政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

目录

介绍

编写组件

训练模型


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

通过 Keras,您可以编写自定义层、模型、度量指标、损失和优化器,并在同一代码库中跨 TensorFlow、JAX 和 PyTorch 运行

老规矩,咱们还是先准备环境(参考我本专栏目录中的文章,其中有搭建环境的部分):

政安晨:【TensorFlow与Keras实战演绎机器学习】专栏 —— 目录icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136985399

准备好环境后,咱们开始。

编写组件

让我们先来看看自定义层

{keras.ops 命名空间包含}
1. NumPy API 的实现,例如 keras.ops.stack 或 keras.ops.matmul
2. 一组 NumPy 中没有的神经网络特定操作,如 keras.ops.conv 或 keras.ops.binary_crossentropy

让我们创建一个可与所有后端配合使用的自定义密集层

class MyDense(keras.layers.Layer):def __init__(self, units, activation=None, name=None):super().__init__(name=name)self.units = unitsself.activation = keras.activations.get(activation)def build(self, input_shape):input_dim = input_shape[-1]self.w = self.add_weight(shape=(input_dim, self.units),initializer=keras.initializers.GlorotNormal(),name="kernel",trainable=True,)self.b = self.add_weight(shape=(self.units,),initializer=keras.initializers.Zeros(),name="bias",trainable=True,)def call(self, inputs):# Use Keras ops to create backend-agnostic layers/metrics/etc.x = keras.ops.matmul(inputs, self.w) + self.breturn self.activation(x)

接下来,让我们制作一个依赖于keras.random命名空间的自定义Dropout层

class MyDropout(keras.layers.Layer):def __init__(self, rate, name=None):super().__init__(name=name)self.rate = rate# Use seed_generator for managing RNG state.# It is a state element and its seed variable is# tracked as part of `layer.variables`.self.seed_generator = keras.random.SeedGenerator(1337)def call(self, inputs):# Use `keras.random` for random ops.return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)

接下来,让我们编写一个自定义子类模型,使用我们的两个自定义层:

class MyModel(keras.Model):def __init__(self, num_classes):super().__init__()self.conv_base = keras.Sequential([keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),keras.layers.MaxPooling2D(pool_size=(2, 2)),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),keras.layers.GlobalAveragePooling2D(),])self.dp = MyDropout(0.5)self.dense = MyDense(num_classes, activation="softmax")def call(self, x):x = self.conv_base(x)x = self.dp(x)return self.dense(x)

让我们编译并适配它:

model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)model.fit(x_train,y_train,batch_size=batch_size,epochs=1,  # For speedvalidation_split=0.15,
)

现在咱们演绎如下

在本地的TensorFlow虚拟环境中,首先导入keras:

from tensorflow import keras

(可以在Jupyter Notebook中运行)

如果在演绎执行中出错,可能是Keras版本问题,使用如下命令升级keras

sudo pip install --upgrade keras

执行结果:

训练模型

在任意数据源上训练模型

所有的Keras模型都可以在各种数据来源上进行训练和评估,与您使用的后端无关。这包括:

NumPy数组 Pandas数据框 TensorFlow tf.data.Dataset对象 PyTorch DataLoader对象 Keras PyDataset对象 无论您使用TensorFlow、JAX还是PyTorch作为Keras后端,它们都可以工作。

让我们尝试使用PyTorch DataLoader:

import torch# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)
)# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(val_torch_dataset, batch_size=batch_size, shuffle=False
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)

现在让我们尝试使用tf.data来完成这个任务

import tensorflow as tftrain_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)
test_dataset = (tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)model = MyModel(num_classes=10)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),optimizer=keras.optimizers.Adam(learning_rate=1e-3),metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc"),],
)
model.fit(train_dataset, epochs=1, validation_data=test_dataset)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/578200.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

封装性练习

练习 1 : 创建程序:在其中定义两个类: Person 和 PersonTest 类。定义如下: 用 setAge() 设置人的合法年龄 (0~130) ,用 getAge() 返回人的年龄。在 PersonTest 类中实例化 Person 类的对象 b ,调用 set…

