【信号处理】基于变分自编码器(VAE)的图片典型增强方法实现

关于

深度学习中,经常面临图片数据量较小的问题,此时,对数据进行增强,显得比较重要。传统的图片增强方法包括剪切,增加噪声,改变对比度等等方法,但是,对于后端任务的性能提升有限。所以,变分自编码器被用来实现深度数据增强。

变分自编码器的主要缺点在于生成图像过于平滑和模糊,图像细节重建不足。

常见的图像增强方法:https://www.tensorflow.org/tutorials/images/data_augmentation

工具

数据集下载地址: CIFAR-10 and CIFAR-100 datasets

方法实现

加载数据和必要的库函数
import tensorflow.compat.v1.keras.backend as K
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
import tensorflow_datasets as tfds
import keras
from keras.models import Model
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshapextrain , ytrain = tfds.as_numpy(tfds.load('cifar10',split='train',batch_size=-1,as_supervised=True,))
xtest , ytest = tfds.as_numpy(tfds.load('cifar10',split='test',batch_size=-1,as_supervised=True,))
xtrain = (xtrain.astype('float32'))/255
xtest = (xtest.astype('float32'))/255height=32
width=32
channels=3
print(f"Train Shape: {xtrain.shape},Test Shape: {xtest.shape}")
plt.imshow(xtrain[0])

