【PyTorch】模型选择、欠拟合和过拟合

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 完整代码
      • 2.2.2. 输出结果

1. 理论介绍

  • 将模型在训练数据上拟合的比在潜在分布中更接近的现象称为过拟合, 用于对抗过拟合的技术称为正则化
  • 训练误差和验证误差都很严重, 但它们之间差距很小。 如果模型不能降低训练误差,这可能意味着模型过于简单(即表达能力不足),无法捕获试图学习的模式。 这种现象被称为欠拟合
  • 训练误差是指模型在训练数据集上计算得到的误差。
  • 泛化误差是指模型应用在同样从原始样本的分布中抽取的无限多数据样本时,模型误差的期望。我们永远不能准确地计算出泛化误差,在实际中,我们只能通过将模型应用于一个独立的测试集来估计泛化误差, 该测试集由随机选取的、未曾在训练集中出现的数据样本构成。
  • 影响模型泛化的因素
    • 可调整参数的数量。当可调整参数的数量(有时称为自由度)很大时,模型往往更容易过拟合。
    • 参数采用的值。当权重的取值范围较大时,模型可能更容易过拟合。
    • 训练样本的数量。即使模型很简单,也很容易过拟合只包含一两个样本的数据集,而过拟合一个有数百万个样本的数据集则需要一个极其灵活的模型。
  • 在机器学习中,我们通常在评估几个候选模型后选择最终的模型。 这个过程叫做模型选择。候选模型可能在本质上不同,也可能是不同的超参数设置下的同一类模型。
  • 为了确定候选模型中的最佳模型,我们通常会使用验证集。验证集与测试集十分相似,唯一的区别是验证集是用于确定最佳模型,测试集是用于评估最终模型的性能
  • K K K折交叉验证:当训练数据稀缺时,将原始训练数据分成 K K K个不重叠的子集。 然后执行 K K K次模型训练和验证,每次在 ( K − 1 ) (K-1) (K1)个子集上进行训练, 并在剩余的一个子集(在该轮中没有用于训练的子集)上进行验证。 最后,通过对 K K K次实验的结果取平均来估计训练和验证误差。
  • 引起过拟合的因素
    • 模型复杂度
      模型复杂度
    • 数据集大小
      • 训练数据集中的样本越少,我们就越有可能(且更严重地)过拟合。
      • 给出更多的数据,拟合更复杂的模型可能是有益的; 如果没有足够的数据,简单的模型可能更有用。

2. 实例解析

2.1. 实例描述

使用以下三阶多项式来生成训练和测试数据 y = 5 + 1.2 x − 3.4 x 2 2 ! + 5.6 x 3 3 ! + ϵ where  ϵ ∼ N ( 0 , 0. 1 2 ) . y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2). y=5+1.2x3.42!x2+5.63!x3+ϵ where ϵN(0,0.12).并用1阶(线性模型)、3阶、20阶多项式拟合。

2.2. 代码实现

2.2.1. 完整代码

import os
import numpy as np
import math, torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import trackdef evaluate_loss(dataloader, net, criterion):"""评估模型在指定数据集上的损失"""num_examples = 0loss_sum = 0.0with torch.no_grad():for X, y in dataloader:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)num_examples += y.shape[0]loss_sum += loss.sum()return loss_sum / num_examplesdef load_dataset(*tensors):"""加载数据集"""dataset = TensorDataset(*tensors)return DataLoader(dataset, batch_size, shuffle=True)if __name__ == '__main__':# 全局参数设置num_epochs = 400batch_size = 10learning_rate = 0.01# 创建记录器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir())# 生成数据集max_degree = 20             # 多项式最高阶数n_train, n_test = 100, 100  # 训练集和测试集大小true_w = np.zeros(max_degree+1)true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))np.random.shuffle(features)poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))for i in range(max_degree+1):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!labels = np.dot(poly_features, true_w)labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)poly_features, labels = [torch.as_tensor(x, dtype=torch.float32) for x in [poly_features, labels.reshape(-1, 1)]]def loop(model_degree):# 创建模型net = nn.Linear(model_degree+1, 1, bias=False).cuda()nn.init.normal_(net.weight, mean=0, std=0.01)criterion = nn.MSELoss(reduction='none')optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)# 加载数据集dataloader_train = load_dataset(poly_features[:n_train, :model_degree+1], labels[:n_train])dataloader_test = load_dataset(poly_features[n_train:, :model_degree+1], labels[n_train:])# 训练循环for epoch in track(range(num_epochs), description=f'{model_degree}-degree'):for X, y in dataloader_train:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()writer.add_scalars(f"{model_degree}-degree", {"train_loss": evaluate_loss(dataloader_train, net, criterion),"test_loss": evaluate_loss(dataloader_test, net, criterion),}, epoch)print(f"{model_degree}-degree: weights =", net.weight.data.cpu().numpy())for model_degree in [1, 3, 20]:loop(model_degree)writer.close()

