变分自编码器【02/3】:训练过程

一、说明

        欢迎来到我们的变分自编码器 (VAE) 系列三部分的第二部分。基于第 1 部分中介绍的介绍和实施细节,本部分将重点关注微调这些模型不可或缺的培训程序。我们将深入探讨每个参数的重要性以及它们对训练过程的贡献。要访问本系列的完整代码,请访问我们的 GitHub 存储库:GitHub - asokraju/ImageAutoEncoder: A repository to learn features from Off Road Navigation Vehicles

        在在本文中,我们概述了设置和管理变分自动编码器 (VAE) 训练过程所涉及的步骤,这是一种复杂的深度学习模型,在与图像生成和修改相关的任务中特别有效。

        训练过程首先接受一系列参数和超参数,这些参数和超参数有助于塑造模型和训练过程的各个方面。这些参数包括图像数据的位置、模型训练日志的存储路径、训练周期数、批量大小和学习率等方面。

        设置这些参数后,我们就建立训练和测试数据的路径。数据被分为一个训练子集,供模型学习,以及一个测试子集,用于衡量模型在以前从未遇到过的数据上的性能。

        接下来,我们设置数据生成器,这是处理大型数据集时的关键组件。这些生成器对训练图像应用一系列转换(称为数据增强),使模型能够更好地泛化并减轻过度拟合。对于测试数据,我们放弃应用这些转换。

        然后,生成器承担按批量大小定义的小批量加载和处理图像的任务。此过程使模型能够逐步增量更新其权重,而不是尝试同时计算整个数据集。

        数据准备就绪并准备好生成器后,VAE 模型现在就可以使用指定的参数和超参数进行训练了。记录此过程以供将来分析、改进和调试之用。

二、解析超参数

        我们首先定义一种方法parse_arguments()来接收模型训练的各种参数和超参数。

def parse_arguments():parser = argparse.ArgumentParser()parser.add_argument('--image-dir', type=str, help='Path to the image data', default= r'Data')parser.add_argument('--logs-dir', type=str, help='Path to store logs', default=r"logs")parser.add_argument('--output-image-shape', type=int, default=56)parser.add_argument('--filters', type=int, nargs='+', default=[32, 64])parser.add_argument('--dense-layer-dim', type=int, default=16)parser.add_argument('--latent-dim', type=int, default=6)parser.add_argument('--beta', type=float, default=1.0)parser.add_argument('--batch-size', type=int, default=128)parser.add_argument('--learning-rate', type=float, default=1e-4)parser.add_argument('--patience', type=int, default=10)parser.add_argument('--epochs', type=int, default=20)parser.add_argument('--train-split', type=float, default=0.8)args = parser.parse_args()return args

        这些参数包括图像数据和日志目录的路径、输出图像形状、卷积层的滤波器数量、密集层和潜在层的尺寸、批量大小、学习率和时期数。

  1. --image-dir:这是训练图像所在的目录。
  2. --logs-dir:这是存储训练日志的位置。这些日志有利于调试和可视化训练进度。
  3. --output-image-shape:定义输出图像的形状。需要注意的是,VAE 需要输出与输入图像形状相同的图像。
  4. --filters:这些是编码器和解码器模型的卷积层的滤波器数量。
  5. --dense-layer-dim:这是编码器模型中密集层的维数。
  6. --latent-dim:这定义了潜在空间的维度。为潜在空间选择合适的大小至关重要,因为它将决定 VAE 表示复杂数据的能力。
  7. --batch-size:一次迭代中使用的训练示例的数量。这会显着影响模型的训练性能。
  8. --learning-rate:优化器在学习时采取的步骤的大小。这需要仔细调整,因为值太小可能会导致收敛速度慢,而值太大可能会阻碍收敛。
  9. --epochs:完整通过训练数据集的次数。正确的纪元数通常取决于模型开始过度拟合的时间。

        接下来,主脚本启动,解析这些参数并设置数据路径,如下所示:

