实现一个简单的线性回归和多项式回归(2)

对于多项式回归,可以同样使用前面线性回归中定义的LinearRegression算子、训练函数train、均方误差函数mean_squared_error,生成数据集create_toy_data,这里就不多做赘述咯~

        拟合的函数为

def sin(x):y = torch.sin(2 * math.pi * x)return y

1.数据集的建立

func = sin
interval = (0,1)
train_num = 15
test_num = 10
noise = 0.5
X_train, y_train = create_toy_data(func=func, interval=interval, sample_num=train_num, noise = noise)
X_test, y_test = create_toy_data(func=func, interval=interval, sample_num=test_num, noise = noise)X_underlying = torch.linspace(interval[0], interval[1], 100)
y_underlying = sin(X_underlying)# 绘制图像
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")
#plt.scatter(X_test, y_test, facecolor="none", edgecolor="r", s=50, label="test data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.legend(fontsize='x-large')
plt.savefig('ml-vis2.pdf')
plt.show()

        生成结果: 

2.模型构建

# 多项式转换
def polynomial_basis_function(x, degree=2):"""输入:- x: tensor, 输入的数据,shape=[N,1]- degree: int, 多项式的阶数example Input: [[2], [3], [4]], degree=2example Output: [[2^1, 2^2], [3^1, 3^2], [4^1, 4^2]]注意:本案例中,在degree>=1时不生成全为1的一列数据;degree为0时生成形状与输入相同,全1的Tensor输出:- x_result: tensor"""if degree == 0:return torch.ones(x.size(), dtype=torch.float32)x_tmp = xx_result = x_tmpfor i in range(2, degree + 1):x_tmp = torch.multiply(x_tmp, x)  # 逐元素相乘x_result = torch.concat((x_result, x_tmp), dim=-1)return x_result

        我的理解多项式回归的模型的建立更像是将原来的一个x变成x x^2 x^3 ... x^degree这个过程

3.模型训练

plt.rcParams['figure.figsize'] = (12.0, 8.0)for i, degree in enumerate([0, 1, 3, 8]):  # []中为多项式的阶数model = Linear(degree)X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), degree)X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), degree)model = optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))  # 拟合得到参数y_underlying_pred = model(X_underlying_transformed).squeeze()print(model.params)# 绘制图像plt.subplot(2, 2, i + 1)plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")plt.plot(X_underlying, y_underlying_pred, c='#f19ec2', label="predicted function")plt.ylim(-2, 1.5)plt.annotate("M={}".format(degree), xy=(0.95, -1.4))# plt.legend(bbox_to_anchor=(1.05, 0.64), loc=2, borderaxespad=0.)
plt.legend(loc='lower left', fontsize='x-large')
plt.savefig('ml-vis3.pdf')
plt.show()

         训练结果:

观察可视化结果,红色的曲线表示不同阶多项式分布拟合数据的结果:

* 当 $M=0$$M=1$时,拟合曲线较简单,模型欠拟合;

* 当 $M=8$ 时,拟合曲线较复杂,模型过拟合;

* 当 $M=3$ 时,模型拟合最为合理。

4.模型评估

        下面通过均方误差来衡量训练误差、测试误差以及在没有噪音的加入下sin函数值与多项式回归值之间的误差,更加真实地反映拟合结果。多项式分布阶数从0到8进行遍历。

