【深度学习】四种天气分类 模版函数 从0到1手敲版本

引入该引入的库

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import torch.optim as optim
%matplotlib inline
import os
import shutil
import glob
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

注意:os.environ[“KMP_DUPLICATE_LIB_OK”]=“TRUE” 必须要引入否则用plt出错

数据集整理

img_dir = r"F:\播放器\1、pytorch全套入门与实战项目\课程资料\参考代码和部分数据集\参考代码\参考代码\29-42节参考代码和数据集\四种天气图片数据集\dataset2"
base_dir = r"./dataset/4weather"img_list = glob.glob(img_dir+"/*.*")
test_dir = "test"
train_dir = "train"
species = ["cloudy","rain","shine","sunrise"]
for idx,img_path in enumerate(img_list):_,img_name = os.path.split(img_path)if idx%5==0:for specie in species:if img_path.find(specie) > -1:dst_dir = os.path.join(test_dir,specie)os.makedirs(dst_dir,exist_ok=True)dst_path = os.path.join(dst_dir,img_name)else:for specie in species:if img_path.find(specie) > -1:dst_dir = os.path.join(train_dir,specie)os.makedirs(dst_dir,exist_ok=True)dst_path = os.path.join(dst_dir,img_name)shutil.copy(img_path,dst_path)

生成测试和训练的文件夹,
目录结构如下:
在这里插入图片描述
rain 下面就是图片了
在这里插入图片描述

构建ds和dl

from torchvision import transforms
transform = transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
train_ds=torchvision.datasets.ImageFolder(train_dir,transform)
test_ds = torchvision.datasets.ImageFolder(train_dir,transform)

在这里插入图片描述
在这里插入图片描述
一张图片效果,这是rain图片 这里需要转换维度,把channel放到最后。同时把数据拉到0-1之间,原本std 和mean 【0.5,0,5】数据在-0.5~0.5之间
在这里插入图片描述
类的映射
在这里插入图片描述

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):img = (img.permute(1, 2, 0).numpy() + 1)/2plt.subplot(2, 3, i+1)plt.title(id_to_class.get(label.item()))plt.imshow(img)

这个方法要学会
在这里插入图片描述

定义网络

class Net(nn.Module):def __init__(self) -> None:super().__init__()self.conv1 = nn.Conv2d(3,16,3)self.conv2 = nn.Conv2d(16,32,3)self.conv3 = nn.Conv2d(32,64,3)self.pool = nn.MaxPool2d(2,2)self.dropout = nn.Dropout(0.3)self.fc1 = nn.Linear(64*10*10,1024)self.fc2 = nn.Linear(1024,4)def forward(self,x):x = F.relu(self.conv1(x))x = self.pool(x)x = F.relu(self.conv2(x))x = self.pool(x)x = F.relu(self.conv3(x))x = self.pool(x)x = self.dropout(x)# print(x.size()) 这里是可以计算出来的,需要掌握计算方法x = x.view(-1,64*10*10)x = F.relu(self.fc1(x))x = self.dropout(x)return self.fc2(x)
model = Net()        
preds = model(imgs)
preds.shape, preds

在这里插入图片描述
定义损失函数和优化函数:

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(),lr=0.001)

定义网络

def fit(epoch, model, trainloader, testloader):correct = 0total = 0running_loss = 0for x, y in trainloader:if torch.cuda.is_available():x, y = x.to('cuda'), y.to('cuda')y_pred = model(x)loss = loss_fn(y_pred, y)optim.zero_grad()loss.backward()optim.step()with torch.no_grad():y_pred = torch.argmax(y_pred, dim=1)correct += (y_pred == y).sum().item()total += y.size(0)running_loss += loss.item()epoch_loss = running_loss / len(trainloader.dataset)epoch_acc = correct / totaltest_correct = 0test_total = 0test_running_loss = 0 with torch.no_grad():for x, y in testloader:if torch.cuda.is_available():x, y = x.to('cuda'), y.to('cuda')y_pred = model(x)loss = loss_fn(y_pred, y)y_pred = torch.argmax(y_pred, dim=1)test_correct += (y_pred == y).sum().item()test_total += y.size(0)test_running_loss += loss.item()epoch_test_loss = test_running_loss / len(testloader.dataset)epoch_test_acc = test_correct / test_totalprint('epoch: ', epoch, 'loss: ', round(epoch_loss, 3),'accuracy:', round(epoch_acc, 3),'test_loss: ', round(epoch_test_loss, 3),'test_accuracy:', round(epoch_test_acc, 3))return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

训练:

epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,model,train_dl,test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)
epoch:  0 loss:  0.043 accuracy: 0.714 test_loss:  0.029 test_accuracy: 0.809
epoch:  1 loss:  0.03 accuracy: 0.807 test_loss:  0.023 test_accuracy: 0.867
epoch:  2 loss:  0.024 accuracy: 0.857 test_loss:  0.018 test_accuracy: 0.888
epoch:  3 loss:  0.021 accuracy: 0.869 test_loss:  0.017 test_accuracy: 0.894
epoch:  4 loss:  0.018 accuracy: 0.886 test_loss:  0.014 test_accuracy: 0.921
epoch:  5 loss:  0.017 accuracy: 0.897 test_loss:  0.022 test_accuracy: 0.869
epoch:  6 loss:  0.013 accuracy: 0.923 test_loss:  0.008 test_accuracy: 0.944
epoch:  7 loss:  0.009 accuracy: 0.947 test_loss:  0.011 test_accuracy: 0.924
epoch:  8 loss:  0.006 accuracy: 0.966 test_loss:  0.004 test_accuracy: 0.988
epoch:  9 loss:  0.004 accuracy: 0.979 test_loss:  0.002 test_accuracy: 0.998
epoch:  10 loss:  0.004 accuracy: 0.979 test_loss:  0.005 test_accuracy: 0.966

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
比较重要的点,
1.分类的数据集布局要记住
2.图片经过conv2 多次后的值要会算 todo
3.图片展示的方法要会

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/563010.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker搭建LNMP环境实战(02):Win10下安装VMware

实战开始,先安装 VMware 虚拟机。话不多说,上手就干! 1、基本环境检查 1.1、本机Bios是否支持虚拟化 进入:任务管理器- 性能,查看“虚拟化”是否启用,如果已启用,则满足要求,如果未…

string类的详细模拟实现

string类的模拟实现 文章目录 string类的模拟实现前言1. 类的框架设计2. 构造函数与析构函数3. 拷贝构造与重载赋值运算符函数4. 运算符重载5. 成员函数6. 迭代器的实现7. 非成员函数8. 单元测试总结 前言 ​ 在现代编程中,字符串处理是每个程序员都会遇到的基本任…

亚稳态及其解决办法

异步电路 亚稳态 亚稳态亚稳态的产生原因什么是同步异步信号怎么消除亚稳态 亚稳态 在数字电路中,每一位数据不是1(高电平)就是0(低电平)。当然对于具体的电路来说,并非1(高电平)就是…

【JavaEE初阶系列】——带你了解volatile关键字以及wait()和notify()两方法背后的原理

目录 🚩volatile关键字 🎈volatile 不保证原子性 🎈synchronized 也能保证内存可见性 🎈Volatile与Synchronized比较 🚩wait和notify 🎈wait()方法 💻wait(参数)方法 🎈noti…

C# WPF编程-控件

C# WPF编程-控件 概述WPF控件类别包括以下控件:背景画刷和前景画刷字体文本装饰和排版字体继承字体替换字体嵌入文本格式化模式鼠标光标 内容控件Label(标签)Button(按钮) 概述 在WPF领域,控件通常被描述为…

xilinx的高速接口构成原理和连接结构

本文来源: V3学院 尤老师的培训班笔记【高速收发器】xilinx高速收发器学习记录Xilinx-7Series-FPGA高速收发器使用学习—概述与参考时钟GT Transceiver的总体架构梳理 文章目录 一、概述:二、高速收发器结构:2.1 QUAD2.1.1 时钟2.1.2 CHANNEL…

计算机视觉之三维重建(2)---摄像机标定

文章目录 一、回顾线代1.1 线性方程组的解1.2 齐次线性方程组的解 二、透镜摄像机的标定2.1 标定过程2.2 提取摄像机参数2.3 参数总结 三、径向畸变的摄像机标定3.1 建模3.2 求解 四、变换4.1 2D平面上的欧式变换4.2 2D平面上的相似变换和仿射变换4.3 2D平面上的透射变换4.4 3D…

MySQL 8.0-索引- 不可见索引(invisible indexes)

概述 MySQL 8.0引入了不可见索引(invisible index),这个在实际工作用还是用的到的,我觉得可以了解下。 在介绍不可见索引之前,我先来看下invisible index是个什么或者定义。 我们依然使用拆开来看,然后再把拆出来的词放到MySQL…

Spring Boot从入门到实战

课程介绍 本课程从SpringBoot的最基础的安装、配置开始到SpringBoot的日志管理、Web业务开发、数据存储、数据缓存,安全控制及相关企业级应用,全程案例贯穿,案例每一步的都会讲解实现思路,全程手敲代码实现。让你不仅能够掌Sprin…

HTTPS协议的工作原理:保护网络通信的安全盾牌

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

网络安全笔记-day7,共享文件服务器

文件共享服务器 准备阶段 打开虚拟机win2003 创建文件 D:. —share   –down   |  test1.txt   |   —up     01xxx.txt     02xxx.txt 配置IP win2003 192.168.1.10 255.255.255.0 winxp 192.168.1.20 255.255.255.0 创建共享文件夹 创建共享&#xff1…

2024智能EDM邮件营销系统使用攻略

在数字化营销领域,智能EDM(Electronic Direct Mail)邮件营销作为一种高效、精准的推广方式,正日益受到企业的高度重视。而要实现这一策略的成功落地,一个高可靠性和高稳定性的专业邮件发送平台则是不可或缺的关键环节。…