深度学习框架Keras与Pytorch对比

对于许多科学家、工程师和开发人员来说,TensorFlow是他们的第一个深度学习框架。TensorFlow 1.0于2017年2月发布,可以说,它对用户不太友好。

在过去的几年里,两个主要的深度学习库KerasPytorch获得了大量关注,主要是因为它们的使用比较简单。

本文将介绍Keras与Pytorch的4个不同点以及为什么选择其中一个库的原因。

Keras

Keras本身并不是一个框架,而是一个位于其他深度学习框架之上的高级API。目前它支持TensorFlow、Theano和CNTK。

Keras的优点在于它的易用性。这是迄今为止最容易上手并快速运行的框架。定义神经网络是非常直观的,因为使用API可以将层定义为函数。

Pytorch

Pytorch是一个深度学习框架(类似于TensorFlow),由Facebook的人工智能研究小组开发。与Keras一样,它也抽象出了深层网络编程的许多混乱部分。

就高级和低级代码风格而言,Pytorch介于Keras和TensorFlow之间。比起Keras具有更大的灵活性和控制能力,但同时又不必进行任何复杂的声明式编程(declarative programming)。

深度学习的从业人员整天都在纠结应该使用哪个框架。一般来说,这取决于个人喜好。但是在选择Keras和Pytorch时,你应该记住它们的几个方面。

640?wx_fmt=jpeg

(1)定义模型的类与函数

为了定义深度学习模型,Keras提供了函数式API。使用函数式API,神经网络被定义为一系列顺序化的函数,一个接一个地被应用。例如,函数定义层1( function defining layer 1)的输出是函数定义层2的输入。

img_input = layers.Input(shape=input_shape)
x = layers.Conv2D(64, (3, 3), activation='relu')(img_input)    
x = layers.Conv2D(64, (3, 3), activation='relu')(x)    
x = layers.MaxPooling2D((2, 2), strides=(2, 2))(x)

在Pytorch中,你将网络设置为一个继承来自Torch库的torch.nn.Module的类。与Keras类似,Pytorch提供给你将层作为构建块的能力,但是由于它们在Python类中,所以它们在类的init_ ()方法中被引用,并由类的forward()方法执行。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3)self.conv2 = nn.Conv2d(64, 64, 3)self.pool = nn.MaxPool2d(2, 2)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool(F.relu(self.conv2(x)))return x
model = Net()

(2)张量和计算图模型与标准数组的比较

Keras API向普通程序员隐藏了许多混乱的细节。这使得定义网络层是直观的,并且默认的设置通常足以让你入门。

只有当你正在实现一个相当先进或“奇特”的模型时,你才真正需要深入了解底层,了解一些基本的TensorFlow。

棘手的部分是,当你真正深入到较低级别的TensorFlow代码时,所有的挑战就随之而来!你需要确保所有的矩阵乘法都对齐。不要试着想打印出你自己定义的层的输出,因为你只会得到一个打印在你的终端上的没有错误的张量定义。

Pytorch在这些方面更宽容一些。你需要知道每个层的输入和输出大小,但是这是一个比较容易的方面,你可以很快掌握它。你不需要构建一个抽象的计算图,避免了在实际调试时无法看到该抽象的计算图的细节。

Pytorch的另一个优点是平滑性,你可以在Torch张量和Numpy数组之间来回切换。如果你需要实现一些自定义的东西,那么在TF张量和Numpy数组之间来回切换可能会很麻烦,这要求开发人员对TensorFlow会话有一个较好的理解。

Pytorch的互操作实际上要简单得多。你只需要知道两种操作:一种是将Torch张量(一个可变对象)转换为Numpy,另一种是反向操作。

当然,如果你从来不需要实现任何奇特的东西,那么Keras就会做得很好,因为你不会遇到任何TensorFlow的障碍。但是如果你有这个需求,那么Pytorch将会是一个更加好的选择。

(3)训练模型

640?wx_fmt=jpeg

用Keras训练模特超级简单!只需一个简单的.fit(),你就可以直接去跑步了。

history = model.fit_generator(generator=train_generator,epochs=10,validation_data=validation_generator)

在Pytorch中训练模型包括以下几个步骤:

  1. 在每批训练开始时初始化梯度
  2. 前向传播
  3. 反向传播
  4. 计算损失并更新权重
# 在数据集上循环多次
for epoch in range(2):  for i, data in enumerate(trainloader, 0):# 获取输入; data是列表[inputs, labels]inputs, labels = data # (1) 初始化梯度optimizer.zero_grad() # (2) 前向传播outputs = net(inputs)loss = criterion(outputs, labels)# (3) 反向传播loss.backward()# (4) 计算损失并更新权重optimizer.step()

光是训练就需要很多步骤!

我想这种方式你就会知道实际上发生了什么。由于这些模型训练步骤对于训练不同的模型本质上保持不变,所以这些代码实际上完全不必要的。

(4)控制CPU与GPU模式的比较

640?wx_fmt=jpeg

如果你已经安装了tensorflow-gpu,那么在Keras中使用GPU是默认启用和完成的。如果希望将某些操作转移到CPU,可以使用以下代码。

with tf.device('/cpu:0'):y = apply_non_max_suppression(x)

对于Pytorch,你必须显式地为每个torch张量和numpy变量启用GPU。这将使代码变得混乱,如果你在CPU和GPU之间来回移动以执行不同的操作,则很容易出错。