# 训练误差和测试误差
training_errors = []
test_errors = []
distribution_errors = []# 遍历多项式阶数
for i in range(9):model = Linear(i)X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), i)X_test_transformed = polynomial_basis_function(X_test.reshape([-1, 1]), i)X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), i)optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))y_train_pred = model(X_train_transformed).squeeze()y_test_pred = model(X_test_transformed).squeeze()y_underlying_pred = model(X_underlying_transformed).squeeze()train_mse = mean_squared_error(y_true=y_train, y_pred=y_train_pred).item()training_errors.append(train_mse)test_mse = mean_squared_error(y_true=y_test, y_pred=y_test_pred).item()test_errors.append(test_mse)# distribution_mse = mean_squared_error(y_true=y_underlying, y_pred=y_underlying_pred).item()# distribution_errors.append(distribution_mse)print("train errors: \n", training_errors)
print("test errors: \n", test_errors)
# print ("distribution errors: \n", distribution_errors)# 绘制图片
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.plot(training_errors, '-.', mfc="none", mec='#e4007f', ms=10, c='#e4007f', label="Training")
plt.plot(test_errors, '--', mfc="none", mec='#f19ec2', ms=10, c='#f19ec2', label="Test")
# plt.plot(distribution_errors, '-', mfc="none", mec="#3D3D3F", ms=10, c="#3D3D3F", label="Distribution")
plt.legend(fontsize='x-large')
plt.xlabel("degree")
plt.ylabel("MSE")
plt.savefig('ml-mse-error.pdf')
plt.show()

        可视化结果如下:

由可视化结果可得:

  1. 当阶数较低的时候,模型的表示能力有限,训练误差和测试误差都很高,代表模型欠拟合;
  2. 当阶数较高的时候,模型表示能力强,但将训练数据中的噪声也作为特征进行学习,一般情况下训练误差继续降低而测试误差显著升高,代表模型过拟合。

        对于模型过拟合的情况,可以引入正则化方法,通过向误差函数中添加一个惩罚项来避免系数倾向于较大的取值。

