Fisher线性判别分析

Fisher线性判别分析        

        原理

        LDA(Linear Discriminant Analysis)是一种经典的线性判别方法,又称Fisher判别分析。该方法思想比较简单:给定训练集样例,设法将样例投影到一维的直线上,使得同类样例的投影点尽可能接近和密集,异类投影点尽可能远离。

        Fisher线性判别分析主要包括两个目标:

  1. 最大化类间方差(Maximize Between-Class Variance): 通过找到一个投影方向,使得不同类别的样本在投影后的均值之间的距离最大。这确保了不同类别在投影空间中有明显的差异。

  2. 最小化类内方差(Minimize Within-Class Variance): 在类间方差最大的同时,还要保证每个类别内部的样本在投影后尽量聚集在一起,即类内方差最小。

        通过这两个目标,Fisher线性判别分析产生了一个投影方向,可以将原始数据映射到一个低维空间,同时保留类别之间的差异。这个投影方向通常可以用一个权重向量(投影向量)表示。

        在实际应用中,Fisher线性判别分析经常用于模式识别、人脸识别、图像处理等领域,特别是在处理具有多个类别的分类问题时,它可以提供较好的分类性能。与主成分分析(PCA)不同,Fisher线性判别分析是有监督的降维方法,因为它利用了类别信息来优化投影方向。

Python代码 

        详见注释

        

import numpy as np
import matplotlib.pyplot as plt# 读取数据,并根据类别分类
def readdata(filename):fr = open(filename)numberOfLines = len(fr.readlines())  # 获取数据行数data = np.zeros((numberOfLines, 2))label = []fr = open(filename)index = 0# 该函数readdata以文件名作为输入,并从文件中读取数据。# 它初始化一个数组data以存储数据点,以及一个列表label以存储相应的标签。for line in fr.readlines():line = line.strip()listFromLine = line.split()data[index, 0] = float(listFromLine[0])data[index, 1] = float(listFromLine[1])label.append(float(listFromLine[-1]))index += 1# 遍历文件中的每一行。使用strip()去除行首和行尾的空格,# 使用split()将行分割成一个值列表,然后将前两个值转换为浮点数,存储在data数组中。# 最后一个值也被转换为浮点数并附加到label列表中。# 分类index1 = np.array([index for (index, value) in enumerate(label) if value == -1.0])index2 = np.array([index for (index, value) in enumerate(label) if value == 1.0])data0 = data[index1]data1 = data[index2]# 在读取所有数据点之后,它通过基于标签筛选data来创建两个数组data0和data1。# data0包含标签为-1.0的点,# data1包含标签为1.0的点。return data0, data1def calculatesi(datai, ui):si = np.zeros((datai.shape[1], datai.shape[1]))# 这一行创建了一个形状为(datai.shape[1], datai.shape[1])的零矩阵,# 并将其赋给变量si。这个矩阵将用于存储协方差矩阵。for xi in datai:# 这一行开始一个循环,遍历datai中的每个数据点,将每个数据点表示为xi。m = xi - ui# 这一行计算了数据点xi与均值向量ui之间的差异,将结果存储在变量m中。si += m * m.reshape(2, 1)# 这一行更新协方差矩阵si。它将矩阵m与其转置相乘,并将结果累加到si上。# m.reshape(2, 1)是为了确保矩阵乘法的维度匹配。return sidef fish(data0, data1):# 计算两数据集data0和data1的均值向量u0和u1,# 通过np.mean函数计算每个特征的平均值。u0 = np.mean(data0, axis=0)u1 = np.mean(data1, axis=0)# 计算类内离散度矩阵si# 调用calculatesi函数,该函数用于计算协方差矩阵。# 分别对data0和data1使用均值向量u0和u1计算了类内离散度矩阵si。s0 = calculatesi(data0, u0)s1 = calculatesi(data1, u1)# 总类内离散度矩阵# 这一行计算了总的类内离散度矩阵sw,# 将两个类内离散度矩阵s0和s1相加。sw = s0 + s1# 求逆# 使用np.linalg.inv函数计算总类内离散度矩阵sw的逆矩阵,sw_inv = np.linalg.inv(sw)# 计算投影w# 将总类内离散度矩阵的逆矩阵sw_inv与均值向量差异(u0 - u1)相乘得到。w = np.dot(sw_inv, (u0 - u1))w0 = (np.dot(w.T, u0) + np.dot(w.T, u0)) / 2return w, u0, u1def judge(filename, w, u0, u1):# 读取数据# 打开文件filename,获取文件行数,# 初始化一个大小为(numberOfLines, 2)的零数组test_data,以及一个空列表label。# 接着,通过循环读取文件的每一行,将每行的数据提取出来,转换为浮点数,并存储到test_data数组中。fr = open(filename)numberOfLines = len(fr.readlines())  # 获取数据行数test_data = np.zeros((numberOfLines, 2))label = []fr = open(filename)index = 0for line in fr.readlines():line = line.strip()listFromLine = line.split()test_data[index, 0] = float(listFromLine[0])test_data[index, 1] = float(listFromLine[1])index += 1# 判断类别# 计算投影后的数据点在投影向量w上的位置,并根据其与两个类的中心的距离来判断类别。# 如果点到类别0的中心的距离小于点到类别1的中心的距离,则将类别标签设为-1.0,否则设为1.0。center_0 = np.dot(w.T, u0)center_1 = np.dot(w.T, u1)for s in test_data:y = np.dot(w.T, s)if abs(y - center_0) < abs(y - center_1):label.append(-1.0)else:label.append(1.0)# 分类# 根据类别标签将数据点分成两个数组test_data0和test_data1,# 分别包含属于类别-1.0和1.0的数据点,并将它们作为函数的返回值。index1 = np.array([index for (index, value) in enumerate(label) if value == -1.0])index2 = np.array([index for (index, value) in enumerate(label) if value == 1.0])test_data0 = test_data[index1]test_data1 = test_data[index2]return test_data0, test_data1def draw(data0, data1, w):plt.scatter(data0[:, 0], data0[:, 1], c='red', marker='x')plt.scatter(data1[:, 0], data1[:, 1], c='blue', marker='x')plt.show()if __name__ == '__main__':# 读取数据集并根据数据类别分类data0, data1 = readdata("train_data.txt")# 计算最佳投影ww, u0, u1 = fish(data0, data1)# 判断测试集test_data0, test_data1 = judge("test_data.txt", w, u0, u1)# 绘图draw(test_data0, test_data1, w)

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/440102.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mybatis-Plus扩展

