pytorch03:transforms常见数据增强操作

目录

  • 一、数据增强
  • 二、transforms--Crop裁剪
    • 2.1 transforms.CenterCrop
    • 2.2 transforms.RandomCrop
    • 2.3 RandomResizedCrop
    • 2.4 FiveCrop和TenCrop
  • 三、transforms—Flip翻转、旋转
    • 3.1RandomHorizontalFlip和RandomVerticalFlip
    • 3.2 RandomRotation
  • 四、transforms —图像变换
    • 4.1 transforms.Pad
    • 4.2 transforms.ColorJitter
    • 4.3 Grayscale和RandomGrayscale
    • 4.4 RandomAffine
    • 4.5 RandomErasing
  • 五、transforms的操作
    • 5.1 transforms.RandomChoice
    • 5.2 transforms.RandomApply
    • 5.3 transforms.RandomOrder
  • 六、自定义transforms
    • 6.1 自定义transforms要素
    • 6.2 通过类实现多参数传入
    • 6.3 椒盐噪声
    • 6.4 自定义transforms代码实现
  • 七、数据增强策略
    • 数据增强代码实现

一、数据增强

   数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。如下是对一张图片常见的增强操作例如:旋转、裁剪、像素抖动。
在这里插入图片描述

二、transforms–Crop裁剪

2.1 transforms.CenterCrop

功能:从图像中心裁剪图片
• size:所需裁剪图片尺寸

2.2 transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为size的图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• padding:设置填充大小
  当为a时,上下左右均填充a个像素,
  当为(a, b)时,上下填充b个像素,左右填充a个像素,
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• pad_if_need:若图像小于设定size,则填充
• padding_mode:填充模式,有4种模式
  1、constant:像素值由fill设定
  2、edge:像素值由图像边缘像素决定
  3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
  4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]
• fill:constant时,设置填充的像素值

2.3 RandomResizedCrop

功能:随机大小、长宽比裁剪图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• scale:随机裁剪面积比例, 默认(0.08, 1)
• ratio:随机长宽比,默认(3/4, 4/3)
• interpolation:插值方法
PIL.Image.NEAREST
PIL.Image.BILINEAR
PIL.Image.BICUBIC

2.4 FiveCrop和TenCrop

  功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• vertical_flip:是否垂直翻转

三、transforms—Flip翻转、旋转

3.1RandomHorizontalFlip和RandomVerticalFlip

在这里插入图片描述

功能:依概率水平(左右)或垂直(上下)翻转图片
• p:翻转概率

3.2 RandomRotation

功能:随机旋转图片
在这里插入图片描述
在这里插入图片描述

• degrees:旋转角度
  当为a时,在(-a,a)之间选择旋转角度
  当为(a, b)时,在(a, b)之间选择旋转角度
• resample:重采样方法
• expand:是否扩大图片,以保持原图

四、transforms —图像变换

4.1 transforms.Pad

功能:对图片边缘进行填充
在这里插入图片描述
• padding:设置填充大小
  当为a时,上下左右均填充a个像素
  当为(a, b)时,上下填充b个像素,左右填充a个像素
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• padding_mode:填充模式,有4种模式,constant、edge、reflect和symmetric
• fill:constant时,设置填充的像素值,(R, G, B) or (Gray)

4.2 transforms.ColorJitter

功能:调整亮度、对比度、饱和度和色相
在这里插入图片描述

• brightness:亮度调整因子
  当为a时,从[max(0, 1-a), 1+a]中随机选择
  当为(a, b)时,从[a, b]中
• contrast:对比度参数,同brightness
• saturation:饱和度参数,同brightness
• hue:色相参数,当为a时,从[-a, a]中选择参数,注: 0<= a <= 0.5
        当为(a, b)时,从[a, b]中选择参数,注:-0.5 <= a <= b <= 0.5

4.3 Grayscale和RandomGrayscale

功能:依概率将图片转换为灰度图
在这里插入图片描述
• num_ouput_channels:输出通道数只能设1或3
• p:概率值,图像被转换为灰度图的概率

4.4 RandomAffine

功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
在这里插入图片描述
在这里插入图片描述
• degrees:旋转角度设置
• translate:平移区间设置,如(a, b), a设置宽(width),b设置高(height)
    图像在宽维度平移的区间为 -img_width * a < dx < img_width * a
• scale:缩放比例(以面积为单位)
• fill_color:填充颜色设置

4.5 RandomErasing

功能:对图像进行随机遮挡
在这里插入图片描述

• p:概率值,执行该操作的概率
• scale:遮挡区域的面积
• ratio:遮挡区域长宽比
• value:设置遮挡区域的像素值,(R, G, B) or (Gray)

五、transforms的操作

5.1 transforms.RandomChoice

功能:从一系列transforms方法中随机挑选一个

