机器学习7:pytorch的逻辑回归

一、说明

        逻辑回归模型是处理分类问题的最常见机器学习模型之一。二项式逻辑回归只是逻辑回归模型的一种类型。它指的是两个变量的分类,其中概率用于确定二元结果,因此“二项式”中的“bi”。结果为真或假 — 0 或 1。

        二项式逻辑回归的一个例子是预测人群中 COVID-19 的可能性。一个人要么感染了COVID-19,要么没有,必须建立一个阈值以尽可能准确地区分这些结果。

二、sigmoid函数

        这些预测不适合一条线,就像线性回归模型一样。相反,逻辑回归模型拟合到右侧所示的 sigmoid 函数。

        对于每个 x,生成的 y 值表示结果为 True 的概率。在 COVID-19 示例中,这表示医生对某人感染病毒的信心。在右图中,阴性结果为蓝色,阳性结果为红色。

图片来源:作者

三、过程

        要进行二项式逻辑回归,我们需要做各种事情:

  1. 创建训练数据集。
  2. 使用 PyTorch 创建我们的模型。
  3. 将我们的数据拟合到模型中。

        逻辑回归问题的第一步是创建训练数据集。首先,我们应该设置一个种子来确保我们的随机数据的可重复性。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seed

我们必须使用 PyTorch 的线性模型,因为我们正在处理一个输入 x 和一个输出 y。因此,我们的模型是线性的。为此,我们将使用 PyTorch 的函数:Linear

model = Linear(in_features=1, out_features=1) # use a linear model

接下来,我们必须生成蓝色 X 和红色 X 数据,确保将它们从行向量重塑为列向量。蓝色的在 0 到 7 之间,红色的在 7 到 10 之间。对于 y 值,蓝点表示 COVID-19 测试阴性,因此它们都将是

  1. 对于红点,它们代表 COVID-19 测试呈阳性,因此它们将为 1。下面是代码及其输出:
blue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

现在,我们的代码应如下所示:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y values

四、优化

        我们将使用梯度下降过程来优化 S 形函数的损失。损失是根据函数拟合数据的优度计算的,数据由 S 形曲线的斜率和截距控制。我们需要梯度下降来找到最佳斜率和截距。

        我们还将使用二进制交叉熵(BCE)作为我们的损失函数,或对数损失函数。对于一般的逻辑回归,不包含对数的损失函数将不起作用。

        为了实现BCE作为我们的损失函数,我们将它设置为我们的标准,并将随机梯度下降作为我们优化它的手段。由于这是我们将要优化的函数,我们需要传入模型参数和学习率。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descent

        现在,我们准备开始梯度下降以优化我们的损失。我们必须将梯度归零,通过将我们的数据插入 sigmoid 函数来找到 y-hat 值,计算损失,并找到损失函数的梯度。然后,我们必须迈出一步,确保存储我们的新斜率并为下一次迭代进行拦截。

optimizer.zero_grad()
Yhat = torch.sigmoid(model(X)) 
loss = criterion(Yhat,Y)
loss.backward()
optimizer.step() 

五、收尾

        为了找到最佳斜率和截距,我们本质上是在训练我们的模型。我们必须对多次迭代或纪元应用梯度下降。在此示例中,我们将使用 2,000 个纪元进行演示。

epochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()

将所有代码片段放在一起,我们应该得到以下代码:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import Lineartorch.manual_seed(42)   # set a random seedmodel = Linear(in_features=1, out_features=1) # use a linear modelblue_x = (torch.rand(20) * 7).reshape(-1,1)   # random floats between 0 and 7
blue_y = torch.zeros(20).reshape(-1,1)red_x = (torch.rand(20) * 7+3).reshape(-1,1)  # random floats between 3 and 10
red_y = torch.ones(20).reshape(-1,1)X = torch.vstack([blue_x, red_x])   # matrix of x values
Y = torch.vstack([blue_y, red_y])   # matrix of y valuesepochs = 2000   # run 2000 iterations
criterion = nn.BCELoss()    # implement binary cross entropy loss functionoptimizer = torch.optim.SGD(model.parameters(), lr = .1) # stochastic gradient descentfor i in range(epochs):optimizer.zero_grad()Yhat = torch.sigmoid(model(X))loss = criterion(Yhat,Y)loss.backward()optimizer.step()print(f"epoch: {i+1}")print(f"loss: {loss: .5f}")print(f"slope: {model.weight.item(): .5f}")print(f"intercept: {model.bias.item(): .5f}")print()
两千个时期后的最终输出:epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314

两千个时期后的最终输出:

epoch: 2000
loss:  0.53861
slope:  0.61276
intercept: -3.17314 

六、可视化

        最后,我们可以将数据与 sigmoid 函数一起绘制,以获得以下可视化效果:

x = np.arange(0,10,.1)
y = model.weight.item()*x + model.bias.item()plt.plot(x, 1/(1 + np.exp(-y)), color="green")plt.xlim(0,10)
plt.scatter(blue_x, blue_y, color="blue")
plt.scatter(red_x, red_y, color="red")plt.show()

图片来源:作者

