机器学习复习(2)——线性回归SGD优化算法

目录

线性回归代码

线性回归理论

SGD算法

手撕线性回归算法

模型初始化

定义模型主体部分

定义线性回归模型训练过程

数据demo准备

模型训练与权重参数

定义线性回归预测函数

定义R2系数计算

可视化展示 

预测结果

训练过程 

sklearn进行机器学习

线性回归代码

class My_Model(nn.Module):def __init__(self, input_dim):super(My_Model, self).__init__()# 矩阵的维度(dimensions) self.layers = nn.Sequential(nn.Linear(input_dim, 16),nn.ReLU(),nn.Linear(16, 8),nn.ReLU(),nn.Linear(8, 1))def forward(self, x):x = self.layers(x)x = x.squeeze(1) # (B, 1) -> (B)return x

线性回归理论

回归算法是相对分类算法而言的,与我们想要预测的目标变量y的值类型有关。

如果目标变量y是分类型变量,如预测用户的性别(男、女),预测月季花的颜色(红、白、黄……),那我们就需要用分类算法去拟合训练数据并做出预测;

如果y是连续型变量,如预测用户的收入(4千,2万,10万……),预测患肺癌的概率(1%,50%,99%……),我们则需要用回归模型。

有时分类问题也可以转化为回归问题。可以用回归模型先预测出患肺癌的概率,然后再给定一个阈值,例如50%,概率值在50%以下为A类,50%以上为B类。

一元线性回归公式:

 具象化含义:

SGD算法

手撕线性回归算法

模型初始化

### 初始化模型参数
def initialize_params(dims):'''输入:dims:训练数据变量维度输出:w:初始化权重参数值b:初始化偏差参数值'''# 初始化权重参数为零矩阵w = np.zeros((dims, 1))# 初始化偏差参数为零b = 0return w, b
w,b=initialize_params(3)#用于测试
print("w初始化是",w)
print("b初始化是",b)

运行结果:

定义模型主体部分

包括线性回归公式、均方损失和参数偏导三部分
def linear_loss(X, y, w, b):'''输入:X:输入变量矩阵y:输出标签向量w:变量参数权重矩阵b:偏差项输出:y_hat:线性模型预测输出loss:均方损失值dw:权重参数一阶偏导db:偏差项一阶偏导'''# 训练样本数量num_train = X.shape[0]# 训练特征数量num_feature = X.shape[1]# 线性回归预测输出y_hat = np.dot(X, w) + b# 计算预测输出与实际标签之间的均方损失loss = np.sum((y_hat-y)**2)/num_train# 基于均方损失对权重参数的一阶偏导数dw = np.dot(X.T, (y_hat-y)) /num_train# 基于均方损失对偏差项的一阶偏导数db = np.sum((y_hat-y)) /num_trainreturn y_hat, loss, dw, db

定义线性回归模型训练过程

### 定义线性回归模型训练过程
def linear_train(X, y, learning_rate=0.01, epochs=10000):'''输入:X:输入变量矩阵y:输出标签向量learning_rate:学习率epochs:训练迭代次数输出:loss_his:每次迭代的均方损失params:优化后的参数字典grads:优化后的参数梯度字典'''# 记录训练损失的空列表loss_his = []# 初始化模型参数w, b = initialize_params(X.shape[1])# 迭代训练for i in range(1, epochs):# 计算当前迭代的预测值、损失和梯度y_hat, loss, dw, db = linear_loss(X, y, w, b)
#y_hat是预测值,loss是损失,dw是权重参数一阶偏导,db是偏差项一阶偏导# 基于梯度下降的参数更新w += -learning_rate * dwb += -learning_rate * db# 记录当前迭代的损失loss_his.append(loss)# 每1000次迭代打印当前损失信息if i % 10000 == 0:print('epoch %d loss %f' % (i, loss))# 将当前迭代步优化后的参数保存到字典params = {'w': w,'b': b}# 将当前迭代步的梯度保存到字典grads = {'dw': dw,'db': db}     return loss_his, params, grads

其中的shape操作说明:

import numpy as np
# 创建一个示例的训练数据集 X
X = np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12],[13, 14, 15]])
# 计算训练样本数量
shape0 = X.shape[0]
shape1 = X.shape[1]
print("shape0是",shape0)
print("shape1是",shape1)

运行结果:

数据demo准备

