机器学习第5天:多项式回归与学习曲线

文章目录

多项式回归介绍

方法与代码

方法描述

分离多项式

学习曲线的作用

场景

学习曲线介绍

欠拟合曲线

示例

结论

过拟合曲线

示例

​结论 


多项式回归介绍

当数据不是线性时我们该如何处理呢,考虑如下数据

import matplotlib.pyplot as plt
import numpy as npnp.random.seed(42)x = 8 * np.random.rand(100, 1) - 4
y = 2*x**2+3*x+np.random.randn(100, 1)plt.scatter(x, y)
plt.show()


方法与代码

方法描述

先讲思路,以这个二元函数为例

y=3*x^{2}+2*x+c

将多项式化为多个单项的,也就是将x的平方和x两个项分离开,然后单独给线性模型处理,求出参数,最后再组合在一起,很好理解,让我们来看一下代码


分离多项式

我们使用机器学习库的PolynomialFeatures来分离多项式

from sklearn.preprocessing import PolynomialFeaturespoly_features = PolynomialFeatures(degree=2, include_bias=False)
x_poly = poly_features.fit_transform(x)
print(x[0])
print(x_poly[0])

运行结果

可以看到,4, 5行代码将原始x和x平方挑选了出来,这时我们再把这个数据进行线性回归

model = LinearRegression()
model.fit(x_poly, y)
print(model.coef_)

 这段代码使用处理后的x拟合y,再打印模型拟合的参数,可以看到模型的两个参数分别是2.9和2左右,而我们的方程的一次参数和二次参数分别是3和2,可见效果还是很好的

把预测的结果绘制出来

model = LinearRegression()
model.fit(x_poly, y)
pre_y = model.predict(x_poly)# 这里是为了让x升序的排序算法, 可以尝试不加这段代码图会变成什么样
sorted_indices = sorted(range(len(x)), key=lambda k: x[k])
x_sorted = [x[i] for i in sorted_indices]
y_sorted = [pre_y[i] for i in sorted_indices]plt.plot(x_sorted, y_sorted, "r-")
plt.scatter(x, y)
plt.show()


学习曲线的作用

场景

设想一下,当你需要预测房价,你也有多组数据,包括离学校距离,交通状况等,但是问题来了,你只知道这些特征可能与房价有关,但并不知道这些特征与房价之间的方程关系,这时我们进行回归任务时,就可能导致欠拟合或者过拟合,幸运的是,我们可以通过学习曲线来判断


学习曲线介绍

学习曲线图就是以损失函数为纵坐标,数据集大小为横坐标,然后在图上画出训练集和验证集两条曲线的图,训练集就是我们用来训练模型的数据,验证集就是我们用来验证模型性能的数据集,我们往往将数据集分成训练集与验证集

我们先定义一个学习曲线绘制函数

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()

 简单介绍一下,这个函数接收模型参数,x,y参数,然后在for循环中,取不同数据集大小来计算RMSE损失(就是\sqrt{MSE}),然后把曲线绘制出来


欠拟合曲线

我们知道欠拟合就是模拟效果不好的情况,可以想象的到,无论在训练集还是验证集上,他的损失都会比较高

示例

我们将线性模型的学习曲线绘制出来

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()x = np.random.rand(100, 1)
y = 2 * x + np.random.rand(100, 1)model = LinearRegression()
plot_learning_curves(model, x, y)

 

结论

可以看到,在只有一点数据时,模型在训练集上效果很好(因为就是开始这一些数据训练出来的),而在验证集上效果不好,但随着训练集增加(模型学习到的越多),验证集上的误差逐渐减小,训练集上的误差增加(因为是学到了一个趋势,不会完全和训练集一样了)

这个图的特征是两条曲线非常接近,且误差都较大(差不多在0.3) ,这是欠拟合的表现(模型效果不好)


过拟合曲线

过拟合就是完全以数据集来模拟曲线,泛化能力很差

示例

我们来试试将一次函数模拟成三次函数,再来看看学习曲线(毫无疑问过拟合了)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipelinedef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()np.random.seed(10)
x = np.random.rand(200, 1)
y = 2 * x + np.random.rand(200, 1)poly_regression = Pipeline([("Poly", PolynomialFeatures(degree=3, include_bias=False)),("Line", LinearRegression())
])plot_learning_curves(poly_regression, x, y)

结论 

