深度学习框架:Pytorch与Keras的区别与使用方法

  

☁️主页 Nowl

🔥专栏《机器学习实战》 《机器学习》

📑君子坐而论道,少年起而行之 

文章目录

Pytorch与Keras介绍

Pytorch

模型定义

模型编译

模型训练

输入格式

完整代码

Keras

模型定义

模型编译

模型训练

输入格式

完整代码

区别与使用场景

结语


Pytorch与Keras介绍

pytorch和keras都是一种深度学习框架,使我们能很便捷地搭建各种神经网络,但它们在使用上有一些区别,也各自有其特性,我们一起来看看吧

Pytorch

模型定义

我们以最简单的网络定义来学习pytorch的基本使用方法,我们接下来要定义一个神经网络,包括一个输入层,一个隐藏层,一个输出层,这些层都是线性的,给隐藏层添加一个激活函数Relu,给输出层添加一个Sigmoid函数

import torch
import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.Sigmoid(x)return x

模型编译

我们在之前的机器学习文章中反复提到过,模型的训练是怎么进行的呢,要有一个损失函数与优化方法,我们接下来看看在pytorch中怎么定义这些

import torch.optim as optim# 实例化模型对象
model = SimpleNet()
# 定义损失函数
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

我们上面创建的神经网络是一个类,所以我们实例化一个对象model,然后定义损失函数为mse,优化器为随机梯度下降并设置学习率

模型训练

# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()

以上步骤是先创建了一些随机样本,作为模型的训练集,然后定义训练轮次为100次,然后前向传播数据集,计算损失,再优化,如此反复

输入格式

关于输入格式是很多人在实战中容易出现问题的,对于pytorch创建的神经网络,我们的输入内容是一个torch张量,怎么创建呢

data = torch.Tensor([[1], [2], [3]])

很简单对吧,上面这个例子创建了一个torch张量,有三组数据,每组数据有1个特征

我们可以把这个数据输入到训练好的模型中,得到输出结果,如果输出不是torch张量,代码就会报错

完整代码

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(1, 32)self.relu = nn.ReLU()self.fc2 = nn.Linear(32, 1)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.sigmoid(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()# 定义优化器
learning_rate = 0.01
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 创建随机输入数据和目标数据
input_data = torch.randn((100, 1))  # 100个样本,每个样本有1个特征
target_data = torch.randn((100, 1))  # 100个样本,每个样本有1个目标值# 训练模型
epochs = 100for epoch in range(epochs):# 前向传播output = model(input_data)# 计算损失loss = criterion(output, target_data)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()data = torch.Tensor([[1], [2], [3]])
prediction = model(data)print(prediction)

可以看到模型输出了三个预测值

注意,这个任务本身没有意义,因为我们的训练集是随机生成的,这里主要学习框架的使用方法

Keras

我们在这里把和上面相同的神经网络结构使用keras框架实现一遍

模型定义

from keras.models import Sequential
from keras.layers import Densemodel = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])

注意这里也是一层输入层,一层隐藏层,一层输出层,和pytorch一样,输入层是隐式的,我们的输入数据就是输入层,上述代码定义了一个隐藏层,输入维度是1,输出维度是32,还定义了一个输出层,输入维度是32,输出维度是1,和pytorch环节的模型结构是一样的 

模型编译

那么在Keras中模型又是怎么编译的呢

model.compile(loss='mse', optimizer='sgd')

非常简单,只需要这一行代码 ,设置损失函数为mse,优化器为随机梯度下降

模型训练

模型的训练也非常简单

# 训练模型
model.fit(input_data, target_data, epochs=100)

 因为我们已经编译好了损失函数和优化器,在fit里只需要输入数据,输出数据和训练轮次这些参数就可以训练了

输入格式

对于Keras模型的输入,我们要把它转化为numpy数组,不然会报错

data = np.array([[1], [2], [3]])

完整代码

from keras.models import Sequential
from keras.layers import Dense
import numpy as np# 定义模型
model = Sequential([Dense(32, input_dim=1, activation='relu'),Dense(1, activation='sigmoid')
])# 创建随机输入数据和目标数据
input_data = np.random.randn(100, 1)  # 100个样本,每个样本有10个特征
target_data = np.random.randn(100, 1)  # 100个样本,每个样本有5个目标值# 编译模型
model.compile(loss='mse', optimizer='sgd')
# 训练模型
model.fit(input_data, target_data, epochs=10)data = np.array([[1], [2], [3]])prediction = model(data)
print(prediction)

可以看到,同样的任务,Keras的代码量小很多

区别与使用场景

Keras代码量少,使用便捷,适用于快速实验和快速神经网络设计

而pytorch由于结构是由类定义的,可以更加灵活地组建神经网络层,这对于要求细节的任务更有利,同时,pytorch还采用动态计算图,使得模型的结构可以在运行时根据输入数据动态调整,但这个特点我还没有接触到,之后可能会详细讲解

结语

