Tensorflow2.0笔记 - 不使用layer方式,简单的MNIST训练

        本笔记不使用layer相关API,搭建一个三层的神经网络来训练MNIST数据集。

        前向传播和梯度更新都使用最基础的tensorflow API来做。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import numpy as npdef load_mnist():path = r'./mnist.npz' #放置mnist.py的目录。注意斜杠f = np.load(path)x_train, y_train = f['x_train'], f['y_train']x_test, y_test = f['x_test'], f['y_test']f.close()return (x_train, y_train), (x_test, y_test)#加载mnist数据集
#X_train: [60000, 28, 28] 图片
#Y_train: [60000] 标签
#mnist数据集下载:https://blog.csdn.net/charles_neil/article/details/107851880
#                https://www.zhihu.com/question/56773355
(X_train,Y_train),(X_test,Y_test) = load_mnist()#转换为tensor
#图片数据值转换到0-1
x = tf.convert_to_tensor(X_train, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(Y_train, dtype=tf.int32)
print(x.shape,y.shape)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))#数据集切分为多个batch
train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter = iter(train_db)sample = next(train_iter)
print(sample[0].shape, sample[1].shape)#学习率
lr = 0.1
#用三个神经元,[b:784] => [b,256] => [b,128] => [b,10]
w1 = tf.Variable(tf.random.truncated_normal([784,256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256,128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128,10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))for epoch in range(10):print("[==================Epoch ", epoch, "========================]")for step, (x,y) in enumerate(train_db):x = tf.reshape(x, [-1, 28*28])#对标签进行onehot编码y_onehot = tf.one_hot(y, depth=10)with tf.GradientTape() as tape:#第一层,输入x [128,784]#x@w + b: [batch, 784] [784,256] + [256] => [batch,256]h1 = x@w1 + b1h1 = tf.nn.relu(h1)#第二层:[batch, 256] => [batch, 128]h2 = h1@w2 + b2h2 = tf.nn.relu(h2)#输出层:[batch,128] => [batch,10]out = h2@w3 + b3#计算损失#使用MSE: mean(sum(y - out)^2)loss = tf.reduce_mean(tf.square(y_onehot - out))#计算梯度grads = tape.gradient(loss, [w1,b1,w2,b2,w3,b3])#更新w和b: w = w - lr * w_gradw1.assign_sub(lr * grads[0])b1.assign_sub(lr * grads[1])w2.assign_sub(lr * grads[2])b2.assign_sub(lr * grads[3])w3.assign_sub(lr * grads[4])b3.assign_sub(lr * grads[5])if (step % 100 == 0):print("Batch:", step, "loss:", float(loss))

        运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/416349.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个简单的Web程序(详解创建一个Flask项目后自带的一个简单的Web程序)

程序代码截图如下: 1.应用初始化 在创建 Flask 程序时,通常需要先创建一个应用实例进行应用初始化。 from flask import Flask # 应用的初始化 app Flask(__name__) 上述代码中,使用 Flask 类创建了一个应用实例 app。 __name__ 参数用…

Harbor离线安装

下载安装包 $ wget https://github.com/goharbor/harbor/releases/download/v2.7.4/harbor-offline-installer-v2.7.4.tgz解压 $ tar xvf harbor-offline-installer-v2.7.4.tgz -C /usr/local修改配置 $ cd /usr/local/harbor $ cp harbor.yml.tmpl harbor.yml $ vim harbo…

Web自动化测试中的接口测试

1、背景 1.1 Web程序中的接口 1.1.1 典型的Web设计架构 web是实现了基于网络通信的浏览器客户端与远程服务器进行交互的应用,通常包括两部分:web服务器和web客户端。web客户端的应用有html,JavaScript,ajax,flash等&am…

驾驭车联网的力量:深入车联网网络架构

车联网,作为移动互联网之后的新风口,以网联思想重新定义汽车,将其从简单的出行工具演化为个人的第二空间。车联网涵盖智能座舱和自动驾驶两大方向,构建在网联基础上,犀思云多年深度赋能汽车行业,本文将从车…

鸿蒙HarmonyOS实战-ArkTS语言(基本语法)

🚀一、ArkTS语言基本语法 🔎1.简介 HarmonyOS的ArkTS语言是一种基于TypeScript开发的语言,它专为HarmonyOS系统开发而设计。ArkTS语言结合了JavaScript的灵活性和TypeScript的严谨性,使得开发者能够快速、高效地开发出高质量的Har…

Kali在Vmware无法连接到网络,配置网络及解决办法

一.问题描述: 打开 Kali,无法连接到网络,虚拟机配置正常的。 尝试 ping 百度,出错: ping baidu.com 提示: ping: baidu.com: Temporary failure in name resolution二.解决办法: 1.首先在vmwa…

操作系统-操作系统的运行机制(内核程序 应用程序 特权指令 非特权指令 内核态 用户态 变态)

文章目录 总览预备知识:程序是如何运行的?内核程序vs应用程序特权指令vs非特权指令内核态vs用户态用户态,内核态的切换小结 总览 预备知识:程序是如何运行的? 转换为机器码放入内存,然后按顺序执行 内核…

三极管这个功能比“放大”还常用?

同学们大家好,今天我们继续学习杨欣的《电子设计从零开始》,这本书从基本原理出发,知识点遍及无线电通讯、仪器设计、三极管电路、集成电路、传感器、数字电路基础、单片机及应用实例,可以说是全面系统地介绍了电子设计所需的知识…

策略模式在工作中的运用

前言 在不同的场景下,执行不同的业务逻辑,在日常工作中是很寻常的事情。比如,订阅系统。在收到阿里云的回调事件、与收到AWS的回调事件,无论是收到的参数,还是执行的逻辑都可能是不同的。为了避免,每次新增…

UG机械制图的基本常识

目前来说工程图就是传递产品信息的工具,所以图纸一定不能出错,因为所有的设计都要转化为生产的输入。 一张完整的工程图应由图框,图素,尺寸标注以及技术要求这四部分组成, 图框包括图纸幅面:A0,A1,A2,A3,…

静态路由高级特性(HCIA)

目录 一、静态路由高级特性 1、路由条目六要素 2、路由分类 3、静态路由配置命令 (1)静态路由中下一跳MA和P2P区别 4、静态路由加路由表条件 5、permanent特性 二、路由冗余和负载 1、控制层面control plane 2、数据层面data plane 路由操控精髓&#xf…

Nginx——强化基础配置

1、牢记Context Context是Nginx中每条指令都会附带的信息,用来说明指令在哪个指令块中使用,可以将Context 理解为配置环境。 每个指令都拥有自己的配置环境,如果把配置环境记错了,或者在设计时未考虑配置环境的作用,…