args = parse_arguments()
IMAGE_DIR = args.image_dir
LOGS_DIR = args.logs_dir
all_image_paths = get_image_data(IMAGE_DIR)
image_count = len(all_image_paths)
TRAIN_SPLIT  = args.train_split
OUTPUT_IMAGE_SHAPE = args.output_image_shape
INPUT_SHAPE = (OUTPUT_IMAGE_SHAPE, OUTPUT_IMAGE_SHAPE, 1)
FILTERS = args.filters
DENSE_LAYER_DIM = args.dense_layer_dim
LATENT_DIM = args.latent_dim
BATCH_SIZE = args.batch_size
EPOCHS = args.epochs
LEARNING_RATE = args.learning_rateLOGDIR = os.path.join(LOGS_DIR, datetime.now().strftime("%Y%m%d-%H%M%S"))
os.mkdir(LOGDIR)

三、用于训练和测试的数据提取、预处理和分离

        如前所述,下一步是提取图像数据并将其划分为训练集和测试集。

        我们首先使用该函数将所有图像路径收集到一个列表中get_image_data()。然后我们计算图像总数,以便稍后将它们划分以进行训练和测试。

def get_image_data(all_dirs):# List to store all image file pathsall_dirs = [all_dirs]all_image_paths = []# Loop through all directories and subdirectories in the data directoryfor data_dir in all_dirs:for root, dirs, files in os.walk(data_dir):for file in files:# Check if the file is an image file (you can add more extensions as needed)if file.endswith('.jpg') or file.endswith('.jpeg') or file.endswith('.png'):# If the file is an image file, append its path to the listall_image_paths.append(os.path.join(root, file))print(data_dir)image_count = len(all_image_paths)print("Total number of imges:", image_count)return all_image_paths

        一旦我们有了图像列表,我们就把它分成两部分:较大的部分用于训练(如 所示TRAIN_SPLIT),较小的部分用于测试。这是通过根据分割比对图像路径列表进行切片,创建两个单独的列表来完成的。然后,我们将这些图像路径存储在各自的数据帧中,df_train并且df_test

all_image_paths = get_image_data(IMAGE_DIR)
image_count = len(all_image_paths)
df_train = pd.DataFrame({'image_paths': all_image_paths[:int(image_count*TRAIN_SPLIT)]})
df_test = pd.DataFrame({'image_paths': all_image_paths[int(image_count*TRAIN_SPLIT):]})

        该过程的下一部分涉及准备我们的数据生成器。这些本质上是处理图像加载、预处理和批处理的管道。我们为训练和测试数据定义单独的生成器。对于训练数据,我们使用剪切、缩放、翻转和旋转等图像增强技术来人为地增加数据集的大小和可变性,这有助于在训练过程中更好地泛化。对于测试数据,我们简单地标准化像素值。

        定义数据生成器后,我们将函数应用于flow_from_dataframe()训练和测试数据帧。该函数直接从给定的数据帧生成批量图像,将图像转换为灰度,将其大小调整为所需的形状,最后对它们进行洗牌以确保训练过程中的随机性。

train_datagen_args = dict(rescale=1.0 / 255,  # Normalize pixel values between 0 and 1shear_range=0.2,zoom_range=0.2,horizontal_flip=True,rotation_range=90,width_shift_range=0.1,height_shift_range=0.1,
)
test_datagen_args = dict(rescale=1.0 / 255)train_datagen = ImageDataGenerator(**train_datagen_args)
test_datagen = ImageDataGenerator(**test_datagen_args)
# Use flow_from_dataframe to generate data batches
train_data_generator = train_datagen.flow_from_dataframe(dataframe=df_train,color_mode='grayscale',x_col='image_paths',y_col=None,target_size=(OUTPUT_IMAGE_SHAPE, OUTPUT_IMAGE_SHAPE),  # Specify the desired size of the input imagesbatch_size=BATCH_SIZE,class_mode=None,  # Set to None since there are no labelsshuffle=True  # Set to True for randomizing the order of the images
)test_data_generator = test_datagen.flow_from_dataframe(dataframe=df_test,color_mode='grayscale',x_col='image_paths',y_col=None,target_size=(OUTPUT_IMAGE_SHAPE, OUTPUT_IMAGE_SHAPE),  # Specify the desired size of the input imagesbatch_size=BATCH_SIZE,class_mode=None,  # Set to None since there are no labelsshuffle=True  # Set to True for randomizing the order of the images
)

