【机器学习Python实战】线性回归

🚀个人主页:为梦而生~ 关注我一起学习吧!
💡专栏:机器学习python实战 欢迎订阅!后面的内容会越来越有意思~
内容说明:本专栏主要针对机器学习专栏的基础内容进行python的实现,部分基础知识不再讲解,有需要的可以点击专栏自取~
💡往期推荐(机器学习基础专栏)
【机器学习基础】机器学习入门(1)
【机器学习基础】机器学习入门(2)
【机器学习基础】机器学习的基本术语
【机器学习基础】机器学习的模型评估(评估方法及性能度量原理及主要公式)
【机器学习基础】一元线性回归(适合初学者的保姆级文章)
【机器学习基础】多元线性回归(适合初学者的保姆级文章)
本期内容:针对以上的一元和多元线性回归的梯度下降求解方法,进行代码展示


文章目录

  • 一元线性回归
  • 多元线性回归


一元线性回归

  • 设计思路

首先,class LinearRegression(object):定义一个LinearRegression类,继承自object类。
在这个类中,首先def __init__(self):定义类的构造函数

在构造函数中,初始化线性回归模型的参数self.__Mself.__theta0self.__theta1,以及梯度下降中的步长(学习率)self.__alpha

在这里插入图片描述

线性回归模型是要不断计算输出的,所以定义函数def predict(self, x),用于预测给定输入x对应的输出

同时线性回归的目的是通过迭代,不断的修改参数 θ \theta θ,所以需要定义函数用来做这个工作,它是通过梯度下降的方法来求解的,所以def __cost_theta0(self, X, Y)def __cost_theta1(self, X, Y)这两个方法用于计算代价函数关于参数 θ 0 \theta_0 θ0 θ 1 \theta_1 θ1的偏导数

下面,def train(self, features, target)把上面的每个步骤和到了一起,定义了一个训练方法train,用于通过梯度下降算法找到最优的模型参数 θ 0 \theta_0 θ0 θ 1 \theta_1 θ1,使得代价函数的平方误差最小。在训练过程中,通过迭代更新参数,并输出每次迭代后的参数值

while的每一次迭代中,通过更新参数self.__theta0self.__theta1来逐渐最小化代价函数的平方误差。

if "0:o.5f".format(prevt0) == "0:o.5f".format(self.__theta0) and "0:o.5f".format(prevt1) == "0:o.5f".format(self.__theta1):判断是否达到收敛条件,即两次迭代的参数值没有改变,如果满足条件,则退出循环。

最后,输出最终得到的参数值。

在这里插入图片描述

  • 总体代码实现

定义LinearRegression的class

#!/usr/bin/env python3
# 这是一个Shebang,指定了此脚本要使用的解释器为python3。
import numpyclass LinearRegression(object):# Constructor. Initailize Constants.def __init__(self):super(LinearRegression, self).__init__()self.__M = 0self.__theta0 = 2self.__theta1 = 2# defining Alpha I.E length of steps in gradient descent Or learning Rate.self.__alpha = 0.01def predict(self,x):return (self.__theta0 + x * self.__theta1)# Cost Function fot theta0.def __cost_theta0(self,X,Y):sqrerror = 0.0for i in range(0,X.__len__()):sqrerror += (self.predict(X[i]) - Y[i])return (1/self.__M * sqrerror)# Cost Function fot theta1.def __cost_theta1(self,X,Y):sqrerror = 0.0for i in range(0,X.__len__()):sqrerror += (self.predict(X[i]) - Y[i]) * X[i]return (1/self.__M * sqrerror)# training Data :# Finding Best __theta0 and __theta1 for data such that the Squared  Error is Minimized.def train(self,features,target):# Validate Dataif not features.__len__() == target.__len__():raise Exception("features and target should be of same length")# Initailize M with Size of X and Yself.__M = features.__len__()# gradient descentprevt0, prevt1 = self.__theta0 , self.__theta1while True:tmp0 = self.__theta0 - self.__alpha * (self.__cost_theta0(features,target))tmp1 = self.__theta1 - self.__alpha * (self.__cost_theta1(features,target))self.__theta0, self.__theta1 = tmp0, tmp1print("theta0(b) :", self.__theta0)print("theta1(m) :", self.__theta1)if "0:o.5f".format(prevt0) == "0:o.5f".format(self.__theta0) and "0:o.5f".format(prevt1) == "0:o.5f".format(self.__theta1):breakprevt0, prevt1 = self.__theta0 , self.__theta1# Training Completed.# log __theta0 __theta1print("theta0(b) :", self.__theta0)print("theta1(m) :", self.__theta1)