七、局限性

        二元分类的最大问题之一是需要阈值。在逻辑回归的情况下,此阈值应为 x 值,其中 y 为 50%。我们试图回答的问题是将阈值放在哪里?

        在 COVID-19 测试的情况下,原始示例说明了这种困境。如果我们将阈值设置为 x=5,我们可以清楚地看到应该是红色的蓝点和应该是蓝色的红点。

        悬垂的红点称为误报,即模型错误地预测正类的区域。悬垂的蓝点称为假阴性 - 模型错误地预测负类的区域。

 八、结论

        成功的二项式逻辑回归模型将减少假阴性的数量,因为这些假阴性通常会导致最大的危险。患有COVID-19但检测呈阴性对他人的健康和安全构成严重风险。

        通过对可用数据使用二项式逻辑回归,我们可以确定放置阈值的最佳位置,从而有助于减少不确定性并做出更明智的决策。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/127315.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《低代码指南》——低代码维格云服务菜单

简介​ 快速了解付费客户能够获得维格服务团队哪些服务,本篇内容不包含使用免费试用版本的客户。 了解维格表产品价格与功能权益:戳我看价格与权益​ 客户付费后能得到哪些服务项目?​ 常规服务项目:

Covert Communication 与选择波束(毫米波,大规模MIMO,可重构全息表面)

Covert Communication for Spatially Sparse mmWave Massive MIMO Channels 2023 TOC abstract 隐蔽通信,也称为低检测概率通信,旨在为合法用户提供可靠的通信,并防止任何其他用户检测到合法通信的发生。出于下一代通信系统安全链路的强烈…

数据结构P46(2-1~2-4)

2-1编写算法查找顺序表中值最小的结点&#xff0c;并删除该结点 #include <stdio.h> #include <stdlib.h> typedef int DataType; struct List {int Max;//最大元素 int n;//实际元素个数 DataType *elem;//首地址 }; typedef struct List*SeqList;//顺序表类型定…

基于遗传算法的新能源电动汽车充电桩与路径选择(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

如何用ChatGPT学或教英文?5个使用ChatGPT的应用场景!

原文&#xff1a;百度安全验证 AI工具ChatGPT的出现大幅改变许多领域的运作方式&#xff0c;就连「学英文」也不例外&#xff01;我发现ChatGPT应用在英语的学习与教学上非常有意思。 究竟ChatGPT如何改变英文学习者(学生)与教学者(老师)呢&#xff1f; 有5个应用场景我感到…

docker部署Vaultwarden密码共享管理系统

Vaultwarden是一个开源的密码管理器&#xff0c;它是Bitwarden密码管理器的自托管版本。它提供了类似于Bitwarden的功能&#xff0c;允许用户安全地存储和管理密码、敏感数据和身份信息。 Vaultwarden的主要特点包括&#xff1a; 1. 安全的数据存储&#xff1a;Vaultwarden使…

SAP-MM-委外订单合并有什么不同?

计划下了两个委外采购申请&#xff0c;用ME59N合并时会有哪些不同&#xff1f; 1、采购申请2个都是委外采购申请&#xff0c;有BOM组件 因为两个采购申请的条件一致&#xff0c;所以在转采购订单ME59时&#xff0c;会被合并 但是合并之后&#xff0c;BOM组件却不显示了

【Vue】vscode格式刷插件Prettier以及配置项~~保姆级教程

文章目录 前言一、下载插件二、在项目内创建配置文件1.在根目录创建&#xff0c;src同级2.写入配置3.每个字段含义 总结 前言 vscode格式刷&#xff0c;有太多插件了&#xff0c;但是每个的使用&#xff0c;换行都不一样。 这里我推荐一个很多人都推荐了的Prettier 一、下载插…

【SpringCloud】微服务技术栈入门5 - ElasticSearch

ElasticSearch 倒排索引 倒排索引建立&#xff1a;对文章标题进行分词&#xff0c;将每个词存入 term&#xff0c;这些词也对应一个 id 也就是文档 倒排索引检索&#xff1a;假设我们搜索华为手机 分词&#xff1a;“华为”“手机”从数据库中找到对应的两个 key&#xff0c;…

CSS鼠标指针表

(机翻)搬运自:cursor - CSS: Cascading Style Sheets | MDN (mozilla.org) 类型Keyword演示注释全局autoUA将基于当前上下文来确定要显示的光标。例如&#xff0c;相当于悬停文本时的文本。default 依赖于平台的默认光标。通常是箭头。none不会渲染光标。链接&状态contex…

【开发篇】二十、SpringBoot整合RocketMQ

文章目录 1、整合2、消息的生产3、消费4、发送异步消息5、补充&#xff1a;安装RocketMQ 1、整合 首先导入起步依赖&#xff0c;RocketMQ的starter不是Spring维护的&#xff0c;这一点从starter的命名可以看出来&#xff08;不是spring-boot-starter-xxx&#xff0c;而是xxx-s…

微信小程序-2

微信开发文档 https://developers.weixin.qq.com/miniprogram/dev/framework/ 一、app.js中的生命周期函数与globalData(全局变量) 指南 - - - 小程序框架 - - - 注册小程序 删除app.js里的东西&#xff0c;输入App回车&#xff0c;调用生命周期 选项 - - - 重新打开此项目…