李雅普诺夫函数

李雅普诺夫函数是一种用于描述动力系统稳定性的数学工具。它在动力系统和控制理论中具有广泛的应用,尤其是在研究非线性系统的稳定性方面。 李雅普诺夫函数通常用于证明动力系统在一些条件下是稳定的。一个李雅普诺夫函数是一个实数值函数,通常表示为 V…

【Django学习笔记(二)】CSS语言介绍

CSS语言介绍 前言正文1、CSS 快速了解2、CSS 应用方式2.1 在标签上应用2.2 在head标签中写style标签2.3 写到文件中 3、问题探讨:用Flask框架开发不方便4、选择器4.1 ID选择器4.2 类选择器4.3 标签选择器4.4 属性选择器4.5 后代选择器4.6 注意事项 5、样式5.1 高度和…

使用mysql官网软件包安装mysql

确定你的操作系统,我的是Centos myqsl 所有安装包的地址:https://repo.mysql.com/yum/ 如果你是使用rpm安装你可以到对应的版本里面找到对应的包。 mysql 发行包的地址:http://repo.mysql.com/ 在这里你可以找到对应的发布包安装。 这里使用y…

【Qt】:坐标

坐标 一.常用快捷键二.使用帮助文档三.Qt坐标体系1.理论2.代码 一.常用快捷键 注释:ctrl / • 运⾏:ctrl R • 编译:ctrl B • 字体缩放:ctrl ⿏标滑轮 • 查找:ctrl F • 整⾏移动:ctrl shift ⬆/…

什么是防火墙,部署防火墙有什么好处?

与我们的房屋没有围墙或界限墙一样,没有防护措施的计算机和网络将容易受到黑客的入侵,这将使我们的网络处于巨大的风险之中。因此,就像围墙保护我们的房屋一样,虚拟墙也可以保护和安全我们的设备,使入侵者无法轻易进入…

3D汽车模型线上三维互动展示提供视觉盛宴

VR全景虚拟看车软件正在引领汽车展览行业迈向一个全新的时代,它不仅颠覆了传统展览的局限,还为参展者提供了前所未有的高效、便捷和互动体验。借助于尖端的vr虚拟现实技术、逼真的web3d开发、先进的云计算能力以及强大的大数据处理,这一在线展…

蓝桥备赛——堆队列

AC code import os import sys import heapq a [] b [] n,k map(int,input().split())for _ in range(n):x,y map(int,input().split())a.append(x)b.append(y) q []# 第一种情况:不打第n个怪兽# 将前n-1个第一次所需能量加入堆 for i in range(n-1):heapq.h…

使用anime.js实现列表滚动轮播

官网&#xff1a;https://animejs.com/ html <div id"slide1"><div class"weather-item" v-for"item in weatherList"><div><img src"../../images/hdft/position.png" alt"">{{item.body.cityInf…

DVWA-CSRF通关教程-完结

DVWA-CSRF通关教程-完结 文章目录 DVWA-CSRF通关教程-完结Low页面使用源码分析漏洞利用 Medium源码分析漏洞利用 High源码分析漏洞利用 impossible源码分析 Low 页面使用 当前页面上&#xff0c;是一个修改admin密码的页面&#xff0c;只需要输入新密码和重复新密码&#xff…

LeetCode:718最长重复子数组 C语言

718. 最长重复子数组 提示 给两个整数数组 nums1 和 nums2 &#xff0c;返回 两个数组中 公共的 、长度最长的子数组的长度 。 示例 1&#xff1a; 输入&#xff1a;nums1 [1,2,3,2,1], nums2 [3,2,1,4,7] 输出&#xff1a;3 解释&#xff1a;长度最长的公共子数组是 [3,…

protobuf学习笔记(二):结合grpc生成客户端和服务端

上一篇文章大概讲了如何将自定义的protobuf类型的message转换成相应的go文件&#xff0c;这次就结合grpc写一个比较认真的客户端和服务器端例子 一、项目结构 client存放rpc服务的客户端文件 server存放rpc服务的服务端文件 protobuf存放自定义的proto文件 grpc存放生成的g…