机器学习实验五:BP 神经网络算法实现与测试

实验五:BP 神经网络算法实现与测试

一、实验目的

深入理解 BP 神经网络的算法原理,能够使用 Python 语言实现 BP 神经网络的训练与测试,并且使用五折交叉验证算法进行模型训练与评估。

 

二、实验内容

1)从 scikit-learn 库中加载 iris 数据集,使用留出法留出 1/3 的样本作为测试集(注意同分布取样);

2)使用训练集训练 BP 神经网络分类算法;

3)使用五折交叉验证对模型性能(准确度、精度、召回率和 F1 值)进行评估和选择;

4)使用测试集,测试模型的性能,对测试结果进行分析,完成实验报告中实验五的部分。

 

 

三、算法步骤、代码、及结果

   1. 算法伪代码

初始化神经网络参数(权重和偏置)  

定义激活函数(sigmoidReLU等)  

重复以下步骤直到收敛:  

    前向传播:  

        1. 输入层接收输入数据  

        2. 计算隐藏层输出  

        3. 计算输出层结果  

    计算损失(成本函数)  

    反向传播:  

        1. 计算输出层误差  

        2. 计算隐藏层误差  

        3. 更新权重和偏置  

使用交叉验证法评估模型性能

   2. 算法主要代码

完整源代码\调用库方法(函数参数说明)

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# 加载 iris 数据集
iris = load_iris()
X = iris.data
y = iris.target

# 留出法划分训练集和测试集,1/3 作为测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3, random_state=42, stratify=y)

# 定义 BP 神经网络分类器
bp_clf = MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42)

# 使用训练集训练模型
bp_clf.fit(X_train, y_train)

# 五折交叉验证评估模型性能
cv_scores = cross_val_score(bp_clf, X_train, y_train, cv=5, scoring='accuracy')
print("五折交叉验证准确度:", np.round(cv_scores, 4))
print("平均准确度:", np.round(np.mean(cv_scores), 4))

y_pred_train = bp_clf.predict(X_train)
precision_train = precision_score(y_train, y_pred_train, average='macro')
recall_train = recall_score(y_train, y_pred_train, average='macro')
f1_train = f1_score(y_train, y_pred_train, average='macro')
print("训练集精度:", np.round(precision_train, 4))
print("训练集召回率:", np.round(recall_train, 4))
print("训练集 F1 值:", np.round(f1_train, 4))

# 使用测试集测试模型性能
y_pred_test = bp_clf.predict(X_test)
accuracy_test = accuracy_score(y_test, y_pred_test)
precision_test = precision_score(y_test, y_pred_test, average='macro')
recall_test = recall_score(y_test, y_pred_test, average='macro')
f1_test = f1_score(y_test, y_pred_test, average='macro')
print("测试集准确度:", np.round(accuracy_test, 4))
print("测试集精度:", np.round(precision_test, 4))
print("测试集召回率:", np.round(recall_test, 4))
print("测试集 F1 值:", np.round(f1_test, 4))

# 分析测试结果
if accuracy_test > 0.8:
    print("模型性能良好,在测试集上有较高的准确度。")
else:
    print("模型性能有待提高,可尝试调整模型参数或增加训练数据。")

 

调用库方法

 

1. load_iris

加载 Iris 数据集。

from sklearn.datasets import load_iris  

 

参数:

return_X_y: 如果为 True,返回特征和目标。如果为 False,返回一个包含数据的对象(默认值为 False

返回值:

返回一个包含特征和目标的对象,通常通过 iris.data iris.target 获取。

 

2. train_test_split

将数据随机划分为训练集和测试集。

from sklearn.model_selection import train_test_split  

 

参数:

test_size: 测试集占比(0-1之间的小数,或具体数目)。

random_state: 随机种子(确保划分可重现)。

stratify: 按类别比例划分(确保训练集和测试集类别分布一致)。

 

返回值:

返回划分后的训练数据和测试数据。

 

3. fit

用法: clf.fit(X_train, y_train)

作用: 训练模型。

 

4. cross_val_score

 

用法: cross_val_score(estimator, X, y, cv, scoring)

 

参数:

estimator: 需要评估的模型。

X: 特征数据。

y: 类别标签。

cv: 交叉验证的折数。

scoring: 评估指标(如准确率、精确率)。

 

  1. MLPClassifier()

BP 神经网络分类器

MLPClassifier(hidden_layer_sizes, max_iter, random_state)

 

参数:

hidden_layer_sizes: 隐藏层的层数和每层的神经元数,如 (100,) 表示有一层隐藏层,包含100个神经元。

max_iter: 最大迭代次数。

random_state: 随机数种子,确保结果可重复。

 

 

6. accuracy_score

计算模型在给定数据上的准确度。

from sklearn.metrics import accuracy_score  

 

参数:

y_true: 真实标签。

y_pred: 预测标签。

 

返回值:

返回预测准确率(在 0 1 之间的小数)。

 

7. classification_report

生成分类绩效的详细报告。

from sklearn.metrics import classification_report  

 

参数:

y_true: 真实标签。

y_pred: 预测标签。

target_names: 可选,类标签名称的列表,以便于输出可读性。

 

返回值:

返回一个字符串,包含每个类的精确率、召回率和 F1 值。

 

 

   3. 训练结果截图(包括:准确率、精度(查准率)、召回率(查全率)、F1

 

 

四、实验结果分析

1. 测试结果截图(包括:准确率、精度(查准率)、召回率(查全率)、F1

 

 

2. 对比分析

模型在 iris 数据集上的表现非常优秀。五折交叉验证显示大多数折的准确率为1,平均准确度为0.98,说明模型在训练集上学习得很好。在训练集上,精度、召回率和 F1 值均接近0.99,表明模型几乎完美地预测了样本。测试集的准确度同样为0.98,其他评估指标也很高,显示出模型在未见数据上的良好泛化能力。尽管结果令人满意,但仍需注意过拟合的风险。总体而言,模型在这个数据集上的表现非常出色,值得进一步测试和应用。

 

五、心得体会

 

通过本次实验,我深入理解了 BP 神经网络的算法原理,包括前向传播、损失函数计算、反向传播和权重更新等关键步骤。使用 TensorFlow 框架,我成功实现了 BP 神经网络的训练与测试,并通过五折交叉验证对模型性能进行了全面评估。

在实验过程中,我遇到了一些挑战。首先,BP 神经网络的超参数选择(如隐藏层神经元数量、学习率、迭代次数等)对模型性能有很大影响,需要通过实验进行调优。其次,由于 iris 数据集较小且特征维度较低,BP 神经网络可能容易过拟合,因此需要注意正则化方法和早停策略的应用。

通过五折交叉验证,我得到了模型在不同训练集和验证集上的性能表现,这有助于我更全面地了解模型的泛化能力。测试集上的结果也验证了模型的有效性。

总的来说,这次实验不仅提高了我的编程能力,还加深了我对 BP 神经网络的理解和应用能力。我认识到,在实际应用中,除了算法本身外,数据预处理、特征选择和模型调优等方面也同样重要。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/856394.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM专题学习之类加载器(二)

类加载器 三层类加载器 1.启动类加载器-BootstrapClassLoader AppClassLoader负责加载核心类,存放在lib目录下的jar包或class文件。 2.扩展类加载器-ExtensionClassLoader ExtensionClassLoader负责加载\lib\ext目录下的jar包或class文件,我们可以将通用性的功能,打成jar包放…

2024-2025-1 20241417 《计算机基础与程序设计》第十三周学习总结

2024-2025-1 20241417 《计算机基础与程序设计》第十三周学习总结 作业信息这个作业属于哪个课程 <班级的链接>(如2024-2025-1-计算机基础与程序设计)这个作业要求在哪里 <作业要求的链接>2024-2025-1计算机基础与程序设计第十三周作业这个作业的目标 <复习前…

28.Python基础篇-logging模块

介绍: logging 模块是Python内置的强大日志记录工具,支持多种输出方式、格式化选项及多进程支持。 日志的级别 logging 模块有五个内置的日志级别,从低到高:DEBUG:详细信息,用于诊断问题。 INFO:常规信息,表示程序正常运行的状态。 WARNING:警告信息,表示潜在问题或即…

Redis安装配置

安装gcc环境sudo yum install -y gcc-c++查看gcc环境gcc -v

我们的电视Our tv 3.6.0安卓+TV 一款全新电视直播软件-内置稳定直播源

应用简介 我们的电视(ourtv)是一款完全无广告的电视直播软件,清晰度可选择高清,超清,蓝光等播放。安装即可使用,再也不用费劲去找各种不稳定的直播源了。 “我们的电视”播放线路(直播源)是来自央视频,因此画质和稳定性还可以。不过随之而来的问题是跟央视频 App 不兼…

[HTML/Web] HTML5之`Video`元素

概述:video 元素 核心属性:playbackRate/播放速率在HTML5中,<video> 元素提供了一个 playbackRate 属性,可以用来设置视频的播放速度。这个属性允许你设置视频的倍速播放,比如正常速度、慢速或快速。以下是如何设置 <video> 元素的倍速播放:html<video id…

鸿蒙HarmonyOS应用开发 | HarmonyOS Next-从应用开发到上架全流程解析

HarmonyOS Next-从应用开发到上架全流程解析 随着智能设备的不断普及,操作系统的竞争变得愈加激烈。在这个背景下,华为推出的HarmonyOS(鸿蒙操作系统)逐渐崭露头角,成为一个引人注目的新兴平台。本文将深入探讨HarmonyOS Next的应用开发流程,并特别关注鸿蒙应用上架的全过…

2024-2025-1 20241307《计算机基础与程序设计》第十三周学习总结

作业信息这个作业属于哪个课程 (2024-2025-1-计算机基础与程序设计)这个作业要求在哪里 ([2024-2025-1计算机基础与程序设计第十三周作业]这个作业的目标作业正文 (2024-2025-1 学号20241307《计算机基础与程序设计》第十三周学习总结)教材学习内容总结 C语言程序设计第十二…

移动端笔记应用,markdown应用选用

要求不能有广告。作为使用频率较高的软件,有广告就是恶心人。 支持markdown,包括且不限于代码块、标题、图片等格式。 支持同步,至少拥有WebDav云同步,或者本地导入导出。 全局搜索功能。以上功能必须免费,至少我不明白导入导出有什么好付费的。云同步这种付费理所当然。背…

一个.NET开源、易于使用的屏幕录制工具

前言 一款高效、易用的屏幕录制工具能够极大地提升我们的工作效率和用户体验,今天大姚给大家分享一个.NET开源、免费、易于使用的屏幕录制工具:Captura。 工具介绍 Captura是一款基于.NET开源、免费、易于使用的屏幕录制、截图工具,允许用户录制屏幕活动、捕获屏幕截图、录制…

CDN信息收集

引子:这篇是对架构信息收集中CDN部分的补充,由于Web应用先得注册域名才能使用CDN服务,而我国境内的域名注册需先要备案。又因为笔者目前并没有这方面的需求,因此本文仅简单介绍该如何识别CDN,以及一些常见的CDN绕过方式。免责声明:本文章仅用于交流学习,因文章内容而产生…

20结构伪类-borderz制图-网络字体-字体图标

一、结构伪类-:nth-child 在一些特殊的场景使用结构伪类还是非常方便的。 是真正有用的东西。 之前使用最主要的东西是nth-child() :nth-child(1)这个是选择父元素中的第一个子元素如果是下图这样就不能选中了。这里需要使用另外一个东西,叫做:nth-of-type()用这个东西可以选择…