degree = 8 # 多项式阶数
reg_lambda = 0.0001 # 正则化系数X_train_transformed = polynomial_basis_function(X_train.reshape([-1,1]), degree)
X_test_transformed = polynomial_basis_function(X_test.reshape([-1,1]), degree)
X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1,1]), degree)model = Linear(degree) optimizer_lsm(model,X_train_transformed,y_train.reshape([-1,1]))y_test_pred=model(X_test_transformed).squeeze()
y_underlying_pred=model(X_underlying_transformed).squeeze()model_reg = Linear(degree) optimizer_lsm(model_reg,X_train_transformed,y_train.reshape([-1,1]),reg_lambda=reg_lambda)y_test_pred_reg=model_reg(X_test_transformed).squeeze()
y_underlying_pred_reg=model_reg(X_underlying_transformed).squeeze()mse = mean_squared_error(y_true = y_test, y_pred = y_test_pred).item()
print("mse:",mse)
mes_reg = mean_squared_error(y_true = y_test, y_pred = y_test_pred_reg).item()
print("mse_with_l2_reg:",mes_reg)# 绘制图像
plt.scatter(X_train, y_train, facecolor="none", edgecolor="#e4007f", s=50, label="train data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.plot(X_underlying, y_underlying_pred, c='#e4007f', linestyle="--", label="$deg. = 8$")
plt.plot(X_underlying, y_underlying_pred_reg, c='#f19ec2', linestyle="-.", label="$deg. = 8, \ell_2 reg$")
plt.ylim(-1.5, 1.5)
plt.annotate("lambda={}".format(reg_lambda), xy=(0.82, -1.4))
plt.legend(fontsize='large')
plt.savefig('ml-vis4.pdf')
plt.show()

可视化结果为:

         多项式回归的难点就是是否真正理解了线性回归中的算子,均方误差等函数的目的和用法,其他的就是简单的函数调用问题啦~不懂的话还是建议多看看线性函数,先把线性函数的看明白最好~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/130414.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电机控制——高数基础

最近开始学习电机控制,将会写一个系列学习笔记,作为一个新手,肯定有理解不到位或错误的地方,仅供大家参考,欢迎交流指正。 先复习一下微积分和自动控制原理,大学学的忘得一干二净。本文大多摘自维基百科 …

信息系统项目管理师第四版学习笔记——项目进度管理

项目进度管理过程 项目进度管理过程包括:规划进度管理、定义活动、排列活动顺序、估算活动持续时间、制订进度计划、控制进度。 规划进度管理 规划进度管理是为规划、编制、管理、执行和控制项目进度而制定政策、程序和文档的过程。本过程的主要作用是为如何在…

uniapp获取公钥、MD5,‘keytool‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件。

获取MD5、SHA1、SHA256指纹信息 通过命令的形式获取 winr调出黑窗口cd到证书所在目录输入keytool -list -v -keystore test.keystore,其中 test.keystore为你的证书名称加文件后缀按照提示输入你的证书密码,就可以查看证书的信息 通过uniapp云端查看(证书是在DClou…

微服务 BFF 架构设计

在现代软件开发中,由于程序、团队、数据规模太大,需要把企业的业务能力进行复用,将领域服务剥离,提供通用能力,避免重复建设和代码;另外服务功能的弹性能力不一样,比如定时任务、数据同步明确的…

目标识别项目实战:基于Yolov7-LPRNet的动态车牌目标识别算法模型(二)

前言 目标识别如今以及迭代了这么多年,普遍受大家认可和欢迎的目标识别框架就是YOLO了。按照官方描述,YOLOv8 是一个 SOTA 模型,它建立在以前 YOLO 版本的成功基础上,并引入了新的功能和改进,以进一步提升性能和灵活性…

```,```中间添加 # + 空格 + 空行后遇到的底部空行出错,书接上回,处理空行

【python查找替换:查找空行,空行前后添加,中间添加 # 空格 空行后遇到的第1行文字? - CSDN App】http://t.csdnimg.cn/QiKCV def is_blank(line):return len(line.strip()) 0txt 时间戳: ("%Y-%m-%d %H:%M:…

【Solidity】智能合约案例——①食品溯源合约

目录 一、合约源码分析: 二、合约整体流程: 1.部署合约 2.管理角色 3.食品信息管理 4.食品溯源管理 一、合约源码分析: Producer.sol:生产者角色的管理合约,功能为:添加新的生产者地址、移除生产者地址、判断角色地址…

关于Go语言的底层,Channel

1.Channel 介绍一下Channel(有缓冲和无缓冲) Go 语言中,不要通过共享内存来通信,而要通过通信来实现内存共享。Go 的CSP(Communicating Sequential Process)并发模型,中文可以叫做通信顺序进程,是通过 gor…

35.树与二叉树练习(1)(王道第5章综合练习)

【所用的树,队列,栈的基本操作详见上一节代码】 试题1(王道5.3.3节第3题): 编写后序遍历二叉树的非递归算法。 参考:34.二叉链树的C语言实现_北京地铁1号线的博客-CSDN博客https://blog.csdn.net/qq_547…

Tomcat 9.0.41在IDEA中乱码问题(IntelliJ IDEA 2022.1.3版本)

1. 乱码的产生是由于编码和解码的编码表不一致引起的。 2. 排查乱码原因 2.1 在idea中启动Tomcat时控制台乱码排查 Tomcat输出日志乱码: 首先查看IDEA控制台,检查发现默认编码是GBK。 再查看Tomcat日志(conf文件下logging.properties)的默…

【鼠标右键菜单添加用VSCode打开文件或文件夹】

鼠标右键菜单添加用VSCode打开文件或文件夹 演示效果如下: 右击文件 或右击文件夹 或在文件夹内空白处右击 方法一:重装软件 重装软件,安装时勾选如图所示方框(如果登录的有账号保存有配置信息可以选择重装软件&#xff0c…

rustlings本地开发环境配置

克隆自己的仓库 首先我们要在github上找到自己仓库并把它克隆到本地 git clone https://github.com/LearningOS/rust-rustlings-2023-autumn-******.git下载插件 rust-analyzer和Git Graph一个可以用来解析rust代码,另一个可以可视化管理git代码库 下载rustling…