样例测试

from LinearRegression_OneVariables import LinearRegression
import numpy as npX = np.array([1,2,3,4,5,6,7,8,9,10])# Y = 0 + 1X
Y = np.array([1,2,3,4,5,6,7,8,9,10])modal = LinearRegression.LinearRegression()modal.train(X,Y)print(modal.predict(14))

多元线性回归

  • 设计思路

首先,将文件导入,打乱顺序并选择训练集。

data=pd.read_csv("c:\\windquality.csv")data_array=data.values#shuffling for train test spplit
np.random.shuffle(data_array)train,test=data_array[:1200,:],data_array[1200:,:]
x_train=train[:,:-1]
x_test=test[:,:-1]
y_train=train[:,-1]
y_test=test[:,-1]

在这里插入图片描述

然后初始化参数,注意这里是多元的,所以有多个参数需要初始化。包括迭代次数和学习率

coef1=0
coef2=0
coef3=0
coef4=0
coef5=0
coef6=0
coef7=0
coef8=0
coef9=0
coef10=0
coef11=0
epoch=1000
alpha=.0001

在这里插入图片描述

然后使用梯度下降算法进行计算

总体代码实现

import pandas as pd
import numpy as np
import matplotlib.pyplot as pltdata=pd.read_csv("c:\\windquality.csv")data_array=data.values#shuffling for train test spplit
np.random.shuffle(data_array)train,test=data_array[:1200,:],data_array[1200:,:]
x_train=train[:,:-1]
x_test=test[:,:-1]
y_train=train[:,-1]
y_test=test[:,-1]coef1=0
coef2=0
coef3=0
coef4=0
coef5=0
coef6=0
coef7=0
coef8=0
coef9=0
coef10=0
coef11=0
epoch=1000
alpha=.0001
c=0
n=len(y_train)
for i in range(epoch):y_pred=((coef1*x_train[:,0])+(coef2*x_train[:,1])+(coef3*x_train[:,2])+(coef4*x_train[:,3])+(coef5*x_train[:,4])+(coef6*x_train[:,5])+(coef7*x_train[:,6])+(coef8*x_train[:,7])+(coef9*x_train[:,8])+(coef10*x_train[:,9])+(coef11*x_train[:,10])+c)#to predict drivativeintercept=(-1/n)*sum(y_train-y_pred)dev1=(-1/n)*sum(x_train[:,0]*(y_train-y_pred))dev2=(-1/n)*sum(x_train[:,1]*(y_train-y_pred))dev3=(-1/n)*sum(x_train[:,2]*(y_train-y_pred))dev4=(-1/n)*sum(x_train[:,3]*(y_train-y_pred))dev5=(-1/n)*sum(x_train[:,4]*(y_train-y_pred))dev6=(-1/n)*sum(x_train[:,5]*(y_train-y_pred))dev7=(-1/n)*sum(x_train[:,6]*(y_train-y_pred))dev8=(-1/n)*sum(x_train[:,7]*(y_train-y_pred))dev9=(-1/n)*sum(x_train[:,8]*(y_train-y_pred))dev10=-1/n*sum(x_train[:,9]*(y_train-y_pred))dev11=-1/n*sum(x_train[:,10]*(y_train-y_pred))#linec=c-alpha*interceptcoef1=coef1-alpha*dev1coef2=coef2-alpha*dev2coef3=coef3-alpha*dev3coef4=coef4-alpha*dev4coef5=coef5-alpha*dev5coef6=coef6-alpha*dev6coef7=coef7-alpha*dev7coef8=coef8-alpha*dev8coef9=coef9-alpha*dev9coef10=coef10-alpha*dev10coef11=coef11-alpha*dev11print("\nintercept:",c,"\ncoefficient1:",coef1,"\ncoefficient2:",coef2,"\ncoefficient3:",coef3,"\ncoefficient4:",coef4,"\ncoefficient5:",coef5,"\ncoefficient6:",coef6,"\ncoefficient7:",coef7,"\ncoefficient8:",coef8,"\ncoefficient9:",coef9,"\ncoefficient10",coef10,   "\ncoefficient11",coef11)#Calculating the predicted values
predicted_values = []
for i in range(0,399):y_pred = ((coef1 * x_test[i,0]) + (coef2 * x_test[i,1]) + (coef3 * x_test[i,2]) + (coef4 * x_test[i,3]) + (coef5 * x_test[i,4]) + (coef6 * x_test[i,5]) + (coef7 * x_test[i,6]) + (coef8 * x_test[i,7]) + (coef9 * x_test[i,8]) + (coef10 * x_test[i,9]) + (coef11 * x_test[i,10]) + intercept)predicted_values.append(y_pred)for i in range(len(predicted_values)):print(predicted_values[i])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/191055.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Web之CSS笔记

