Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能

前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。

实现代码

import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 设置超参数
learning_rate = 0.001
num_epochs = 100
batch_size = 32# 定义输入和输出的维度
input_dim = X.shape[1]
output_dim = len(set(y))# 定义权重和偏置项
W1 = tf.Variable(tf.random.normal(shape=(input_dim, 64), dtype=tf.float64))
b1 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W2 = tf.Variable(tf.random.normal(shape=(64, 64), dtype=tf.float64))
b2 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W3 = tf.Variable(tf.random.normal(shape=(64, output_dim), dtype=tf.float64))
b3 = tf.Variable(tf.zeros(shape=(output_dim,), dtype=tf.float64))# 定义前向传播函数
def forward_pass(X):X = tf.cast(X, tf.float64)h1 = tf.nn.relu(tf.matmul(X, W1) + b1)h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)logits = tf.matmul(h2, W3) + b3return logits# 定义损失函数
def loss_fn(logits, labels):return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate)# 定义准确率指标
accuracy_metric = tf.metrics.SparseCategoricalAccuracy()# 定义训练步骤
def train_step(inputs, labels):with tf.GradientTape() as tape:logits = forward_pass(inputs)loss_value = loss_fn(logits, labels)gradients = tape.gradient(loss_value, [W1, b1, W2, b2, W3, b3])optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2, W3, b3]))accuracy_metric(labels, logits)return loss_value# 进行训练
for epoch in range(num_epochs):epoch_loss = 0.0accuracy_metric.reset_states()for batch_start in range(0, len(X_train), batch_size):batch_end = batch_start + batch_sizebatch_X = X_train[batch_start:batch_end]batch_y = y_train[batch_start:batch_end]loss = train_step(batch_X, batch_y)epoch_loss += losstrain_loss = epoch_loss / (len(X_train) // batch_size)train_accuracy = accuracy_metric.result()print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}")# 进行评估
logits = forward_pass(X_test)
test_loss = loss_fn(logits, y_test)
test_accuracy = accuracy_metric(y_test, logits)print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

实现效果

本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘有一定认知和理解,会结合自身科研实践经历不定期分享关于python、机器学习、深度学习基础知识与案例。

致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。

邀请三个朋友关注V订阅号:数据杂坛,即可在后台联系我获取相关数据集和源码,送有关数据分析、数据挖掘、机器学习、深度学习相关的电子书籍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/151970.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

哈希算法:如何防止数据库中的用户信息被脱库?

文章来源于极客时间前google工程师−王争专栏。 2011年CSDN“脱库”事件,CSDN网站被黑客攻击,超过600万用户的注册邮箱和密码明文被泄露,很多网友对CSDN明文保存用户密码行为产生了不满。如果你是CSDN的一名工程师,你会如何存储用…

debian 10 安装apache2 zabbix

nginx 可以略过,改为apache2 apt updateapt-get install nginx -ynginx -v nginx version: nginx/1.14.2mysql 安装参考linux debian10 安装mysql5.7_debian apt install mysql5.7-CSDN博客 Install and configure Zabbix for your platform a. Install Zabbix re…

SpringCore完整学习教程5,入门级别

本章从第6章开始 6. JSON Spring Boot提供了三个JSON映射库的集成: Gson Jackson JSON-B Jackson是首选的和默认的库。 6.1. Jackson 为Jackson提供了自动配置,Jackson是spring-boot-starter-json的一部分。当Jackson在类路径上时,将自动配置Obj…

uniapp 中添加 vconsole

uniapp 中添加 vconsole 一、安装 vconsole npm i vconsole二、使用 vconsole 在项目的 main.js 文件中添加如下内容 // #ifdef H5 // 提交前需要注释 本地调试使用 import * as vconsole from "vconsole"; new vconsole() // 使用 vconsole // #endif三、成功

[17]JAVAEE-HTTP协议

目录 一、什么是HTTP协议 什么时候会用到HTTP协议? HTTP协议的工作流程 二、HTTP的报文格式 抓包 HTTP请求报文格式 1.首行 2.header 常见键值对: 3.空行 4.正文(body)(有的时候可以没有) HTTP…

数据分析和互联网医院小程序:提高医疗决策的准确性和效率

互联网医院小程序已经在医疗领域取得了显著的进展,为患者和医疗从业者提供了更便捷和高效的医疗服务。随着数据分析技术的快速发展,互联网医院小程序能够利用大数据来提高医疗决策的准确性和效率。本文将探讨数据分析在互联网医院小程序中的应用&#xf…

【Kotlin精简】第6章 反射

1 反射简介 反射机制是在运行状态中,对于任意一个类,都能够知道这个类的所有属性和方法,对于任意一个对象,都能够调用它的任意一个方法和属性。 1.1 Kotlin反射 我们对比Kotlin和Java的反射类图。 1.1.1 Kotlin反射常用的数据结…

Egg.js使用MySql数据库

最近在接手一个项目,vuenuxtegg,我也是刚开始学习egg.js,所以会将自己踩的坑都记录下来。 安装mysql 使用sequelize连接数据库,首先安装egg-sequelize和mysql2。 npm install --save egg-sequelize mysql2打开package.json文件…

(a /b)*c的值

系列文章目录 进阶的卡莎C++_睡觉觉觉得的博客-CSDN博客数1的个数_睡觉觉觉得的博客-CSDN博客双精度浮点数的输入输出_睡觉觉觉得的博客-CSDN博客足球联赛积分_睡觉觉觉得的博客-CSDN博客大减价(一级)_睡觉觉觉得的博客-CSDN博客小写字母的判断_睡觉觉觉得的博客-CSDN博客纸币(…

JavaWeb——关于servlet种mapping地址映射的一些问题

6、Servlet 6.4、Mapping问题 一个Servlet可以指定一个映射路径 <servlet-mapping><servlet-name>hello</servlet-name><url-pattern>/hello</url-pattern> </servlet-mapping>一个Servlet可以指定多个映射路径 <servlet-mapping>&…

DevOps持续集成-Jenkins(2)

文章目录 DevOpsDevOps概述Integrate工具&#xff08;centos7-jenkins主机&#xff09;Integrate概述Jenkins介绍CI/CD介绍Linux下安装最新版本的Jenkins⭐Jenkins入门配置安装必备插件⭐安装插件&#xff08;方式一&#xff1a;可能有时会下载失败&#xff09;安装插件&#x…

SpringCloudGateway 入门

目录 POM 依赖一、内容网关的作用Spring-Cloud-Gateway的核心概念 二、基于Ribbon的负载均衡三、核心概念详细3.1 断言 Predicate3.2 过滤器3.2.1 内置过滤器3.2.2 自定义过滤器构造器&#xff08;原理&#xff09;资源结构Route / Predicate 的构造器构造器的增强器整体协同关…