这条曲线的特征是训练集的效果比验证集好(两条线之间有一定间距),这往往是过拟合的表现(在训练集上效果好,验证集差,表面泛化能力差) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/189328.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker与Kubernetes结合的难题与技术解决方案

文章目录 1. **版本兼容性**技术解决方案 2. **网络通信**技术解决方案 3. **存储卷的管理**技术解决方案 4. **安全性**技术解决方案 5. **监控和日志**技术解决方案 6. **扩展性与自动化**技术解决方案 7. **多集群管理**技术解决方案 结语 🎈个人主页&#xff1a…

服务器数据恢复—热备盘同步中断导致Raid5数据丢失的数据恢复案例

服务器数据恢复环境: 某单位一台服务器上有一组raid5阵列,该raid5阵列有15块成员盘。上层是一个xfs裸分区,起始位置是0扇区。 服务器故障&检测: 服务器raid5阵列中有硬盘性能表现不稳定,但是由于管理员长时间没有关…

小程序授权获取头像

wxml <view class"header"><text>头像</text><button class"butt" plain"true" open-type"chooseAvatar" bind:chooseavatar"chooseAvatar"><image src"{{HeadUrl}}" mode"&quo…

Python实验项目7 :tkinter GUI编程

&#xff08;1&#xff09;利用tkinter 制作界面&#xff0c;效果图如下&#xff1a; from tkinter import * # winTk() for i in range(1,20):Button(width5,height10,bg"black" if i%20 else"white").pack(side"left") win.geometry("8…

联想Win11系统的任务栏格式调整为居中或居左

一 .目的 联想Win11系统的任务栏格式调整为居中或居左 二 .方法 2.1 鼠标任意放到电脑桌面位置&#xff0c;点击鼠标右键&#xff0c;显示后县级【显示设置】 2.2 个性化→任务栏→任务栏行为→对其方式&#xff1a;按需或个人习惯进行选择【靠左】 2.3 成功调整&#x…

BUG 随想录 - Java: 程序包 com.example.xxx 不存在

目录 一、BUG 复现 二、解决问题 一、BUG 复现 背景&#xff1a;通过 feign 的最佳实践&#xff0c;将 feign 单独提取成一个微服务&#xff0c;接着在需要远程调用的微服务中引入 feign 模块&#xff0c;并在启动类通过 EnableFeignClients 声明指定的 Feign 客户端. 出现问题…

【数据结构与算法】JavaScript实现双向链表

文章目录 一、双向链表简介二、封装双向链表类2.0.创建双向链表类2.1.append(element)2.2.toString()汇总2.3.insert(position,element)2.4.get(position)2.5.indexOf(element)2.7.update(position,element)2.8.removeAt(position)2.9.其他方法2.10.完整实现 三、链表结构总结3…

nginx学习(1)

一、下载安装NGINX&#xff1a; 先安装gcc-c编译器 yum install gcc-c yum install -y openssl openssl-devel&#xff08;1&#xff09;下载pcre-8.3.7.tar.gz 直接访问&#xff1a;http://downloads.sourceforge.net/project/pcre/pcre/8.37/pcre-8.37.tar.gz&#xff0c;就…

硬件开发笔记(十二):RK3568底板电路电源模块和RTC模块原理图分析

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/134429973 红胖子网络科技博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬…

深度学习论文解读:比较ResNet和ViT差异

前言 计算机视觉、机器学习&#xff0c;这两个词会让你想到什么&#xff1f; 相信绝大多数人第一反应都是CNN&#xff0c;而持续关注这些领域发展的人&#xff0c;则会进一步联想到近几年大火的Transformer&#xff0c;它不仅在自然语言相关任务上表现优秀&#xff0c;在图像…

【JUC】五、线程的第三种创建方式 Callable

文章目录 1、Callable概述2、FutureTask Java基础中&#xff0c;了解到的创建线程的两种方式为&#xff1a; 继承Thread类实现Runnable接口 除了以上两种&#xff0c;还可以通过&#xff1a; Callable接口&#xff08;since JDK1.5&#xff09;线程池方式 1、Callable概述 …

【uniapp】Google Maps

话不多说 直接上干货 提前申请谷歌地图账号一、新建地图 使用h5获取当前定位或者使用三方uniapp插件 var coords ""navigator.geolocation.getCurrentPosition(function(position) {coords {lat: position.coords.latitude,lng: position.coords.longitude};lats …