pytorch-解决过拟合之regularization

目录

  • 1.解决过拟合的方法
  • 2. regularization
  • 2. regularization分类
  • 3. pytorch L2 regularization
  • 4. 自实现L1 regularization
  • 5. 完整代码

1.解决过拟合的方法

  • 更多的数据
  • 降低模型复杂度
    regularization
  • Dropout
  • 数据处理
  • 早停止

2. regularization

以二分类的cross entropy为例,就是在其公式后增加一项参数一范数累加和,并乘以一个超参数用来权衡参数配比。
模型优化是要使得前部分loss尽量小,那么同时也要后半部分范数接近于0,但是为了保持模型的表达能力还要保留比如 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,那么可能使得 β 0 β_{0} β0 β 2 β_{2} β2 β 3 β_{3} β3 = 0.01 而 β 4 β_{4} β4- β n β_{n} βn很小很小,比如0.0001,这样就使得比如f(x)= x 7 x^7 x7,退化为 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,这样即保证了模型的表达能力,也降低了模型的复杂度。从而防止过拟合。
在这里插入图片描述
下图是未增加regularization和增加了regularization的区别展示图
可以看出未增加regularization的时候,模型可以将噪点也拟合进去了,因此图形很不平滑,发生了过拟合。而增加regularization之后,图形变得很平滑。
在这里插入图片描述

2. regularization分类

regularization有两类分别是L1和L2,L1增加的是参数的一范数,L2增加的二范数
在这里插入图片描述
最常用的是L2regularization

3. pytorch L2 regularization

pytorch中L2 regularization叫weight_decay
在这里插入图片描述

4. 自实现L1 regularization

在这里插入图片描述

5. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsfrom visdom import Visdombatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
criteon = nn.CrossEntropyLoss().to(device)viz = Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',legend=['loss', 'acc.']))
global_step = 0for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()global_step += 1viz.line([loss.item()], [global_step], win='train_loss', update='append')if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()viz.line([[test_loss, correct / len(test_loader.dataset)]],[global_step], win='test', update='append')viz.images(data.view(-1, 1, 28, 28), win='x')viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/659681.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql基础知识汇总

本文自行整理,只做学习记忆之用,若有不当之处请指出 一、数据库三层结构 (1)所谓安装Mysql数据库,就是在主机安装一个数据库管理系统(DBMS),这个管理程序可以管理多个数据库。DBMS(database manage system) &#xf…

vue知识

一、初始vue Vue核心 Vue简介 初识 (yuque.com) 1.想让Vue工作,就必须创建一个Vue实例,且要传入一个配置对象 2.root容器里的代码依然符合html规范,只不过混入了一些特殊的Vue语法 3.root容器里的代码被称为【Vue模板】 4.Vue实例和容器…

docker各目录含义

目录含义builder构建docker镜像的工具或过程buildkit用于构建和打包容器镜像,官方构建引擎,支持多阶段构建、缓存管理、并行化构建和多平台构建等功能containerd负责容器生命周期管理,能起、停、重启,确保容器运行。负责镜管理&am…

Git Tag:为你的代码版本打上优雅的标签

为你的代码版本打上优雅的标签 在软件开发过程中,版本控制是项目管理的重要一环。Git 作为最流行的版本控制系统之一,为我们提供了强大的工具来管理代码版本。其中,git tag 命令允许我们为代码仓库中的特定提交打上标签,这些标签…

没有京牌车如何面对五一小长假?

随着五一小长假的来临,许多人都计划着出游,享受假期的快乐。然而,对于没有京牌车的市民来说,出行可能会面临一些困扰。那么,没有京牌车该如何应对五一小长假的出行需求呢?下面北京盛昂京牌小编沐沐将为您提…

设计模式-01 设计模式单例模式

设计模式-01 设计模式单例模式 目录 设计模式-01 设计模式单例模式 1定义 2.内涵 3.使用示例 4.具体代码使用实践 5.注意事项 6.最佳实践 7.总结 1 定义 单例模式是一种设计模式,它确保一个类只能被实例化一次。它通过在类内部创建类的唯一实例并提供一个全…

实验报告5-Spring MVC实现页面

实验报告5-SpringMVC实现页面 一、需求分析 使用Spring MVC框架,从视图、控制器和模型三方面实验动态页面。模拟实现用户登录,模拟的用户名密码以模型属性方式存放在Spring容器中,控制器相应用户请求并映射参数,页面收集用户数据或…

力扣HOT100 - 131. 分割回文串

解题思路&#xff1a; class Solution {List<List<String>> res new ArrayList<>();List<String> pathnew ArrayList<>();public List<List<String>> partition(String s) {backtrack(s,0);return res;}public void backtrack(Str…

C++复盘(一)

文章目录 常量标识符命名规则数据类型sizeof关键字浮点数字符型转义字符字符串型布尔类型bool 比较运算符switch-case语句rand()随机数种子srand() goto语句一维数组函数函数的声明函数的分文件编写 指针指针所占内存空间空指针野指针const修饰指针1、常量指针2、指针常量3、co…

SpringBoot集成Flowable案例

前言 Flowable 是一个使用 Java 编写的轻量级业务流程引擎。Flowable 流程引擎可用于部署 BPMN2.0 流程定义&#xff08;用于定义流程的行业 XML 标准&#xff09;&#xff0c;创建这些流程定义的流程实例&#xff0c;进行查询&#xff0c;访问运行中或历史的流程实例与相关数…

SAP-MM-SD批次管理的影响点M3530

业务场景: 业务部门在创建物料主数据时,勾选了“批次管理”实际不需要。收货时提示输入批次,不能收货了,那回到物料主数据修改,取消勾选“批次管理”发现取消不了,报错M3530,大致内容如下: “显示错误”按钮仅在对话框模式下出现,而不是在数据传输或大规模维护中。 步…

【代码问题】【Pytorch】训练模型时Loss为NaN或INF

解决方法或者问题排查&#xff1a; 加归一化层&#xff1a; 我的问题是我新增的一个模块与原来的模块得到的张量相加&#xff0c;原张量是归一化后的&#xff0c;我的没有&#xff1a; class Module(nn.Module):def __init__(self,dim,):super().__init__()# 新增一个LayerNo…