Web之HTML、CSS、JS 二、CSS(Cascading Style Sheets层叠样式表)CSS与HTML的结合方式CSS选择器CSS基本属性CSS伪类DIVCSS轮廓CSS边框盒子模型CSS定位 Web之HTML笔记 二、CSS(Cascading Style Sheets层叠样式表) Css是种格式化网…

浏览器页面被恶意控制时的解决方法

解决360流氓软件控制浏览器页面 提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、接受360安全卫士的好意(尽量不要选)二、拒绝360安全卫士的好意(强烈推荐)第…

Leetcode—876.链表的中间结点【简单】

2023每日刷题(三十三) Leetcode—876.链表的中间结点 实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* middleNode(struct ListNode* head) {struct ListNod…

Redis持久化机制详解

使用缓存的时候,我们经常需要对内存中的数据进行持久化也就是将内存中的数据写入到硬盘中。大部分原因是为了之后重用数据(比如重启机器、机器故障之后恢复数据),或者是为了做数据同步(比如 Redis 集群的主从节点通过 …

传输层——TCP协议

文章目录 一.TCP协议二.TCP协议格式1.序号与确认序号2.窗口大小3.六个标志位 三.确认应答机制(ACK)四.超时重传机制五.连接管理机制1.三次握手2.四次挥手 六.流量控制七.滑动窗口八.拥塞控制九.延迟应答十.捎带应答十一.面向字节流十二.粘包问题十三.TCP…

字符串函数详解

一.字母大小写转换函数. 1.1.tolower 结合cppreference.com 有以下结论&#xff1a; 1.头文件为#include <ctype.h> 2.使用规则为 #include <stdio.h> #include <ctype.h> int main() {char ch A;printf("%c\n",tolower(ch));//大写转换为小…

vscode编写verilog的插件【对齐、自动生成testbench文件】

vscode编写verilog的插件&#xff1a; 插件名称&#xff1a;verilog_testbench,用于自动生成激励文件 安装教程&#xff1a;基于VS Code的Testbench文件自动生成方法——基于VS Code的Verilog编写环境搭建SP_哔哩哔哩_bilibili 优化的方法&#xff1a;https://blog.csdn.net…

Jenkins持续集成

1. 持续集成及Jenkins介绍 1.1. 软件开发生命周期 软件开发生命周期又叫做SDLC&#xff08;Software Development Life Cycle&#xff09;&#xff0c;它是集合了计划、开发、测试和部署过程的集合。如下图所示 &#xff1a; 需求分析 这是生命周期的第一阶段&#xff0c;根据…

【C++】模板初阶 【 深入浅出理解 模板 】

模板初阶 前言&#xff1a;泛型编程一、函数模板&#xff08;一&#xff09;函数模板概念&#xff08;二&#xff09;函数模板格式&#xff08;三&#xff09;函数模板的原理&#xff08;四&#xff09;函数模板的实例化&#xff08;五&#xff09;模板参数的匹配原则 三、类模…

毅速丨嫁接打印在模具制造中应用广泛

在模具行业中&#xff0c;3D打印随形水路已经被广泛认可&#xff0c;它可以提高冷却效率&#xff0c;从而提高产品良率。然而&#xff0c;全打印模具制造的成本相对较高&#xff0c;因为需要使用金属3D打印机和专用材料。为了节省打印成本&#xff0c;同时利用3D打印的优势&…

IDEA创建文件添加作者及时间信息

前言 当使用IDEA进行软件开发时&#xff0c;经常需要在代码文件中添加作者和时间信息&#xff0c;以便更好地维护和管理代码。 但是如果每次都手动编辑 以及修改那就有点浪费时间了。 实践 其实我们可以将注释日期 作者 配置到 模板中 同时配置上动态获取内容 例如时间 这样…

ThinkPHP 系列漏洞

目录 2、thinkphp5 sql注入2 3、thinkphp5 sql注入3 4、 thinkphp5 SQL注入4 5、 thinkphp5 sql注入5 6、 thinkphp5 sql注入6 7、thinkphp5 文件包含漏洞 8、ThinkPHP5 RCE 1 9、ThinkPHP5 RCE 2 10、ThinkPHP5 rce3 11、ThinkPHP 5.0.X 反序列化漏洞 12、ThinkPHP…