from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
data = diabetes.data
target = diabetes.target 
print(data.shape)
print(target.shape)
print(data[:5])
print(target[:5])
###########################################
# 导入sklearn diabetes数据接口
from sklearn.datasets import load_diabetes
# 导入sklearn打乱数据函数
from sklearn.utils import shuffle
# 获取diabetes数据集
diabetes = load_diabetes()
# 获取输入和标签
data, target = diabetes.data, diabetes.target 
# 打乱数据集
X, y = shuffle(data, target, random_state=13)
# 按照8/2划分训练集和测试集
offset = int(X.shape[0] * 0.8)
# 训练集
X_train, y_train = X[:offset], y[:offset]
# 测试集
X_test, y_test = X[offset:], y[offset:]
# 将训练集改为列向量的形式
y_train = y_train.reshape((-1,1))
# 将验证集改为列向量的形式
y_test = y_test.reshape((-1,1))
# 打印训练集和测试集维度
print("X_train's shape: ", X_train.shape)
print("X_test's shape: ", X_test.shape)
print("y_train's shape: ", y_train.shape)
print("y_test's shape: ", y_test.shape)

模型训练与权重参数

# 线性回归模型训练
loss_his, params, grads = linear_train(X_train, y_train, 0.01, 200000)
# 打印训练后得到模型参数
print(params)

定义线性回归预测函数

### 定义线性回归预测函数
def predict(X, params):'''输入:X:测试数据集params:模型训练参数输出:y_pred:模型预测结果'''# 获取模型参数w = params['w']b = params['b']# 预测y_pred = np.dot(X, w) + breturn y_pred
# 基于测试集的预测
y_pred = predict(X_test, params)
# 打印前五个预测值
y_pred[:5]

定义R2系数计算

R2系数,也称为决定系数(Coefficient of Determination),是一种用于评估回归模型拟合优度的统计指标。它表示模型对观测数据的方差解释比例,通常用于衡量回归模型的拟合程度。

R2系数的取值范围在0到1之间,具体含义如下:

  • 如果R2等于0,表示模型未能解释目标变量的任何方差,即模型无法拟合数据。
  • 如果R2等于1,表示模型完美拟合了数据,能够解释目标变量的所有方差。
  • 如果R2在0和1之间,表示模型能够解释一部分目标变量的方差,数值越接近1,说明模型的拟合程度越好。

计算公式如下:

其中:

  • SSR(Sum of Squares of Residuals)表示模型的残差平方和,即实际观测值与模型预测值之间的差异的平方和。
  • SST(Total Sum of Squares)表示总平方和,即实际观测值与观测值的均值之间的差异的平方和。

R2系数越接近1,说明模型对数据的拟合越好,而越接近0则表示模型的拟合效果较差。这个指标对于评估回归模型的性能非常有用,帮助我们了解模型解释数据方差的程度。

### 定义R2系数函数
def r2_score(y_test, y_pred):'''输入:y_test:测试集标签值y_pred:测试集预测值输出:r2:R2系数'''# 测试标签均值y_avg = np.mean(y_test)# 总离差平方和ss_tot = np.sum((y_test - y_avg)**2)# 残差平方和ss_res = np.sum((y_test - y_pred)**2)# R2计算r2 = 1 - (ss_res/ss_tot)return r2

可视化展示 

预测结果

import matplotlib.pyplot as plt
f = X_test.dot(params['w']) + params['b']plt.scatter(range(X_test.shape[0]), y_test)
plt.plot(f, color = 'darkorange')
plt.xlabel('X_test')
plt.ylabel('y_test')
plt.show();

运行结果:

训练过程 

plt.plot(loss_his, color='blue')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.show()

运行结果:

sklearn进行机器学习

 和torch.nn类似:封装好了linear函数,直接掉包

### sklearn版本为1.0.2
# 导入线性回归模块
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score
# 创建模型实例
regr = linear_model.LinearRegression()
# 模型拟合
regr.fit(X_train, y_train)
# 模型预测
y_pred = regr.predict(X_test)
# 打印模型均方误差
print("Mean squared error: %.2f" % mean_squared_error(y_test, y_pred))
# 打印R2
print('R2 score: %.2f' % r2_score(y_test, y_pred))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/449710.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SqlSever查询某个表的列名称、说明、备注、注释,类型等信息

