Python深度学习基于Tensorflow(3)Tensorflow 构建模型

文章目录

        • 数据导入和数据可视化
        • 数据集制作以及预处理
        • 模型结构
        • 低阶 API 构建模型
        • 中阶 API 构建模型
        • 高阶 API 构建模型
        • 保存和导入模型

这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。

这里以CIFAR-10为数据集,CIFAR-10为小型数据集,一共包含10个类别的 RGB 彩色图像:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图像的尺寸为 32×32(像素),3个通道 ,数据集中一共有 50000 张训练圄片和 10000 张测试图像。CIFAR-10数据集有3个版本,这里使用Python版本。

数据导入和数据可视化

这里不用书中给的CIFAR-10数据,直接使用TensorFlow自带的玩意导入数据,可能需要魔法,其实TensorFlow中的数据特别的经典。

![[Pasted image 20240506194103.png]]

接下来导入cifar10数据集并进行可视化展示

import matplotlib.pyplot as plt
import tensorflow as tf(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape
# ((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))index_name = {0:'airplane',1:'automobile',2:'bird',3:'cat',4:'deer',5:'dog',6:'frog',7:'horse',8:'ship',9:'truck'
}def plot_100_img(imgs, labels):fig = plt.figure(figsize=(20,20))for i in range(10):for j in range(10):plt.subplot(10,10,i*10+j+1)plt.imshow(imgs[i*10+j])plt.title(index_name[labels[i*10+j][0]])plt.axis('off')plt.show()plot_100_img(x_test[:100])

![[Pasted image 20240506200312.png]]

数据集制作以及预处理

数据集预处理很简单就能实现,直接一行代码。

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))# 提取出一行数据
# train_data.take(1).get_single_element()

这里接着对数据预处理操作,也很容易就能实现。

def process_data(img, label):img = tf.cast(img, tf.float32) / 255.0return img, labeltrain_data = train_data.map(process_data)# 提取出一行数据
# train_data.take(1).get_single_element()

这里对数据还有一些存储和提取操作

dataset 中 shuffle()、repeat()、batch()、prefetch()等函数的主要功能如下。
1)repeat(count=None) 表示重复此数据集 count 次,实际上,我们看到 repeat 往往是接在 shuffle 后面的。为何要这么做,而不是反过来,先 repeat 再 shuffle 呢? 如果shuffle 在 repeat 之后,epoch 与 epoch 之间的边界就会模糊,出现未遍历完数据,已经计算过的数据又出现的情况。
2)shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 表示将数据打乱,数值越大,混乱程度越大。为了完全打乱,buffer_size 应等于数据集的数量。
3)batch(batch_size, drop_remainder=False) 表示按照顺序取出 batch_size 大小数据,最后一次输出可能小于batch ,如果程序指定了每次必须输入进批次的大小,那么应将drop_remainder 设置为 True 以防止产生较小的批次,默认为 False。
4)prefetch(buffer_size) 表示使用一个后台线程以及一个buffer来缓存batch,提前为模型的执行程序准备好数据。一般来说,buffer的大小应该至少和每一步训练消耗的batch数量一致,也就是 GPU/TPU 的数量。我们也可以使用AUTOTUNE来设置。创建一个Dataset便可从该数据集中预提取元素,注意:examples.prefetch(2) 表示将预取2个元素(2个示例),而examples.batch(20).prefetch(2) 表示将预取2个元素(2个批次,每个批次有20个示例),buffer_size 表示预提取时将缓冲的最大元素数返回 Dataset。

![[Pasted image 20240506201344.png]]

最后我们对数据进行一些缓存操作

learning_rate = 0.0002
batch_size = 64
training_steps = 40000
display_step = 1000AUTOTUNE = tf.data.experimental.AUTOTUNE
train_data = train_data.map(process_data).shuffle(5000).repeat(training_steps).batch(batch_size).prefetch(buffer_size=AUTOTUNE)

目前数据准备完毕!

模型结构