此过程会产生一批经过处理的图像流,准备用于训练和测试我们的变分自动编码器模型。

四、训练

        建立数据管道后,我们继续构建变分自动编码器模型的关键步骤。这是一个由两部分组成的过程,涉及编码器和解码器的组装,这在本系列的第一部分中进行了讨论。一旦我们将这些单独的组件组合在一起,我们就会将它们组合起来创建完整的 VAE 模型。

        简而言之,我们的编码器模块负责将输入图像转换为一组定义潜在空间分布的参数。然后,解码器从该分布中取出一个点,并将其变换回原始图像空间。因此,VAE 学习将有关输入数据的有用信息编码到潜在空间中,然后可以将其用于各种目的,例如生成新图像、查找相似图像等。

        除了构建模型之外,我们还定义了一个自定义指标类TotalLoss来监控训练期间 VAE 的性能。

# custom metrics
class TotalLoss(Metric):def __init__(self, name="total_loss", **kwargs):super(TotalLoss, self).__init__(name=name, **kwargs)self.total_loss = self.add_weight(name="tl", initializer="zeros")def update_state(self, y_true, y_pred, sample_weight=None):# Compute the total lossz_mean, z_log_var, z, reconstruction = y_predreconstruction_loss = reduce_mean(reduce_sum(binary_crossentropy(y_true, reconstruction), axis=(1, 2)))kl_loss = -0.5 * (1 + z_log_var - tf_square(z_mean) - tf_exp(z_log_var))kl_loss = reduce_mean(reduce_sum(kl_loss, axis=1))total_loss = reconstruction_loss + kl_lossself.total_loss.assign(total_loss)def result(self):return self.total_lossdef reset_states(self):# The state of the metric will be reset at the start of each epoch.self.total_loss.assign(0.)

        该指标计算总损失,其中包括重建损失和 KL 散度损失,提供了 VAE 学习重新创建输入图像和在潜在空间中保持良好分布的整体度量。

encoder, encoder_layers_dim = encoder_model(input_shape = INPUT_SHAPE, filters=FILTERS, dense_layer_dim=DENSE_LAYER_DIM, latent_dim=LATENT_DIM)
decoder = decoder_model(encoder_layers_dim)vae = VAE(encoder, decoder)
vae.compile(optimizer=Adam(learning_rate=LEARNING_RATE), metrics=[TotalLoss()])

进一步细分:

  1. 构建编码器:我们首先调用我们的encoder_model函数,它返回编码器模型及其层的尺寸。编码器模型将输入图像的形状、每个卷积层中使用的滤波器数量、密集层的维度以及潜在空间的维度作为输入。
  2. 构建解码器:接下来,我们调用该decoder_model函数,为其提供编码器层的尺寸。解码器模型使用这些维度构建一系列反映编码器结构的反卷积层,使其能够从潜在空间准确地重建输入图像。
  3. 组合编码器和解码器:一旦我们有了编码器和解码器模型,我们就可以通过将这些模型传递给类来实例化我们的 VAE 模型VAE。此类将编码器和解码器组合成一个端到端模型,可以在我们的图像数据上进行训练。
  4. 定义总损失指标:最后,我们定义一个自定义指标来监控模型在训练期间的总损失。该类TotalLoss是Keras类的子类Metric,它通过将重建损失和KL散度损失相加来计算总损失。这为我们提供了反映模型整体性能的单一值。