背景:在工程项目中,有时需要对数据查询进行展示,常规的表格展示虽然能解决大部分问题;但在数据量比较大的情况就如果一次完整的展示信息,势必会造成数据加载中增加耗时,影响数据的展示效果;常规的解决方案都是在数据加载中采取分页的模式,降低数据的加载耗时;但如果要…

AutoCAD .NET 层次结构介绍

AutoCAD .NET API 提供了一种面向对象的编程接口,通过它可以与AutoCAD进行深度集成和自定义功能开发。以下是基于.NET框架下AutoCAD对象层次结构的基本介绍: Autodesk.AutoCAD.ApplicationServices 命名空间 根对象,代表运行中的AutoCAD应用程…

网络空间测绘在安全领域的应用(上)

近年来,网络空间测绘已经跻身为网络通信技术、网络空间安全、地理学等多学科融合的前沿领域。 该领域聚焦于构建网络空间信息的“全息地图”,致力于建立面向全球网络的实时观测、准确采样、映射和预测的强大基础设施。 通过采用网络探测、数据采集、信…

Docker搭建MySQL8主从复制

之前文章我们了解了面试官:说一说Binlog是怎么实现的,这里我们用Docker搭建主从复制环境。 docker安装主从MySQL 这里我们使用MySQL8.0.32版本: 主库配置 master.cnf //基础配置 [client] port3306 socket/var/run/mysqld/mysql.sock [m…

HTMLCSS JavaScript 基础

HTML复杂建立骨架。 CSS复杂装修。 JS负责定义行为和交互。 示例功能&#xff0c;点击按钮&#xff0c;数量增加&#xff0c;图片交互显示。 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"…

C++结构体拷贝时发生的vector iterators incompatible等崩溃情况

文章目录 结构体拷贝时的容器异常崩溃结构体拷贝崩溃的另一种情况结构体拷贝时的容器异常崩溃 自定义一个结构体 struct MMM{int a;std::vector<int> b; }在拷贝时发生异常 代码是 MMM m = mi

Pytest测试用例参数化

pytest.mark.parametrize(参数名1,参数名2...参数n, [(参数名1_data1,参数名2_data1...参数名n_data1),(参数名1_data2,参数名2_data2...参数名n_data2)]) 场景&#xff1a; 定义一个登录函数test_login,传入参数为name,password&#xff0c;需要用多个账号去测试登录功能 # …

【笔记】Android 常用编译模块和输出产物路径

模块&产物路径 具体编译到软件的路径要看编译规则的分区&#xff0c;代码中模块编译输出的产物基本对应。 Android 代码模块 编译产物路径设备adb路径Comment 模块device/mediatek/system/common/ 资源overlay/telephony/frameworks/base/core 文件举例res/res/values-m…

【奶奶看了都会】《幻兽帕鲁》云服务器部署教程

在帕鲁的世界&#xff0c;你可以选择与神奇的生物「帕鲁」一同享受悠闲的生活&#xff0c;也可以投身于与偷猎者进行生死搏斗的冒险。帕鲁可以进行战斗、繁殖、协助你做农活&#xff0c;也可以为你在工厂工作。你也可以将它们进行售卖&#xff0c;或肢解后食用。 《幻兽帕鲁》官…

关于Ubuntu下docker-mysql:ERROR 2002报错

报错场景&#xff1a; mysql容器创建好后登录mysql时即使密码正确也是报出下方提示&#xff1a; 原因是在创建mysql容器在创建时本地目录缺失&#xff0c; 先去自建一个目录&#xff0c;例如&#xff1a; /opt/my_sql 正确完整目录如下&#xff1a; docker run --namemys…

oracle19C 密码包含特殊字符@ 导致ORA-12154

oracle 19C 密码包含特殊字符 出现登录失败&#xff0c;针对此问题一次说个明白 ORA-12154: TNS:could not resolve the connect identifier specified Oracle 19c之前密码是可以包含特殊字符&#xff0c;但是如果包含特殊字符需要双引号 比如oracle11g 正常 如果密码包含特殊…

Scrum敏捷开发企业培训-敏捷研发管理

课程简介 Scrum是目前运用最为广泛的敏捷开发方法&#xff0c;是一个轻量级的项目管理和产品研发管理框架。 这是一个两天的实训课程&#xff0c;面向研发管理者、项目经理、产品经理、研发团队等&#xff0c;旨在帮助学员全面系统地学习Scrum和敏捷开发, 帮助企业快速启动敏…