transforms.RandomChoice([transforms1, transforms2, transforms3])

5.2 transforms.RandomApply

功能:依据概率执行一组transforms操作

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

5.3 transforms.RandomOrder

功能:对一组transforms操作打乱顺序

transforms.RandomOrder([transforms1, transforms2, transforms3])

六、自定义transforms

6.1 自定义transforms要素

1.仅接收一个参数,返回一个参数
2.注意上下游的输出与输入
当前transforms的输入是上一个transforms的输出,所以要保证数据类型匹配:
在这里插入图片描述

6.2 通过类实现多参数传入

在这里插入图片描述

在Python中,__call__是一个特殊的方法,用于使一个对象可以像函数一样被调用。如果一个类定义了__call__方法,那么实例化的对象就可以被当作函数一样调用,而调用的实际上是__call__方法。

class CallableClass:def __init__(self):print("Initializing the CallableClass")def __call__(self, *args, **kwargs):print("Calling the CallableClass with arguments:", args, kwargs)# 实例化对象
obj = CallableClass()# 调用对象,实际上调用了__call__方法
obj(1, 2, 3, keyword_arg="hello")

上面的例子中,CallableClass定义了__call__方法,这意味着实例obj可以像函数一样被调用。当你调用obj(1, 2, 3, keyword_arg=“hello”)时,实际上是在调用obj.call(1, 2, 3, keyword_arg=“hello”)。

6.3 椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点, 白点称为盐噪声,黑色为椒噪声
信噪比(Signal-Noise Rate, SNR)是衡量噪声的比例,图像中为图像像素的占比,从下图可以看出,信噪比越小,图片丢失的像素越多。
在这里插入图片描述

6.4 自定义transforms代码实现

class AddPepperNoise(object):"""增加椒盐噪声Args:snr (float): Signal Noise Rate 信噪比p (float): 概率值,依概率执行该操作Attributes:snr (float): 信噪比p (float): 操作执行的概率"""def __init__(self, snr, p=0.9):# 确保传入的snr和p是float类型assert isinstance(snr, float) and isinstance(p, float)self.snr = snrself.p = pdef __call__(self, img):"""对图像应用椒盐噪声操作。Args:img (PIL Image): PIL Image对象Returns:PIL Image: 处理后的PIL Image对象"""# 根据概率决定是否执行噪声操作if random.uniform(0, 1) < self.p:img_ = np.array(img).copy()h, w, c = img_.shapesignal_pct = self.snrnoise_pct = (1 - self.snr)# 生成噪声掩码,表示每个像素是原始图像、盐噪声还是椒噪声mask = np.random.choice((0, 1, 2), size=(h, w, 1),p=[signal_pct, noise_pct / 2., noise_pct / 2.])mask = np.repeat(mask, c, axis=2)# 根据噪声类型修改图像像素值img_[mask == 1] = 255  # 盐噪声img_[mask == 2] = 0    # 椒噪声# 将NumPy数组转换回PIL Image对象,并确保数据类型为uint8,颜色通道为RGBreturn Image.fromarray(img_.astype('uint8')).convert('RGB')else:return img

在这里插入图片描述

七、数据增强策略

原则:让训练集与测试集更接近可以使用下面这些方法
• 空间位置:平移
• 色彩:灰度图,色彩抖动
• 形状:仿射变换
• 上下文场景:遮挡,填充

例如我们训练集白猫比较多,可以改变白猫色彩,让白猫的颜色接近黑猫。
在这里插入图片描述

数据增强代码实现

要求:使用第四套RMB进行训练,要求能对第5套RMB识别正确。

我们只进行普通的图片处理训练好的模型,发现将第五套100元都识别成一元,因为第四套人民币的1元和第五套人民的100元颜色相近,所以会导致识别错误:
在这里插入图片描述
解决方法,将所有训练集颜色都进行灰度处理,代码修改如下:

train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),  #图片灰度化transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])

修改后的预测结果如下:
在这里插入图片描述
训练完整代码如下:

# -*- coding: utf-8 -*-import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from lenet import LeNet
from my_dataset import RMBDataset
from common_tools import transform_invertdef set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1# ============================ step 1/5 数据 ============================split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.RandomCrop(32, padding=4),transforms.RandomGrayscale(p=0.9),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================net = LeNet(classes=2)
net.initialize_weights()# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.net.train()for i, data in enumerate(train_loader):# forwardinputs, labels = dataoutputs = net(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.net.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = dataoutputs = net(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().sum().numpy()loss_val += loss.item()valid_curve.append(loss_val)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()# ============================ inference ============================BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)for i, data in enumerate(valid_loader):# forwardinputs, labels = dataoutputs = net(inputs)_, predicted = torch.max(outputs.data, 1)rmb = 1 if predicted.numpy()[0] == 0 else 100img_tensor = inputs[0, ...]  # C H Wimg = transform_invert(img_tensor, train_transform)plt.imshow(img)plt.title("LeNet got {} Yuan".format(rmb))plt.show()plt.pause(0.5)plt.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/310236.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Github项目推荐-vocal-separate