2.2.2. 输出结果

权重

  • 采用1阶多项式(线性模型)拟合
    1
  • 采用3阶多项式拟合
    3
  • 采用20阶多项式拟合
    20

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/255230.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

activemq启动成功但web管理页面却无法访问

前提: 在linux启动activemq成功!本地能ping通linux 处理方案: 确定防火墙是否关闭, 有两种处理方案:第一种-关闭防火墙;第二种-暴漏8161和61616两个端口 netstat -lnpt查看8161和61616端口 注意&#xf…

HarmonyOS开发(十):通知和提醒

1、通知概述 1.1、简介 应用可以通过通知接口发送通知消息,终端用户可以通过通知栏查看通知内容,也可以点击通知来打开应用。 通知使用的的常见场景: 显示接收到的短消息、即使消息...显示应用推送消息显示当前正在进行的事件&#xff0c…

Unity UGUI TextMeshPro实现输入中文和表情包(Emoji)表情

目录 实现中文显示 准备工作 1、打开Window——TextMeshPro——FontAssetCreator 2、把字体文件放入SourceFont中 3、把CharacterSet改为Characters from File 4、把字体库文件放入Characters File 5、设置好参数点击Generate Font Atlas等待完成后保存 6、把生成后保存…

洗鞋机行业分析:2023年市场发展前景及消费现状

随着消费主力的转移,年轻群体在消费中的话语权和影响力越来越大,“精致懒”正在成为潮流。洗鞋机作为消费升级时代的产物,自诞生以来,经过十几年的发展,逐渐被年轻消费者熟知,洗鞋机品牌阵营和产品种类也变…

Redis——某马点评day02——商铺缓存

什么是缓存 添加Redis缓存 添加商铺缓存 Controller层中 /*** 根据id查询商铺信息* param id 商铺id* return 商铺详情数据*/GetMapping("/{id}")public Result queryShopById(PathVariable("id") Long id) {return shopService.queryById(id);} Service…

域名证书(SSL)申请

获取域名证书的步骤如下: 选择认证机构:域名证书必须从受信任的认证机构中申请,如JoySSL、GeoTrust、Thawte等。收集信息:在申请域名证书之前,需要准备一些证明信息,如域名认证、身份证明等。创建CSR&…

C //例10.4 从键盘输入10个学生的有关数据,然后把它们转存到磁盘文件上去。

C程序设计 (第四版) 谭浩强 例10.4 例10.4 从键盘输入10个学生的有关数据,然后把它们转存到磁盘文件上去。 IDE工具:VS2010 Note: 使用不同的IDE工具可能有部分差异。 代码块 方法:使用指针,函数的模块…

一次显著的性能提升,从8s到0.7s

前言 最近我在公司优化了一些慢查询SQL,积累了一些SOL调优的实战经验。 这篇文章从实战的角度出发,给大家分享一下如何做SQL调优。 经过两次优化之后,慢SQL的性能显著提升了,耗时从8s优化到了0.7s。 现在拿出来给大家分享一下…

老老实实的程序员该如何描述自己的缺点

答辩的时候,晋升的时候,面试的时候,你有没有经常遇到一个问题,那就是你觉得自己有什么缺点吗? 目录 1. 每个人都有缺点 2. 这道题在考什么? 3. 我之前是怎么回答的 4. 你可以这样回答试一试 5. 总结 …

ECharts的颜色渐变

目录 一、直接配置参数实现颜色渐变 二、使用ECharts自带的方法实现颜色渐变 一、两种渐变的实现方法 1、直接配置参数实现颜色渐变 横向的渐变: //主要代码 option {xAxis: {type: category,boundaryGap: false,data: [Mon, Tue, Wed, Thu, Fri, Sat, Sun]},yA…

路径规划之PRM算法

系列文章目录 路径规划之Dijkstra算法 路径规划之Best-First Search算法 路径规划之A *算法 路径规划之D *算法 路径规划之PRM算法 路径规划之PRM算法 系列文章目录前言一、前期准备1.栅格地图2.采样3.路标 二、PRM算法1.起源2.流程3. 优缺点4. 实际效果 前言 之前提到的几种…

如何解决el-table中动态添加固定列时出现的行错位

问题描述 在使用el-table组件时,我们有时需要根据用户的操作动态地添加或删除一些固定列,例如操作列或选择列。但是,当我们使用v-if指令来控制固定列的显示或隐藏时,可能会出现表格的行错位的问题,即固定列和非固定列…