机器学习——感知机模型

机器学习系列文章

入门必读:机器学习介绍


文章目录

  • 机器学习系列文章
  • 前言
  • 1. 感知机
    • 1.1 感知机定义
    • 1.2 感知机学习策略
  • 2. 代码实现
    • 2.1 构建数据
    • 2.2 编写函数
    • 2.3 迭代
  • 3. 总结


前言

大家好,大家好✨,这里是bio🦖。这次为大家带来的是感知机模型。下面跟我一起来了解感知机模型吧!

感知机 (Perceptron) 是二类分类的线性分类模型 ,其输入为实例的特征向量 ,输出为实例的类别 ,分别为 +1-1。1957年,由康奈尔航空实验室(Cornell Aeronautical Laboratory)弗兰克·罗森布拉特 (Frank Rosenblatt)提出。它可以被视为一种最简单形式的前馈神经网络,是一种二元线性分类器。在人工神经网络领域中,感知机也被指为单层的人工神经网络,以区别于较复杂的多层感知机。在这里插入图片描述


1. 感知机

1.1 感知机定义

感知器使用特征向量来表示二元分类器,把矩阵上的输入 x \mathcal{x} x(实数值向量)映射到输出值 y \mathcal{y} y 上(一个二元的值)。
f ( x ) = { + 1 , i f w ⋅ x + b > 0 − 1 , e l s e f(x) = \begin{cases} +1,\,\, if\,w\cdot x+b>0\\ -1,\,\,else\\ \end{cases} f(x)={+1,ifwx+b>01,else

w \mathcal{w} w 是实数的表示权重的向量, w ⋅ x \mathcal{w} \cdot \mathcal{x} wx 是点积。 b \mathcal{b} b 是偏置,一个不依赖于任何输入值的常数。


1.2 感知机学习策略

假设训练数据集是线性可分的 ,如下图所示。感知机学习的目标是求得一个能够将训练集正实例点和负实例点完全正确分开的直线 L \mathcal{L} L。 为了找出这样的超平而 , 即确定感知机模型参数 w \mathcal{w} w b \mathcal{b} b ,需要确定一个学习策略 , 即定义损失函数并将损失函数极小化 。

损失函数的一个选择是误分类数据点的数量 。 但是这样的损失函数不是参数 w \mathcal{w} w b \mathcal{b} b 的连续可导函数,不易优化 。 损失函数的另一个选择是误分类数据点到直线 L \mathcal{L} L 的总距离。感知机所采用的就是后者 。
在这里插入图片描述

  • 对于错误分类的数据点 ( x i , y i ) (\mathcal{x_i, y_i}) (xi,yi),总有:
    − y i ⋅ ( w ⋅ x i + b ) > 0 \mathcal{-y_i\cdot(w\cdot x_i+b) > 0} yi(wxi+b)>0
  • 错误分类点到直线 L \mathcal{L} L 的距离为:
    1 ∥ w ∥ ∣ w ⋅ x i + b ∣ \mathcal{\frac{1}{\|w\|} \vert w \cdot x_i +b \vert} w1wxi+b
  • 假设直线 L \mathcal{L} L 的误分类点集合为 m \mathcal{m} m , 那么所有误分类点到直线 L \mathcal{L} L 的总距离为:
    − 1 ∥ w ∥ ∑ i m y i ⋅ ( w ⋅ x i + b ) \mathcal{-\frac{1}{\|w\|} \sum_{i}^{m}y_i\cdot(w \cdot x_i +b)} w1imyi(wxi+b)
  • 不考虑 1 ∥ w ∥ \mathcal{\frac{1}{\|w\|}} w1,感知机的损失函数为:
    K ( w , b ) = − ∑ i m y i ⋅ ( w ⋅ x i + b ) \mathcal{K(w, b)= - \sum_{i}^{m}y_i\cdot(w \cdot x_i +b)} K(w,b)=imyi(wxi+b)

显然,损失函数 K \mathcal{K} K 是非负的。如果没有误分类点,损失函数值是 0 。而且,误分类点越少,误分类点离超平面越近,损失函数值就越小 。

而感知机的优化算法采用的是随机梯度下降算法 (Stochastic Gradient Descent)(后续更新),用误分类数据驱动损失函数 K \mathcal{K} K 不断减小。本文将采取二维数据,来展示感知机的工作过程。


2. 代码实现

2.1 构建数据

首先创建二维数据,并用线性回归模型拟合出直线 L \mathcal{L} L。代码如下:

import numpy as np
from sklearn.datasets import make_classification
from sklearn import linear_model
import matplotlib.pyplot as plt
import random# two-dimention data
td_data = make_classification(n_samples=20, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1, random_state=24)td_data = list(td_data)
td_data[1] = np.array([1 if i == 0 else -1 for i in td_data[1]])
td_data = tuple(td_data)# visualized data
fig, ax = plt.subplots()
scatter = ax.scatter(td_data[0][:, 0], td_data[0][:, 1], c=td_data[1], alpha=0.6, cmap="cool")
legend_1 = ax.legend(*scatter.legend_elements(), title="Classes", loc="upper left")
ax.add_artist(legend_1)
ax.set_xlabel("Feature_1")
ax.set_ylabel("Feature_2")# add minimal residual sum of squares line as gold standard
reg = linear_model.LinearRegression()# reshape for model fitting
reg.fit(td_data[0][:, 0].reshape(-1, 1), td_data[0][:, 1].reshape(-1, 1))
print(f"the intercept is {reg.intercept_[0]} and the coefficient is {reg.coef_[0][0]}")
formula = f"f(x)={round(reg.coef_[0][0], 2)}*x1-x2{round(reg.intercept_[0], 2)}"# create a x axis for plotting
create_x_axis = np.linspace(min(td_data[0][:, 0]), max(td_data[0][:, 0]), 100).reshape(-1, 1)
predicted_value = reg.predict(create_x_axis)ax.plot(create_x_axis, predicted_value, c="gold", alpha=0.8, label=formula)
handles, labels = ax.get_legend_handles_labels()
legend_2 = ax.legend(handles, labels, loc="lower right")plt.show()

根据代码输出的结果可知,由线性回归模型拟合出的直线 L = 0.53 x 1 + x 2 − 0.6 \mathcal{L = 0.53x_1+x_2-0.6} L=0.53x1+x20.6
在这里插入图片描述

2.2 编写函数

接下来编写可复用的函数,减少代码编写量。partial_derivative_w 函数用于对变量 w \mathcal{w} w 求偏导,partial_derivative_b 函数用于对变量 b \mathcal{b} b 求偏导,decision_funtion 函数用于决策是否继续进行迭代,plot_function 函数绘制迭代结果图。

# take the partial derivative of w and b
def partial_derivative_w(data_point, label_point):# feature_1 * feature_2 * yresult_w_1 = data_point[0] * label_pointresult_w_2 = data_point[1] * label_pointreturn [result_w_1, result_w_2]def partial_derivative_b(label_point):# labelresult_b = label_pointreturn result_b# decision function. w and b will be change if exist data point make 
def decision_funtion(weight_1, weigh_2, intercept):# if y*(w*x+b) < 0, the data point is wrongly classified.result = td_data[1] * ((td_data[0][:, 0] * weight_1) + (td_data[0][:, 1] * weight_2) + intercept)if len(result[np.where(result < 0)]) != 0:print(result)wrong_dp_index = np.where(result == result[np.where(result < 0)][0])[0][0]wrong_dp = td_data[0][wrong_dp_index]wrong_lb = td_data[1][wrong_dp_index]return [True, wrong_dp, wrong_lb]else:print("interation end")return [False, None, None]def plot_function(weight_1, weight_2, intercept):fig, ax = plt.subplots()scatter = ax.scatter(td_data[0][:, 0], td_data[0][:, 1], c=td_data[1], alpha=0.6)ax.legend(*scatter.legend_elements(), title="Classes")ax.set_xlabel("Feature_1")ax.set_ylabel("Feature_2")b = intercept/weight_2hyperplane = [(-(weight_1/weight_2) * i) - b for i in create_x_axis]ax.plot(create_x_axis, hyperplane, c='green', alpha=0.5)plt.show()

2.3 迭代

设置特征一的初始权重为 0,特征二的初始权重为 0,初始截距为 0,学习率为 0.1,迭代次数为1000次,随机从数据中选择一个数据点作为分类错误数据点后开始迭代。

# initiate weight, intercept and learning rate
weight_1 = 0
weight_2 = 0
intercept = 0
learn_rate = 0.1# iteration times
iteration_times = 1000# random value in two dimention data
random_index = random.randint(0, 19)
feature_point = td_data[0][random_index]
label_point = td_data[1][random_index]
# it is not correctly classified for any data point resulting in loss function equte 0.for iteration in range(iteration_times):# w1 = w0 + (learn_rate * y * x)new_weight_1 = weight_1 + (learn_rate * partial_derivative_w(feature_point, label_point)[0])new_weight_2 = weight_2 + (learn_rate * partial_derivative_w(feature_point, label_point)[1])# b1 = b0 + learn_rate * ynew_intercept = intercept + (learn_rate * partial_derivative_b(label_point))# decisiondecision_condition, wrong_dp, wrong_lp = decision_funtion(new_weight_1, new_weight_2, new_intercept)if decision_condition:weight_1 = new_weight_1weight_2 = new_weight_2intercept = new_intercept# wrong data pointfeature_point = wrong_dplabel_point = wrong_lpprint(f"The {iteration + 1} iteration\tweight_1={weight_1}\tweight_2={weight_2}\tintercept={intercept}\n")plot_function(weight_1, weight_2, intercept)else:print(f"The {iteration + 1} iteration\tweight_1={new_weight_1}\tweight_2={new_weight_2}\tintercept={new_intercept}\n")plot_function(new_weight_1, new_weight_2, new_intercept)break

迭代结果如下表所示,在迭代到第八次的时候,感知机模型成功将所有数据点正确分类。

迭代次数效果图片
1在这里插入图片描述
2在这里插入图片描述
3在这里插入图片描述
4在这里插入图片描述
5在这里插入图片描述
6在这里插入图片描述
7在这里插入图片描述
8在这里插入图片描述

3. 总结

以上就是本次更新的全部内容。关于本次内容有一下缺点:

  • 用于迭代的错误分类数据点没有被绘制出来。
  • 由于跳过了大部分数学知识,内容衔接没有做好。
  • 迭代数据点完全随机,复现过程可能不同。

后续将会更新:

  • 感知机模型的数学解释
  • 随机梯度算法的解释
  • 可视化迭代过程的错误分类数据点(可能)

喜欢本次内容的小伙伴麻烦👍点赞+👍关注。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/519412.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【R语言实战】聚类分析及可视化

&#x1f349;CSDN小墨&晓末:https://blog.csdn.net/jd1813346972 个人介绍: 研一&#xff5c;统计学&#xff5c;干货分享          擅长Python、Matlab、R等主流编程软件          累计十余项国家级比赛奖项&#xff0c;参与研究经费10w、40w级横向 文…

消息队列-Kafka-如何进行顺序消费

全局有序 只有 1 个分区&#xff0c;那这个时候就是能够保证消息的顺序消费。 分区有序 如果我们还是想同时消费多个分区并且保证有序&#xff0c;这个时候我们需要将需要保证顺序的消息路由到同一个分区。 在发送消息的时候我们可以看到&#xff1a; 上面的代码定义了消息…

1 数据分析概述与职业操守 (3%)

1、 EDIT数字化模型 E——exploration探索 &#xff08;是什么&#xff09; 业务运行探索&#xff1a;探索关注企业各项业务的运行状态、各项指标是否合规以及各项业务的具体数据情况等。 D——diagnosis 诊断 (为什么) 问题根源诊断&#xff1a;当业务指标偏离正常值时&…

C语言从入门到精通 第十二章(程序的编译及链接)

写在前面&#xff1a; 本系列专栏主要介绍C语言的相关知识&#xff0c;思路以下面的参考链接教程为主&#xff0c;大部分笔记也出自该教程。除了参考下面的链接教程以外&#xff0c;笔者还参考了其它的一些C语言教材&#xff0c;笔者认为重要的部分大多都会用粗体标注&#xf…

【C++】string类的基础操作

&#x1f497;个人主页&#x1f497; ⭐个人专栏——C学习⭐ &#x1f4ab;点击关注&#x1f929;一起学习C语言&#x1f4af;&#x1f4ab; 目录 导读 1. 基本概述 2. string类对象的常见构造 3. string类对象的容量操作 4. string类对象的访问及遍历操作 5. 迭代器 6.…

第五十二回 戴宗二取公孙胜 李逵独劈罗真人-飞桨AI框架安装和使用示例

吴用说只有公孙胜可以破法术&#xff0c;于是宋江请戴宗和李逵去蓟州。两人听说公孙胜的师傅罗真人在九宫县二仙山讲经&#xff0c;于是到了二仙山&#xff0c;并在山下找到了公孙胜的家。 两人请公孙胜去帮助打高唐州&#xff0c;公孙胜说听师傅的。罗真人说出家人不管闲事&a…

Milvus 向量数据库实践 - 1

假定你已经安装了docker、docker-compose 环境 参考的文档如下&#xff1a; Milvus技术探究 - 知乎 MilvusClient() - Pymilvus v2.3.x for Milvus 一文带你入门向量数据库milvus 一、在docker上安装单机模式milvus数据库 1、 进入milvus官网&#xff1a; Install Milvus Stand…

Lazada本土店与跨境店区别,附店铺防关联攻略

许多新手商家在初入跨境电商时&#xff0c;对于平台账号类别并不清楚。Lazada是最大的东南亚在线购物平台之一&#xff0c;如果你的跨境目标正指向东南亚&#xff0c;那么Lazada一定是是你的首选平台。那么接下来让小编带大家认识Lazada本土店与跨境店的区别&#xff01; 一、本…

根据标准化开发流程---解析LIN总线脉冲唤醒的测试方法和用例设计思路

前言&#xff1a;本文从标准化开发流程的角度&#xff0c;以LIN总线脉冲唤醒为切入点。从测试工程师的角度来讲测试工作应当如何展开&#xff08;结合我干测试总结出来的测试经验&#xff09;。希望大家都能从中有收获&#xff01;&#xff01;谢谢&#xff01;&#xff01; 1…

“揭秘网络握手与挥别:TCP三次握手和四次挥手全解析“

前言 在计算机网络中&#xff0c;TCP&#xff08;传输控制协议&#xff09;是一种重要的通信协议&#xff0c;用于在网络中的两台计算机之间建立可靠的连接并交换数据。TCP协议通过“三次握手”和“四次挥手”的过程来建立和终止连接&#xff0c;确保数据的准确传输。 一、三…

模拟实现std::string类(包含完整、分文件程序)

std库中的string是一个类&#xff0c;对string的模拟实现&#xff0c;既可以复习类的特性&#xff0c;也可以加深对std::string的理解。 &#x1f308;一、搭建框架 ☀️1.新命名空间 本质上string是一个储存在库std里面的类&#xff0c;现在需要模拟实现一个string类&#…

Scrapy与分布式开发(2.3):lxml+xpath基本指令和提取方法详解

lxmlxpath基本指令和提取方法详解 一、XPath简介 XPath&#xff0c;全称为XML Path Language&#xff0c;是一种在XML文档中查找信息的语言。它允许用户通过简单的路径表达式在XML文档中进行导航。XPath不仅适用于XML&#xff0c;还常用于处理HTML文档。 二、基本指令和提取…