模型的结构如下,现在使用低阶,中阶,高阶 API 来构建这一个模型

![[Pasted image 20240506202450.png]]

低阶 API 构建模型
import matplotlib.pyplot as plt
import tensorflow as tf## 定义模型
class CustomModel(tf.Module):def __init__(self, name=None):super(CustomModel, self).__init__(name=name)self.w1 = tf.Variable(tf.initializers.RandomNormal()([32*32*3, 256]))self.b1 = tf.Variable(tf.initializers.RandomNormal()([256]))self.w2 = tf.Variable(tf.initializers.RandomNormal()([256, 128]))self.b2 = tf.Variable(tf.initializers.RandomNormal()([128]))self.w3 = tf.Variable(tf.initializers.RandomNormal()([128, 64]))self.b3 = tf.Variable(tf.initializers.RandomNormal()([64]))self.w4 = tf.Variable(tf.initializers.RandomNormal()([64, 10]))self.b4 = tf.Variable(tf.initializers.RandomNormal()([10]))def __call__(self, x):x = tf.cast(x, tf.float32)x = tf.reshape(x, [x.shape[0], -1])x = tf.nn.relu(x @ self.w1 + self.b1)x = tf.nn.relu(x @ self.w2 + self.b2)x = tf.nn.relu(x @ self.w3 + self.b3)x = tf.nn.softmax(x @ self.w4 + self.b4)return x
model = CustomModel()## 定义损失
def compute_loss(y, y_pred):y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)return tf.reduce_mean(loss)## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义准确率
def compute_accuracy(y, y_pred):correct_pred = tf.equal(tf.argmax(y_pred, axis=1), tf.cast(tf.reshape(y, -1), tf.int64))correct_pred = tf.cast(correct_pred, tf.float32)return tf.reduce_mean(correct_pred)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)accuracy = compute_accuracy(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))return loss.numpy(), accuracy.numpy()## 开始训练loss_list, acc_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):loss, acc = train_one_epoch(batch_x, batch_y)loss_list.append(loss)acc_list.append(acc)if i % 10 == 0:print(f'第{i}次训练->', 'loss:' ,loss, 'acc:', acc)
中阶 API 构建模型
## 定义模型
class CustomModel(tf.Module):def __init__(self):super(CustomModel, self).__init__()self.flatten = tf.keras.layers.Flatten()self.dense_1 = tf.keras.layers.Dense(256, activation='relu')self.dense_2 = tf.keras.layers.Dense(128, activation='relu')self.dense_3 = tf.keras.layers.Dense(64, activation='relu')self.dense_4 = tf.keras.layers.Dense(10, activation='softmax')def __call__(self, x):x = self.flatten(x)x = self.dense_1(x)x = self.dense_2(x)x = self.dense_3(x)x = self.dense_4(x)return xmodel = CustomModel()## 定义损失以及准确率
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss)train_accuracy(y, y_pred)## 开始训练
loss_list, accuracy_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):train_one_epoch(batch_x, batch_y)loss_list.append(train_loss.result())accuracy_list.append(train_accuracy.result())if i % 10 == 0:print(f"第{i}次训练: loss: {train_loss.result()} accuarcy: {train_accuracy.result()}")
高阶 API 构建模型
## 定义模型
model = tf.keras.Sequential([tf.keras.layers.Input(shape=[32,32,3]),tf.keras.layers.Flatten(),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax'),
])## 定义optimizer,loss, accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),loss = tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']
)## 开始训练
model.fit(train_data.take(10000))
保存和导入模型

保存模型

tf.keras.models.save_model(model, 'model_folder')

导入模型

model = tf.keras.models.load_model('model_folder')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/671176.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

企业网站 | 被攻击时该怎么办?

前言 每天,数以千计的网站被黑客入侵。发生这种情况时,被入侵网站可用于从网络钓鱼页面到SEO垃圾邮件或者其它内容。如果您拥有一个小型网站,很容易相信黑客不会对它感兴趣。不幸的是,通常情况并非如此。 黑客入侵网站的动机与所…

