【PyTorch][chapter 20][李宏毅深度学习]【无监督学习][ GAN]【实战】

前言

 本篇主要是结合手写数字例子,结合PyTorch 介绍一下Gan 实战

第一轮训练效果

第20轮训练效果,已经可以生成数字了

68 轮


目录: 

  1.   谷歌云服务器(Google Colab)
  2.   整体训练流程
  3.   Python 代码

一  谷歌云服务器(Google Colab)

     个人用的一直是联想小新笔记本,虽然非常稳定方便。但是现在跑深度学习,性能确实有点跟不上. 

   1.1    打开谷歌云服务器(Google Colab)

      https://colab.research.google.com/

    1. 2  新建笔记

                 

1

 1.4  选择T4GPU 

1.5  点击运行按钮

可以看到当前硬件的情况

     


二  整体训练流程


三    PyTorch 例子

# -*- coding: utf-8 -*-
"""
Created on Fri Mar  1 13:27:49 2024@author: chengxf2
"""
import torch.optim as optim #优化器
import numpy as np 
import matplotlib.pyplot  as plt
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn#第一步加载手写数字集
def loadData():#同时归一化数据集(-1,1)style = transforms.Compose([transforms.ToTensor(),   #0-1 归一化0-1, channel,height,widthtransforms.Normalize(mean=0.5, std=0.5) #变成了-1,1 ])trainData = torchvision.datasets.MNIST('data',train=True,transform=style,download=True)dataloader = torch.utils.data.DataLoader(trainData,batch_size= 16,shuffle=True)imgs,_ = next(iter(dataloader))#torch.Size([64, 1, 28, 28])print("\n imgs shape ",imgs.shape)return dataloaderclass Generator(nn.Module):'''定义生成器输入:z 随机噪声[batch, input_size]输出:x: 图片 [batch, height, width, channel]'''def __init__(self,input_size):super(Generator,self).__init__()self.net = nn.Sequential(nn.Linear(in_features = input_size , out_features =256),nn.ReLU(),nn.Linear(in_features = 256 , out_features =512),nn.ReLU(),nn.Linear(in_features = 512 , out_features =28*28),nn.Tanh())def forward(self, z):# z 随机输入[batch, dim]x = self.net(z)#[batch, height, width, channel]#print(x.shape)x = x.view(-1,28,28,1)return xclass Discriminator(nn.Module):'''定义鉴别器输入:x: 图片 [batch, height, width, channel]输出:y:  二分类图片的概率: BCELoss 计算交叉熵损失'''def __init__(self):super(Discriminator,self).__init__()#开始的维度和终止的维度,默认值分别是1和-1self.flatten = nn.Flatten()self.net = nn.Sequential(nn.Linear(in_features = 28*28 , out_features =512),nn.LeakyReLU(), #负值的时候保留梯度信息nn.Linear(in_features = 512 , out_features =256),nn.LeakyReLU(),nn.Linear(in_features = 256 , out_features =1),nn.Sigmoid())def forward(self, x):x = self.flatten(x)#print(x.shape)out =self.net(x)return outdef gen_img_plot(model, epoch, test_input):out = model(test_input).detach().cpu()out = out.numpy()imgs = np.squeeze(out)fig = plt.figure(figsize=(4,4))for i in range(out.shape[0]):plt.subplot(4,4,i+1)img = (imgs[i]+1)/2.0#[-1,1]plt.imshow(img)plt.axis('off')plt.show()def train():#1 初始化参数device ='cuda' if torch.cuda.is_available() else 'cpu'#2 加载训练数据dataloader = loadData()test_input  = torch.randn(16,100,device=device)#3 超参数maxIter = 20 #最大训练次数input_size = 100batchNum = 16input_size =100#4 初始化模型gen = Generator(100).to(device)dis = Discriminator().to(device)#5 优化器,损失函数d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)g_optim = torch.optim.Adam(gen.parameters(),lr=1e-4)loss_fn = torch.nn.BCELoss()#6 loss 变化列表D_loss =[]G_loss= []for epoch in range(0,maxIter):d_epoch_loss = 0.0g_epoch_loss  =0.0#count = len(dataloader)for step ,(realImgs, _) in enumerate(dataloader):realImgs = realImgs.to(device)random_noise = torch.randn(batchNum, input_size).to(device)#先训练判别器d_optim.zero_grad()real_output = dis(realImgs)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))d_real_loss.backward()#不要训练生成器,所以要生成器detachfake_img = gen(random_noise)fake_output = dis(fake_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossd_optim.step()#优化生成器g_optim.zero_grad()fake_output = dis(fake_img.detach())g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss+= d_lossg_epoch_loss+= g_losscount = 16       with torch.no_grad():d_epoch_loss/=countg_epoch_loss/=countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)gen_img_plot(gen, epoch, test_input)print("Epoch: ",epoch)print("-----finised-----")if __name__ == "__main__":train()