Keras和Pytorch都各有各的优点,请读者根据需求选择,同时有些深度学习教程偏向于使用某一种框架,最好都学习一点,以适应不同的场景

 

感谢阅读,觉得有用的话就订阅下本专栏吧 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/231582.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务链路追踪组件SkyWalking实战

概述 微服务调用存在的问题 串联调用链路,快速定位问题;理清服务之间的依赖关系;微服务接口性能分析;业务流程调用处理顺序; 全链路追踪:对请求源头到底层服务的调用链路中间的所有环节进行监控。 链路…

基于Java web的多功能游戏大厅系统的开发与实现

摘 要 目前,国内游戏市场上的网络游戏有许多种类,游戏在玩法上也越来越雷同,形式越来越单调。这种游戏性系统给玩家带来的成就感虽然是无穷的,但是也有随之而来的疲惫感,尤其是需要花费大量的时间和精力,这…

ESP32-Web-Server编程- WebSocket 编程

ESP32-Web-Server编程- WebSocket 编程 概述 在前述 ESP32-Web-Server 实战编程-通过网页控制设备的 GPIO 中,我们创建了一个基于 HTTP 协议的 ESP32 Web 服务器,每当浏览器向 Web 服务器发送请求,我们将 HTML/CSS 文件提供给浏览器。 使用…

从0开始学习JavaScript--JavaScript 中 `let` 和 `const` 的区别及最佳实践

在JavaScript中,let 和 const 是两个用于声明变量的关键字。尽管它们看起来很相似,但它们之间有一些重要的区别。本篇博客将深入探讨 let 和 const 的用法、区别,并提供一些最佳实践,以确保在代码中正确使用它们。 let 和 const …

U-boot(七):U-boot移植

本文主要探讨基于210官方U-boot源码移植。 移植基础 tar -jxvf android_uboot_smdkv210.tar.bz2cd u-boot-samsung-devrm -rf onenand_ipl onenand_bl1 lib_avr32 lib_blackfin lib_i386 lib_m68k lib_mips lib_microblaze lib_nios lib_nios2 lib_ppc lib_sh lib_sparccd bo…

11.兔子生崽问题【2023.11.26】

1.问题描述 有一对兔子&#xff0c;从出生后第3个月起每个月都生一对兔子&#xff0c;小兔子长到第三个月后每个月又生一对兔子&#xff0c;假如兔子都不死&#xff0c;问 第二十个月的兔子对数为多少对&#xff1f; 2.解决思路 3.代码实现 #include<stdio.h> int mai…

Linux:windows 和 Linux 之间文本格式转换

背景 在 Windows 上编辑的文件&#xff0c;放到 Linux 平台&#xff0c;有时会出现奇怪的问题&#xff0c;其中有一个是 ^M 引起的&#xff0c;例如这种错误&#xff1a; /bin/bash^M: bad interpreter 这个问题相信大家也碰到过&#xff0c;原因是 Windows 和 Linux 关于换行的…

Javase | Java常用类 (不断补充中...)

目录: 1.Object类2.String类3.StringBuffer类4.Math类5.Random类6.包装类(不断补充中...) 1.Object类 Object类是Java语言中的所有类的超类&#xff0c;即所有类的根。它中描述的所有方法&#xff0c;所有类都可以使用。 equals( ) : 指示其他某个对象与此对象“是否相等” (比…

PC端数据列表有头像显示头像,没有头像显示名字的第一个字

PC端数据列表有头像显示头像&#xff0c;没有头像显示名字的第一个字 .charAt(0) 是 JavaScript 字符串对象的方法&#xff0c;用于获取字符串的第一个字符。 字符串中的字符位置是从 0 开始的&#xff0c;所以.charAt(0) 就表示获取字符串的第一个字符。 <el-table ref&qu…

layui提示框没有渲染bug解决

bug&#xff1a;使用layui时或许是依赖导入又或是ideal和浏览器缓存问题导致前面明明正常的页面显示&#xff0c;后面出现提示框没有css样式&#xff0c;弹出框没有背景css 效果如下 解决后 解决方法 在你的代码中引入layer.js 我这是jsp页面 <script type"text/jav…

PHP微信UI在线聊天系统源码 客服私有即时通讯系统 附安装教程

DuckChat是一套完整的私有即时通讯解决方案&#xff0c;包含服务器端程序和各种客户端程序&#xff08;包括iOS、Android、PC等&#xff09;。通过DuckChat&#xff0c;站点管理员可以快速在自己的服务器上建立私有的即时通讯服务&#xff0c;用户可以使用客户端连接至此服务器…

Parasoft:正确的静态应用程序安全测试 (SAST) 解决方案

随着软件开发从Web应用扩展到工业物联网&#xff08;IIoT&#xff09;设备&#xff0c;静态应用安全测试&#xff08;SAST&#xff09;越来越有必要从根本上帮助确保软件的功能安全。根据 Forrester Research的研究&#xff0c;网络攻击是近两年安全漏洞的主要来源。因此&#…