7 MybatisX插件[扩展] 7.1 MybatisX插件介绍 MybatisX 是一款基于 IDEA 的快速开发插件&#xff0c;为效率而生。 安装方法&#xff1a;打开 IDEA&#xff0c;进入 File -> Settings -> Plugins -> Browse Repositories&#xff0c;输入 mybatisx 搜索并安装。 功…

负载均衡下Webshell连接思路及难点

君衍. 一、应用场景二、环境搭建三、思路以及难点1、查看内部结构2、查看webshell3、使用蚁剑进行连接4、难点1 shell文件上传问题5、难点2 命令执行时飘逸6、难点3 大工具上传失败7、难点4 脚本失效 四、解决方式1、关闭对方节点服务器2、基于IP地址判断是否执行3、脚本实现流…

c#窗体捕捉方向键

方法1 实现方法参考代码&#xff1a; private void Form1_Load(object sender, EventArgs e){this.KeyPreview true;}protected override bool ProcessDialogKey(Keys keyData){if (keyData Keys.Left || keyData Keys.Right || keyData Keys.Up || keyData Keys.Down){s…

Linux下安装edge

edge具有及其强大的功能&#xff0c;受到很多人的喜爱&#xff0c;它也开发Linux版本&#xff0c;下面是安装方法&#xff1a; 1.去edge官网下载Linux(.deb)文件。 https://www.microsoft.com/zh-cn/edge/download?formMA13FJ 2.下载之后输入以下指令&#xff08;后面是安装…

【计算机网络】——TCP协议

&#x1f4d1;前言 本文主要是【计算机网络】——传输层TCP协议的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是青衿&#x1f947; ☁️博客首页&#xff1a;CSDN主页放风讲故事 &#x1f304;每日一句…

【2023地理设计组一等奖】基于GIS的桥梁隧道三维建模与可视化

作品介绍 1 设计背景和意义 随着我国基础建设规模不断扩大和深入,构建桥梁可视化管理模型,全面推动智慧桥梁,已成为现代隧道桥梁建设行业的发展趋势。传统的桥梁建模工作需要复杂的算法设计并需要熟练编程实践技能,实现周期长。开发自主知识版权的桥梁建模软件系统或专用插…

AI嵌入式K210项目(23)-人脸检测

文章目录 前言一、实验准备二、实验步骤总结 前言 本章使用预训练好的模型&#xff0c;进行人脸检测&#xff0c;将摄像头采集的画面分析&#xff0c;比对模型&#xff0c;如果有人脸则框出来&#xff0c;并打印相关信息。 一、实验准备 请先将模型文件导入内存卡上&#xf…

装机打不开BIOS怎么办?如何进入Windows10的BIOS页面,如何关闭快速启动

电脑有快速启动&#xff0c;想进去BIOS页面非常困难&#xff0c;在临开机的页面&#xff0c;按触发按键不管用。 然后我看到了一种新的进入BIOS的方式&#xff1a; &#xff08;1&#xff09;win8以上的系统&#xff0c;按住shift&#xff0c;然后鼠标点击重启&#xff0c;再…

Softmax分类器

文章目录 回顾使用Sigmoid构建多分类器&#xff1f; SoftMax函数交叉熵损失函数例子 MINIST多分类器数据集步骤实现1.数据集2.构建模型3.构建损失函数和优化器4. 训练和测试 完整代码 回顾 上节课利用糖尿病数据集做了二分类任务 MNIST数据集有10个类别我们又该如何进行分类呢…

java反射常用方法

反射思维导图 使用案例 package Reflection.Work.WorkTest01;import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.Arrays;public class WorkDe…

基于数字签名技术的挑战/响应式认证方式

挑战/响应式认证方式简便灵活&#xff0c;实现起来也比较容易。当网络需要验证用户身份时&#xff0c;客户端向服务器提出登录请求&#xff1b;当服务器接收到客户端的验证请求时&#xff0c;服务器端向客户端发送一个随机数&#xff0c;这就是这种认证方式的“冲击&#xff08…

java学习之路(2)-编译java文件运行Java文件

创建.java后缀文本文件HelloWorld .java 写入代码&#xff1a; public class HelloWorld { public static void main(String []args) { System.out.println("Hello World"); } } 运行cmd命令 找到代码所在目录 输入javac编译Java文件生成HelloWorld.class 编译:…