参考:

10.完整课程简介_哔哩哔哩_bilibili

理论【PyTorch][chapter 19][李宏毅深度学习]【无监督学习][ GAN]【理论】-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/503300.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习_16_权重衰退调整过拟合

所谓过拟合即模型复杂度较高,但用于训练数据集过于简单,最后导致模型将过多无用渣质作为学习对象 这个在上篇 深度学习_15_过拟合&欠拟合 已经详细介绍,以下便不再赘述。 上篇提到要想解决过拟合现象可以试着降低模型复杂度&#xff0c…

Windows 10 合并磁盘分区 (G and H)

Windows 10 合并磁盘分区 [G and H] 1. 设备和驱动器2. 计算机 -> 管理 -> 存储 -> 磁盘管理3. 删除卷4. 新建简单卷5. 设备和驱动器References 1. 设备和驱动器 2. 计算机 -> 管理 -> 存储 -> 磁盘管理 3. 删除卷 H: -> right-click -> 删除卷 H: 变…

html2canvas 将DOM节点转成图片

官网地址:html2canvas - Screenshots with JavaScript 将js文件保存到本地 可以新建一个txt文件,然后丢进去修改后缀名称即可。 在项目中引入js文件: import html2canvas from "../html2canvas.min.js" 这是我准备画的DOM节点。…

【AIGC】微笑的秘密花园:红玫瑰与少女的美好相遇

在这个迷人的画面中,我们目睹了一个迷人的时刻,女子则拥有一头柔顺亮丽的秀发,明亮的眼睛如同星河般璀璨,优雅而灵动,她的微笑如春日暖阳,温暖而又迷人。站在红玫瑰花瓣的惊人洪水中。 在一片湛蓝无云的晴…

【AI Agent系列】【MetaGPT多智能体学习】5. 多智能体案例拆解 - 基于MetaGPT的智能体辩论(附完整代码)

本系列文章跟随《MetaGPT多智能体课程》(https://github.com/datawhalechina/hugging-multi-agent),深入理解并实践多智能体系统的开发。 本文为该课程的第四章(多智能体开发)的第三篇笔记。主要是对课程刚开始环境搭…

C#,哈夫曼编码(Huffman Code)压缩(Compress )与解压缩(Decompress)算法与源代码

David A. Huffman 1 哈夫曼编码简史(Huffman code) 1951年,哈夫曼和他在MIT信息论的同学需要选择是完成学期报告还是期末考试。导师Robert M. Fano给他们的学期报告的题目是,寻找最有效的二进制编码。由于无法证明哪个已有编码是…

Facebook直播网络需要满足什么条件

Facebook直播已经成为了企业、个人和组织开展在线活动、互动和营销的重要平台之一。然而,要确保Facebook直播的顺利进行和观众体验的良好,需要满足一系列关键条件。本文将探讨Facebook直播网络 需要满足的关键条件。 1、稳定的互联网连接: 稳…

7. 构建简单 IPv6 网络

7.1 实验介绍 7.1.1 关于本实验 IPv6(Internet Protocol Version 6)也被称为IPng(IP Next Generation)。它是Internet工程任务组IETF(Internet Engineering Task Force)设计的一套规范,是IPv4…

babylonjs入门-半球光

基于babylonjs封装的一些功能和插件 ,希望有更多的小伙伴一起玩babylonjs; 欢迎加群(点击群号传送):464146715 官方文档 中文文档 案例传送门 懒得打字 粘贴复制 一气呵成

OSCP靶场--Craft

OSCP靶场–Craft 考点(1.odt恶意宏文档getshell 2.SeImpersonatePrivilege土豆提权【PrintSpoofer】) 1.nmap扫描 nmap -Pn -sCV — open -p- — min-rate 10000 -oN nmap/open 192.168.249.169 Starting Nmap 7.92 ( https://nmap.org ) at 2022–10–23 06:58 EDT Nmap sc…

如何克隆树莓派系统到较小的硬盘/SD卡上(如何分区、设置修复引导)

最近有个老固态硬盘空下来了,虽然写入速度没那么快,但是足够满足千兆网络了,所以我就想把现在给树莓派使用的固态硬盘换下来。由于一些设置很浪费时间,所以我不打算重装系统。此外这个老固态是 120GB 的,要小于正在使用…

【Unity】机器人末端执行器仿真

机械手臂的末端执行器使用多项式来计算转动角度可能有几个原因: 精确控制:机械臂的运动通常需要高度的精确性,特别是在精密工作或复杂运动轨迹的情况下。多项式,特别是高阶的,可以很好地近似复杂的非线性关系和运动轨迹…