Tensorflow2.0笔记 - 自定义Layer和Model实现CIFAR10数据集的训练

       本笔记记录使用自定义Layer和Model来做CIFAR10数据集的训练。

        CIFAR10数据集下载:

        https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

        自定义的Layer和Model实现较为简单,参数量较少,并且没有卷积层和dropout等,最终准确率不高,仅做练习使用。

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255y = tf.cast(y, dtype=tf.int32)return x,ybatchsize = 128
#CIFAR10数据集下载,可以直接使用网络下载
(x,y), (x_val, y_val) = datasets.cifar10.load_data()
#CIFAR10的标签(训练集)数据维度是[50000, 1],通过squeeze消除掉里面1的维度,变成[50000]
print("y.shape:", y.shape)
y = tf.squeeze(y)
print("squeezed y.shape:", y.shape)
y_val = tf.squeeze(y_val)
#进行onehot编码
y = tf.one_hot(y, depth=10)
y_val = tf.one_hot(y_val, depth=10)
print("Datasets: ", x.shape, " ", y.shape, " x.min():", x.min(), " x.max():", x.max())train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsize)sample = next(iter(train_db))
print("Batch:", sample[0].shape, sample[1].shape)#自定义Layer
class MyDense(layers.Layer):def __init__(self, input_dim, output_dim):super(MyDense, self).__init__()self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))#self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim])#self.bias = self.add_weight(name='b', shape=[output_dim])def call(self, inputs, training = None):x = inputs@self.kernel + self.biasreturn xclass MyNetwork(keras.Model):def __init__(self):super(MyNetwork, self).__init__()self.fc1 = MyDense(32 * 32 * 3, 512)self.fc2 = MyDense(512, 512)self.fc3 = MyDense(512, 256)self.fc4 = MyDense(256, 256)self.fc5 = MyDense(256, 10)def call(self, inputs, training = None):x = tf.reshape(inputs, [-1, 32 * 32 * 3])x = self.fc1(x)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)x = tf.nn.relu(x)#返回logitsreturn xtotal_epoches = 35
learn_rate = 0.001
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),loss = tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['Accuracy'])
network.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=1)

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/597869.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UNITY实战进阶-BatchRendererGroup+Jobs+Burst+RVO2+GPUAnimation 实现万人团战(一)

研究思路:GPUAnimation把动画放入GPU中处理,BatchRendererGroup进行动态批量渲染处理,JobsBurst进行多线程处理逻辑(移动、攻击等),RVO2采用Jobs的寻路导航。 准备工作: Editor > Project S…

论文阅读——Sat2Vid

Sat2Vid: Street-view Panoramic Video Synthesis from a Single Satellite Image 提出了一种新颖的方法,用于从单个卫星图像和摄像机轨迹合成时间和几何一致的街景全景视频。 即根据单个卫星图像和给定的观看位置尽可能真实地、尽可能一致地合成街景全景视频序列。…

【Easy云盘 | 第二篇】后端统一设计思想

文章目录 4.1后端统一设计思想4.1.1后端统一返回格式对象4.1.2后端统一响应状态码4.1.3后端统一异常处理类4.1.4StringUtils类4.1.5 RedisUtils类 4.1后端统一设计思想 4.1.1后端统一返回格式对象 com.easypan.entity.vo.ResponseVO Data public class ResponseVO<T> …

注意!今明两天广东等地仍有较强降雨

中央气象台监测显示 进入4月以来 我国江南、华南北部强降雨 接连而至 湖南、江西、浙江中南部 福建大部、广东中北部等地降雨量 较常年同期偏多1倍以上 上述地区部分国家观测站 日雨量突破4月历史极值 截至4月7日早晨 广东广州、惠州、清远 韶关、河源等地部分地区 …

JavaScript - 你能说出解决跨域的一些方案吗

难度级别:中高级及以上 提问概率:65% 回答解决跨域之前,首先建议求职者描述什么是跨域。跨域问题是浏览器基于同源策略引起的,同源策略是浏览器的一种安全功能。同源要保证域名相同,或是被访问服务的协议+主机+端口都相同,那么反之这些属…

每日一题|字符迁移【算法赛】|字符数组+前缀和+差分

每日一题|字符迁移【算法赛】 字符迁移 心有猛虎&#xff0c;细嗅蔷薇。你好朋友&#xff0c;这里是锅巴的C\C学习笔记&#xff0c;常言道&#xff0c;不积跬步无以至千里&#xff0c;希望有朝一日我们积累的滴水可以击穿顽石。 字符迁移 注意&#xff1a; 预习知识&#xf…

实景三维在文化旅游领域的应用

实景三维技术&#xff0c;作为一种前沿的科技手段&#xff0c;近年来在文化旅游领域的应用逐渐崭露头角。它能够将真实世界的场景以三维的形式精确呈现&#xff0c;为游客带来身临其境的体验&#xff0c;为文化旅游注入新的活力。本文将探讨实景三维在文化旅游领域的应用及其所…

c语言实现2048小游戏

#include <stdio.h> #include <stdlib.h> #include <time.h> #include <conio.h>int best 0 ;// 定义2048游戏的结构体 typedef struct { int martix[16]; // 当前4*4矩阵的数字 int martixPrior[16]; // 上一步的4*4矩阵的数字 int emptyIndex[16…

持续交付工具Argo CD的部署使用

Background CI/CD&#xff08;Continuous Integration/Continuous Deployment&#xff09;是一种软件开发流程&#xff0c;旨在通过自动化和持续集成的方式提高软件交付的效率和质量。它包括持续集成&#xff08;CI&#xff09;和持续部署&#xff08;CD&#xff09;两个主要阶…

YOLOv8改进 | 细节涨点篇 | 利用YOLOv8自带的RayTune进行超参数调优

一、本文介绍 本文给大家带来的改进机制是利用Ray Tune进行超参数调优,在YOLOv8的项目中目前已经自带了该超参数调优的代码,我们无需进行任何的改动,只需要调用该方法输入我们的一些指令即可,当然了,这些超参数的设置还是比较又学问的,本文的内容也是应群友的需求进行发…

(学习日记)2024.04.05:UCOSIII第三十三节:互斥量

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

吴恩达2022机器学习专项课程(一) 第二周课程实验:多元线性回归(Lab_02)

1.训练集 使用Numpy数组存储数据集。 2.打印数组 打印两个数组的形状和数据。 3.初始化w&#xff0c;b 为了演示&#xff0c;w&#xff0c;b预设出接近最优解的值。w是一个一维数组&#xff0c;w个数对应特征个数。 4.非向量化计算多元线性回归函数 使用for循环&…