ArcGIS中SHP转CAD如何分图层以及颜色等(保留属性信息)

很多小伙伴在使用ArcGIS时,想要将SHP图层转成CAD,但结果发现生成的CAD数据在打开时只保留了线条或者面块,其余的属性信息全部丢失,甚至无法做到分层,分颜色。在ArcGIS中想要实现SHP分图层以及颜色转CAD需要对CAD的字段…

数据分析之Tebleau可视化:树状图、日历图、气泡图

树状图(适合子分类比较多的) 1.基本树状图的绘制 同时选择产品子分类和销售金额----选择智能推荐----选择树状图 2.双层树状图的绘制 将第二个维度地区拖到产品分类的下面---大的划分区域是上面的维度(产品分类),看着…

设计模式之传输对象模式

在编程江湖里,有一种模式,它如同数据的“特快专递”,穿梭于系统间,保证信息的快速准确送达,它就是——传输对象模式(Data Transfer Object, DTO)。这不仅仅是数据的搬运工,更是提升系…

小程序激励广告视频多次回调问题

1.问题 2. 激励视频使用及解决方案 官方文档 let videoAd null; // 在页面中定义激励视频广告 Page({/*** 页面的初始数据*/data: {},/*** 生命周期函数--监听页面加载*/onLoad(options) {let that this;// 创建激励视频广告实例if (wx.createRewardedVideoAd) {videoAd w…

打破 AI 算力天花板,Meta超大规模AI基础设施架构解读

Meta超大规模AI智算基础设施架构设计 摘要 双重 GPU 集群,每群配备 2.4 万个 H100 芯片,分别采用 RoCE 和 InfiniBand 网络连接。LLaMA3 就是在这两个集群上训练出来的;Meta AI 将部署庞大算力集群,拥有 35 万张 H100 GPU&#x…

(数据分析方法)长期趋势分析

目录 一、定义 二、目的 三、方法 1、移动平均法 (1)、简单移动平均法 (2)、加权移动平均法 (3)、指数平滑法 2、最小二乘法 3、线性回归 1、数据预处理 2、观察数据分布建立假设模型 3、定义损失函数 4、批量梯度下降 5、优化 4、LSTM 时序分析 5、特征工程 一…

分布式事务了解吗?你们是如何解决分布式事务问题的?(文末有福利)

目录 一、面试官心理分析 二、面试题剖析 1.XA 方案(两阶段提交方案) 2.TCC 方案 3.Saga方案 4.本地消息表 5.可靠消息最终一致性方案 6.最大努力通知方案 7.你们公司是如何处理分布式事务的? 福利放送: 一、面试官心理分析 只要聊…

cmake进阶:文件操作之读文件

一. 简介 cmake 提供了 file() 命令可对文件进行一系列操作,譬如读写文件、删除文件、文件重命名、拷贝文件、创建目录等等。 接下来 学习这个功能强大的 file() 命令。 前一篇文章学习了 CMakeLists.txt语法中写文件操作。文章如下: cmake进阶&…

[极客大挑战 2019]PHP

1.通过目录扫描找到它的备份文件,这里的备份文件是它的源码。 2.源码当中涉及到的关键点就是魔术函数以及序列化与反序列化。 我们提交的select参数会被进行反序列化,我们要构造符合输出flag条件的序列化数据。 但是,这里要注意的就是我们提…

【数据结构与算法】之五道链表进阶面试题详解!

目录 1、链表的回文结构 2、相交链表 3、随机链表的复制 4、环形链表 5、环形链表(||) 6、完结散花 个人主页:秋风起,再归来~ 数据结构与算法 个人格言:悟已往之不谏,知…

本地运行AI大模型简单示例

一、引言 大模型LLM英文全称是Large Language Model,是指包含超大规模参数(通常在十亿个以上)的神经网络模型。2022年11月底,人工智能对话聊天机器人ChatGPT一经推出,人们利用ChatGPT这样的大模型帮助解决很多事情&am…