编码器模型搭建
input_shape=(height,width,channels)
latent_dims=3072input_img= Input(shape=input_shape, name='encoder_input')
x=Conv2D(128, 4, padding='same', activation='relu',strides=2)(input_img)
x=Conv2D(256, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(512, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(1024, 4, padding='same', activation='relu',strides=2)(x)
conv_shape = K.int_shape(x)
x=Flatten()(x)
x=Dense(3072, activation='relu')(x)
z_mean=Dense(latent_dims, name='latent_mean')(x)
z_sigma=Dense(latent_dims, name='latent_sigma')(x)def sampler(args):z_mean, z_sigma = argseps = K.random_normal(shape=(K.shape(z_mean)[0], K.int_shape(z_mean)[1]))return z_mean + K.exp(z_sigma / 2) * epsz = Lambda(sampler, output_shape=(latent_dims, ), name='z')([z_mean, z_sigma])encoder = Model(input_img, [z_mean, z_sigma, z], name='encoder')
print(encoder.summary())

 解码器模型构建
decoder_input = Input(shape=(latent_dims, ), name='decoder_input')
x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = Conv2DTranspose(256, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(128, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(3, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
decoder = Model(decoder_input, x, name='decoder')
decoder.summary()
z_decoded = decoder(z)class CustomLayer(keras.layers.Layer):def vae_loss(self, x, z_decoded):x = K.flatten(x)z_decoded = K.flatten(z_decoded)# Reconstruction loss (as we used sigmoid activation we can use binarycrossentropy)recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)# KL divergencekl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mean) - K.exp(z_sigma), axis=-1)return K.mean(recon_loss + kl_loss)# add custom loss to the classdef call(self, inputs):x = inputs[0]z_decoded = inputs[1]loss = self.vae_loss(x, z_decoded)self.add_loss(loss, inputs=inputs)return x

 

整体模型构建
y = CustomLayer()([input_img, z_decoded])vae = Model(input_img, y, name='vae')
vae.compile(optimizer='adam', loss=None)
vae.summary()

 

模型训练

history=vae.fit(xtrain, verbose=2, epochs = 100, batch_size = 64, validation_split = 0.2)
 训练可视化
f = plt.figure(figsize=(10,7))
f.add_subplot()
#Adding Subplot
plt.plot(history.epoch, history.history['loss'], label = "loss") # Loss curve for training set
plt.plot(history.epoch, history.history['val_loss'], label = "val_loss") # Loss curve for validation setplt.title("Loss Curve",fontsize=18)
plt.xlabel("Epochs",fontsize=15)
plt.ylabel("Loss",fontsize=15)
plt.grid(alpha=0.3)
plt.legend()
plt.savefig("VAE_Loss_Trial5.png")
plt.show()

 中间编码特征可视化
mu, _, _ = encoder.predict(xtest)
#Plot dim1 and dim2 for mu
plt.figure(figsize=(10, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=ytest, cmap='brg')
plt.xlabel('dim 1')
plt.ylabel('dim 2')
plt.colorbar()
plt.show()
plt.savefig("VAE_Colourbar_Trial5.png")

 

数据增强生成
#RANDOM GENERATION
def generate():n=20figure = np.zeros((width *2 , height * 10, channels))#Create a Grid of latent variables, to be provided as inputs to decoder.predict
#Creating vectors within range -5 to 5 as that seems to be the range in latent spacefor k in range(2):for l in range(10):z_sample =random.rand(3072)z_out=np.array([z_sample])x_decoded = decoder.predict(z_out)digit = x_decoded[0].reshape(width, height, channels)figure[k * width: (k + 1) * width,l * height: (l + 1) * height] = digitplt.figure(figsize=(10, 10))
#Reshape for visualizationfig_shape = np.shape(figure)figure = figure.reshape((fig_shape[0], fig_shape[1],3))plt.imshow(figure, cmap='gnuplot2')plt.show()  plt.savefig("VAE_imagesgen_Trial5.png")

解码器图像重建
#IMAGE RECONSTRUCT USING TEST SET IMGS
def reconstruct():num_imgs = 6rand = np.random.randint(1, xtest.shape[0]-6) xtestsample = xtest[rand:rand+num_imgs]x_encoded = np.array(encoder.predict(xtestsample))latent_xtest=x_encoded[2]x_decoded = decoder.predict(latent_xtest)rows = 2 # defining no. of rows in figurecols = 3 # defining no. of colums in figurecell_size = 1.5f = plt.figure(figsize=(cell_size*cols,cell_size*rows*2)) # defining a figure f.tight_layout()for i in range(rows):for j in range(cols): f.add_subplot(rows*2,cols, (2*i*cols)+(j+1)) # adding sub plot to figure on each iterationplt.imshow(xtestsample[i*cols + j]) plt.axis("off")for j in range(cols): f.add_subplot(rows*2,cols,((2*i+1)*cols)+(j+1)) # adding sub plot to figure on each iterationplt.imshow(x_decoded[i*cols + j]) plt.axis("off")f.suptitle("Autoencoder Results - Cifar10",fontsize=18)plt.savefig("VAE_imagesrecons_Trial5.png")plt.show()

 

代码获取

已经附在文章底部,自行拿取。

项目开发,相关问题咨询,欢迎交流沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/591816.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试复盘1 - 测试相关(实习)

写在前:hello,大家早中晚上好~这里是西西,最近有在准备测试相关的面试,特此开设了新的篇章,针对于面试中的问题来做一下复盘,会把我自己遇到的问题进行整理,除此之外还会进行对一些常见面试题的…

【解决】Unity Profile | FindMainCamera

开发平台:Unity 2020.3.7f1c1 关键词:FindMainCamera   问题背景 ModelViewer 是开发者基于 UnityEngine 编写的相机控制组件。ModelView.Update 中调度52次并触发3次GC.Collect。显然并不期望并尽可能避免 Update 造成的GC 问题。事实上 FindMainCame…

稀碎从零算法笔记Day35-LeetCode:字典序的第K小数字

要考虑完结《稀碎从零》系列了哈哈哈 这道题和【LC.42 接雨水】,我愿称之为【笔试界的颜良&文丑】 题型:字典树、前缀获取、数组、树的先序遍历 链接:440. 字典序的第K小数字 - 力扣(LeetCode) 来源&#xff1…

【面试经典150 | 动态规划】不同路径 II

文章目录 写在前面Tag题目1方法一:动态规划方法二:空间优化 题目2方法一:动态规划空间优化 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主…

跟着Kimi Chat学习提示工程Prompt Engineering!让AI更高效地给你打工!

大家好,我是木易,一个持续关注AI领域的互联网技术产品经理,国内Top2本科,美国Top10 CS研究生,MBA。我坚信AI是普通人变强的“外挂”,所以创建了“AI信息Gap”这个公众号,专注于分享AI全维度知识…

DTFT及其反变换的直观理解

对于离散时间傅里叶变换(DTFT)及其反变换的讲解,教材里通常会先给出DTFT正变换的公式,再举个DTFT的简单变换例子,推导一下DTFT的性质,然后给出DTFT反变换的公式,再证明一下正变换和反变化的对应关系。总的来说就是&…

conda修改默认安装python版本为指定版本

1.查看conda中当前的python版本号: 打开Anaconda Powershell Prompt 输入python -V 回车会输出版本号 2.查看conda所支持的python版本,并选择指定版本安装 选择一个3.9.13版本的进行安装 安装命令: conda install python=3.9.13 如果一直卡在这个画面,请使用管理员权限运行…

python ---- %r %s格式输出的区别

在python中, % s和 % r是我们常用的格式符,它们的用法基本一致,但作用却不尽相同,下面简要说明一下两者的区别: 1. % s是将对象 / 变量传递到str()方法中,并将其转化为面向用户的可阅读的格式。 2. % r是将…

聊一聊单点登录

互联网工程师 一、单点登录的概念 单点登录(Single Sign-On,简称SSO)是一种身份认证和授权技术,旨在解决用户在访问多个应用系统时需要重复登录的问题。该技术允许用户在一个应用系统中完成登录后,就可以访问其他相互信…

C#,简单,精巧,实用的按类型删除指定文件的工具软件

点击下载本文软件(积分): https://download.csdn.net/download/beijinghorn/89059141https://download.csdn.net/download/beijinghorn/89059141 下载审核通过之前,请从百度网盘下载(无积分):…

web框架的本质初识

1.什么是HTML HTML是一个超文本语言,是一种创建网页结构的标记语言。就是你女朋友化妆之后的样子 2.什么是HTTP协议 是一种用于在Web上传输数据的协议。它是客户端和服务器之间进行相互通信的基础的协议 3.HTTP的特点 无连接:每个http请求都是独立的…

母婴商品销售数据分析-Python数据分析项目

文章目录 母婴商品销售数据分析项目介绍项目背景分析目的 数据准备导入数据数据清洗 数据分析整体市场情况第一季度销量下滑原因第四季度销量上升原因复购率 商品销售情况 母婴商品销售数据分析 项目介绍 项目背景 政策Politics:国家发展改革委2013年5月28日表示…