项目地址 vocal-separate: 项目简述 这是一个音乐和人声分离的项目&#xff0c;基于python开发。有图形化操作界面&#xff0c;看起来还不错。 项目截图

二叉树BFS

前置知识 二叉树节点的定义 二叉树是递归定义的 /*** Definition for a binary tree node.&#xff08;LeetCode&#xff09;*/public class TreeNode {int val;TreeNode left;TreeNode right;TreeNode() {}TreeNode(int val) { this.val val; }TreeNode(int val, TreeNode…

基本运算器实验静态随机存储器实验

1.1 基本运算器实验 1. 实验记录 ①运算结果 首先按照实验指导书进行连线&#xff0c;然后打开试验箱电源&#xff0c;把A&#xff0c;B两个数存到寄存器中&#xff0c;然后改变s3 s2 s1 s0 的值&#xff0c;产生脉冲&#xff0c;观察对应的数据总线上的值以及两个标志位。 …

微信小程序开发系列-09自定义组件样式特性

微信小程序开发系列目录 《微信小程序开发系列-01创建一个最小的小程序项目》《微信小程序开发系列-02注册小程序》《微信小程序开发系列-03全局配置中的“window”和“tabBar”》《微信小程序开发系列-04获取用户图像和昵称》《微信小程序开发系列-05登录小程序》《微信小程序…

【Unity入门】热更新框架之xLua

目录 一、xLua概述1.1xLua简介1.2xLua安装 二、Lua文件加载2.1执行字符串2.2加载Lua文件2.3自定义loader 三、xLua文件配置3.1打标签3.2静态列表3.3动态列表 四、Lua与C#交互4.1 C#访问Lua4.1.1 获取一个全局基本数据类型4.1.2 访问一个全局的table4.1.3 访问一个全局的functio…

Maven项目提示Ignored pom.xml问题

1 环境 &#xff08;1&#xff09;IDEA开发工具&#xff1a;2022.2.1 &#xff08;2&#xff09;JDK&#xff1a;Java17&#xff08;Spring6要求JDK最低版本是Java17&#xff09; &#xff08;3&#xff09;Spring&#xff1a;6.1.2 &#xff08;4&#xff09;Maven 3.8.8 2 …

pytest --collectonly 收集测试案例

pytest --collectonly 是一条命令行指令&#xff0c;用于在运行 pytest 测试时仅收集测试项而不执行它们。它会显示出所有可用的测试项列表&#xff0c;包括测试模块、测试类和测试函数&#xff0c;但不会执行任何实际的测试代码。 这个命令对于查看项目中的测试结构和确保所有…

大模型LLM的微调技术:LoRA

0 引言 LoRA(Low-Rank Adaptation)出自2021年的论文“LoRA: Low-Rank Adaptation of Large Language Models” LoRA技术冻结预训练模型的权重&#xff0c;并在每个Transformer块中注入可训练层&#xff08;称为秩分解矩阵&#xff09;&#xff0c;即在模型的Linear层的旁边增…

Java EE 网络原理之HTTPS

文章目录 1. HTTPS 是什么&#xff1f;2. "加密" 是什么&#xff1f;3. HTTPS 的工作过程3.1 引入对称加密3.2 引入非对称加密3.3 中间人攻击3.4 引入证书 4. Tomecat4.1 tomcat 的作用 1. HTTPS 是什么&#xff1f; HTTPS也是⼀个应用层协议&#xff0c;是在 HTTP …

【计算机毕业设计】python+django数码电子论坛系统设计与实现

本系统主要包括管理员和用户两个角色组成&#xff1b;主要包括&#xff1a;首页、个人中心、用户管理、分类管理、数码板块管理、数码评价管理、数码论坛管理、畅聊板块管理、系统管理等功能的管理系统。 后端&#xff1a;pythondjango 前端&#xff1a;vue.jselementui 框架&a…

Java强软弱虚引用

面试&#xff1a; 1.强引用&#xff0c;软引用&#xff0c;弱引用&#xff0c;虚引用分别是什么&#xff1f; 2.软引用和弱引用适用的场景&#xff1f; 3.你知道弱引用的话&#xff0c;能谈谈WeakHashMap吗&#xff1f; 目录 一、Java引用 1、强引用&#xff08;默认支持模式…

ES6的默认参数和rest参数

✨ 专栏介绍 在现代Web开发中&#xff0c;JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性&#xff0c;还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言&#xff0c;JavaScript具有广泛的应用场景&#x…