TensorFlow 和 Keras 中的回调提供了一种在训练的各个阶段执行某些操作的方法。它们是训练过程的重要组成部分,因为它们允许我们在训练期间添加自定义行为。

class VAECallback(Callback):"""Randomly sample 5 images from validation_data set and shows the reconstruction after each epoch"""def __init__(self, vae, validation_data, log_dir, n=5):self.vae = vaeself.validation_data = validation_dataself.n = nself.log_dir = log_dirdef on_epoch_end(self, epoch, logs=None):# check every 10 epochsif epoch % 10 ==0:# Generate decoded images from the validation inputvalidation_batch = next(iter(self.validation_data))_, _, _, reconstructed_images = self.vae.predict(validation_batch)# Rescale pixel values to [0, 1]reconstructed_images = np.clip(reconstructed_images, 0.0, 1.0)# Plot the original and reconstructed images side by sideplt.figure(figsize=(10, 2*self.n))  # Adjusted the figure sizefor i in range(self.n):plt.subplot(self.n, 2, 2*i+1)plt.imshow(validation_batch[i], cmap='gray')plt.axis('off')plt.subplot(self.n, 2, 2*i+2)plt.imshow(reconstructed_images[i], cmap='gray')plt.axis('off')fig_name = os.path.join(self.log_dir , 'decoded_images_epoch_{:04d}.png'.format(epoch))plt.savefig(fig_name)# plt.show()vae_callback = VAECallback(vae, test_data_generator, LOGDIR)
tensorboard_cb = TensorBoard(log_dir=LOGDIR, histogram_freq=1)
checkpoint_cb = ModelCheckpoint(filepath=vae_path, save_weights_only=True, verbose=1)
earlystopping_cb = EarlyStopping(monitor="total_loss",min_delta=1e-2,patience=5,verbose=1,)

在此代码中,使用了四种类型的回调:

  1. VAECallback:此自定义回调从验证数据集中采样一些图像,并在每个训练周期结束时可视化 VAE 对这些图像的重建。通过保存这些图像,我们可以直观地跟踪模型的性能如何随着时间的推移而提高。
  2. TensorBoard: TensorBoard 是 TensorFlow 附带的可视化工具。此回调记录每个时期的各种指标和参数,允许您在 TensorBoard 中可视化它们。这对于监控培训过程和诊断问题非常有用。
  3. ModelCheckpoint:此回调会按一定的时间间隔保存模型权重,以便您可以使用它们在以后继续训练或评估模型在不同指标上的性能。在这种情况下,它保存迄今为止看到的最佳模型的权重(通过验证损失来衡量)。
  4. EarlyStopping:当监控的指标停止改善时,此回调将停止训练,在本例中为total_loss. 它对于防止过度拟合和减少计算浪费很有用。“patience”参数是指标停止改进后停止之前等待的时期数。

最后,fit在 VAE 模型上调用该函数以开始训练过程。训练数据、纪元数、验证数据和回调将传递给此函数。

history = vae.fit(train_data_generator,epochs=EPOCHS,validation_data=test_data_generator,callbacks=[tensorboard_cb, vae_callback, checkpoint_cb, earlystopping_cb]
)

        在训练过程结束时,history对象包含每个时期的损失和度量值,可用于绘制学习曲线并评估模型。

        因此,总而言之,在本系列的这一部分中,我们讨论了训练变分自动编码器的过程,包括设置数据管道、构建模型、定义自定义指标和使用回调。在本系列的下一部分也是最后一部分中,我们将深入研究超参数调整以优化模型的性能。敬请关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/288223.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android的基础开发

基础开发 listView ListView就是列表条目&#xff0c;可以向下滚动&#xff0c;也可以点击。 首先设置两个视图布局 activity_main2.xml【充当容器{ListView}】 <ListViewandroid:layout_width"match_parent"android:layout_height"match_parent"a…

邮政快递单号查询入口,标记需要的单号记录