例如,为了将我们之前的模型转移到GPU上运行,我们需要做以下工作:

#获取GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#传送网络到GPU
net.to(device)# 传送输入和标签到GPU
inputs, labels = data[0].to(device), data[1].to(device)JAVASCRIPT 复制 全屏

Keras在这方面的优势在于它的简单性和良好的默认设置

选择框架的一般建议

我通常给出的建议是从Keras开始。

Keras绝对是最容易使用、理解和快速上手并运行的框架。你不需要担心GPU设置,处理抽象代码,或者做任何复杂的事情。你甚至可以在不接触TensorFlow的任何一行的情况下实现定制层和损失函数。

如果你确实开始深入到深度网络的更细粒度方面,或者正在实现一些非标准的东西,那么Pytorch就是你的首选库。在Keras上实现反而会有一些额外的工作量,虽然不多,但这会拖慢你的进度。使用pytorch能够快速地实现、训练和测试你的网络,并附带易于调试的额外好处!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/308546.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pytorch简介

1.1 Pytorch的历史 PyTorch是一个由Facebook的人工智能研究团队开发的开源深度学习框架。在2016年发布后,PyTorch很快就因其易用性、灵活性和强大的功能而在科研社区中广受欢迎。下面我们将详细介绍PyTorch的发展历程。 在2016年,Facebook的AI研究团队…

广州找工作哪个网站好

吉鹿力招聘网是一个很好的广州找工作网站,它提供了多种类型的招聘信息,包括技工招聘。总之,吉鹿力招聘网是一个有效的招聘网站,可以帮助广州的人们找到合适的工作。 广州找工作上 吉鹿力招聘网 打开 吉鹿力招聘网 “注册账号”&…

服务器运行状况监控工具

服务器运行状况监视提供了每个服务器状态和性能的广泛概述,通过监控服务器指标,如 CPU 使用率、内存消耗、I/O、磁盘使用率、进程等,服务器运行状况监控可以避免服务器停机。 服务器性能监控指标 服务器是网络中最重要的组件之一&#xff0…

基于Spring Cloud + Spring Boot的企业电子招标采购系统源码

随着企业的快速发展,招采管理逐渐成为企业运营中的重要环节。为了满足公司对内部招采管理提升的要求,建立一个公平、公开、公正的采购环境至关重要。在这个背景下,我们开发了一款电子招标采购软件,以最大限度地控制采购成本&#…

【2023年终总结】纵是一路仆仆风尘,也莫忘了仰头

文章目录 1. 写在前面2. 关于生活3. 关于工作4. 关于以后 【作者主页】:吴秋霖 【作者介绍】:Python领域优质创作者、阿里云博客专家、华为云享专家。长期致力于Python与爬虫领域研究与开发工作! 【作者推荐】:对JS逆向感兴趣的朋…

nginx设置跨域访问

目录 一&#xff1a;前端请求 二&#xff1a;后端设置 网站架构前端使用jquery请求&#xff0c;后端使用nginxphp-fpm 一&#xff1a;前端请求 <script> $.getJSON(http://nngzh.youjoy.com/cc.php, { openid: sd, }, function(res) { alert(res); if(res.code 0) …

【windows】在host中设置禁止访问某个网站

1.找到&#xff1a;C:\Windows\System32\drivers\etc目录下的hosts文件。 2.把这个文件复制到桌面上&#xff08;提供更高权限&#xff09; 3.用记事本打开&#xff0c;在末尾添加上127.0.0.1 www.bilibili.com 4.保存后复制回原来的文件夹&#xff0c;替换掉原来的文件 5.…

Postgresql源码(119)PL/pgSQL中ExprContext的生命周期

前言 在PL/pgSQL语言中&#xff0c;执行任何SQL都需要通过SPI调用SQL层解析执行&#xff0c;例如在SQL层执行表达式的入口&#xff1a; static bool exec_eval_simple_expr(PLpgSQL_execstate *estate,PLpgSQL_expr *expr,Datum *result,bool *isNull,Oid *rettype,int32 *re…

独立站如何优化网页加载速度

对于跨境电商独立站而言&#xff0c;流量是跨境电商业务的重中之重&#xff0c;由于独立站并不自带流量&#xff0c;非常依赖于谷歌搜索引擎自然流量&#xff0c;以及付费广告流量。 但随着付费流量价格日益水涨船高&#xff0c;为了摆脱对付费流量的依赖&#xff0c;相信广大…

VSCode远程开发配置

目录 概要远程开发插件安装开始连接SSH无密码登录开发环境配置 概要 现在很多公司都是直接远程到服务器上写代码&#xff0c;使用远程开发&#xff0c;可以在与生产环境相同的环境中开发、测试和部署代码&#xff0c;减少因环境不同而导致的问题。当下VSCode远程开发是支持的比…

Android笔记(二十二):Paging3分页加载库结合Compose的实现网络单一数据源访问

Paging3 组件是谷歌公司推出的分页加载库。个人认为Paging3库是非常强大&#xff0c;但是学习难点比较大的一个库。Paging3组件可用于加载和显示来自本地存储或网络中更大的数据集中的数据页面。此方法可让移动应用更高效地利用网络带宽和系统资源。在具体实现上&#xff0c;Pa…

最新AI系统ChatGPT网站H5系统源码,支持AI绘画,GPT语音对话+ChatFile文档对话总结+DALL-E3文生图

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Ch…