批量查询邮政快递单号的物流信息&#xff0c;对需要的单号记录进行标记。 所需工具&#xff1a; 一个【快递批量查询高手】软件 邮政快递单号若干 操作步骤&#xff1a; 步骤1&#xff1a;运行【快递批量查询高手】软件&#xff0c;并登录 步骤2&#xff1a;点击主界面左上角…

生信学院|12月22日《快速产品图像渲染》

课程主题&#xff1a;快速产品图像渲染 课程时间&#xff1a;2023年12月22日 14:00-14:30 主讲人&#xff1a;陈伟 生信科技 售后服务工程师 1、SOLIDWORKS Visualize介绍 2、操作演示 3、答疑 请安装腾讯会议客户端或APP&#xff0c;微信扫描海报中的二维码报名哦~~~ 或…

基于SSM的视康眼镜网店销售系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

PyQt6 QFontDialog字体对话框控件

锋哥原创的PyQt6视频教程&#xff1a; 2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计50条视频&#xff0c;包括&#xff1a;2024版 PyQt6 Python桌面开发 视频教程(无废话版…

大型语言模型:RoBERTa — 一种稳健优化的 BERT 方法

slavahead 一、介绍 BERT模型的出现BERT模型带来了NLP的重大进展。 BERT 的架构源自 Transformer&#xff0c;它在各种下游任务上取得了最先进的结果&#xff1a;语言建模、下一句预测、问答、NER标记等。 尽管 BERT 性能出色&#xff0c;研究人员仍在继续尝试其配置&#xff0…

【MySQL】Sql优化之索引的使用方式(145)

索引分类 1.单值索引 单的意思就是单列的值&#xff0c;比如说有一张数据库表&#xff0c;表内有三个字段&#xff0c;分别是 id name numberNo&#xff0c;我给name 这个字段加一个索引&#xff0c;这就是单值索引&#xff0c;因为只有name 这一列是索引&#xff1b; 一个表…

3d游戏公司选择云电脑进行云办公有哪些优势

随着游戏行业的不断发展&#xff0c;很多的游戏制作公司也遇到了很多的难题&#xff0c;比如硬件更换成本高、团队协同难以及效率低下等问题&#xff0c;那么如何解决游戏行业面临的这些行业痛点&#xff0c;以及游戏制作公司选择云电脑进行云办公有哪些优势&#xff1f;一起来…

Axure中如何使用交互样式交互事件交互动作情形

&#x1f3ac; 艳艳耶✌️&#xff1a;个人主页 &#x1f525; 个人专栏 &#xff1a;《产品经理如何画泳道图&流程图》 ⛺️ 越努力 &#xff0c;越幸运 目录 一、Axure中交互样式 1、什么是交互样式&#xff1f; 2、交互样式的作用&#xff1f; 3、Axure中如何…

【HarmonyOS开发】ArkUI实现下拉刷新/上拉加载

列表下拉刷新、上拉加载更多&#xff0c;不管在web时代还是鸿蒙应用都是一个非常常用的功能&#xff0c;基于ArkUI中TS扩展的声明式开发范式实现一个下拉刷新&#xff0c;上拉加载。 上拉加载、下拉刷新 如果数据量过大&#xff0c;可以使用LazyForEach代替ForEach 高阶组件-…

Leetcode—128.最长连续序列【中等】

2023每日刷题&#xff08;六十四&#xff09; Leetcode—128.最长连续序列 实现代码 class Solution { public:int longestConsecutive(vector<int>& nums) {unordered_set<int> s;for(auto num: nums) {s.insert(num);}int longestNum 0;for(auto num: s) …

Linux Mint 21.3 代号为“Virginia”开启下载

Linux Mint 团队今天放出了 Linux Mint 21.3 Beta ISO 镜像&#xff0c;正式版计划在今年圣诞节发布。 支持 在实验性支持 Wayland 之外&#xff0c;Cinnamon 6.0 版 Linux Mint 21.3 Beta 镜像还带来了其它改进&#xff0c;Nemo 文件夹管理器右键菜单支持下载相关操作。 Cin…