3 tensorflow构建模型详解

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
神经网络是一种广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应。

2、密集层

在上一篇我们创建了预测价格模型,其中代码为:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

意思是构建了一个模型,里面添加了一层神经网络,只有一个神经元。(Sequential是顺序的意思,Dense是密集层。)

密集层(也叫全连接层),在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

3、西瓜费用预测模型详解

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=500)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

下面是代码详解:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)设置训练数据
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=[1])
])model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')history = model.fit(weight, total_cost, epochs=5)# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]print('w:')
print(w)
print('b: ')
print(b)# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 

4、创建更复杂一点的模型

现实生活中我们要预测的东西影响因素可能有很多个,如房价预测,房价可能受到房屋面积、房间数量等等因素影响。思考一下,下面的神经网络图创建模型时要如何设置参数呢?

model = tf.keras.Sequential([tf.keras.layers.Dense(2, input_shape=[3]),tf.keras.layers.Dense(1)
])


 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/154730.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

openEuler社区2023年度满意度调研

Hi,朋友们 一年一度的openEuler社区满意度调研来啦!我们诚邀您参与问卷调研,反馈您在社区的使用情况。我们会持续吸纳建议,为您创造更好的用户体验! 时间:2023年10月27日-2023年12月17日 。 链接&#x…

【Linux】第三站:Linux基本指令(二)

文章目录 一、通配符 *二、man指令三、cp指令1.先给一个文件里面写入数据2. cp指令拷贝普通文件3.cp指令拷贝文件目录4.常用的选项总结 四、mv指令1.mv命令简介2.使用 五、一些插曲1.一些注意事项2.指令的本质3.再谈输出重定向4.追加重定向5.输入重定向 六、cat指令七、more指令…

电脑录像功能在哪?一文帮你轻松破解

“电脑录像功能在哪里呀?最近因工作上的原因,需要使用电脑来录像,但是找了一上午都找不到在哪里,眼看已经快没时间了,现在真的很急,希望大家帮帮我。” 电脑已经成为了人们生活和工作中必不可少的工具&…

【ROS教程demo】用C++创建一个ROS节点,发布指令使得小海龟做圆周运动

ROS创建节点发布命令使得小海龟做圆周运动 1.任务需求2.任务分析2.1发布方topic和msg2.2接收方topic和msg2.3目标明确!3.创建ROS节点3.1创建发布方节点pub_pose3.2创建订阅方节点sub_pose1.任务需求 创建一个节点,在其中实现一个订阅者和一个发布者,完成以下功能: 发布者:…

Mac电脑配置Dart编程环境

1.安装Dart SDK 官网地址:https://dart.dev/get-dart $brew tap dart-lang/dart$brew install dart 安装后,用命令检测一下是否安装正常。 $brew info dart 2.VS Code配置Dart环境 1).安装VS Code 官网地址:https://code.visualstudio.c…

APP推荐:推荐一款免费无广告的本地音乐播放器,手机听歌必备

目录 一、软件简介 二、软件特色 三、软件使用 四、软件下载 相信很多朋友都喜欢听歌,今天给大家推荐一款非常棒的手机本地音乐APP——糖醋音乐,完全无广告、免费听歌,大家只需要把自己需要的歌曲下载到你的手机就可以愉快的听歌了&#…

Linux开发者的利器:深入了解环境开发工具之yum篇

W...Y的主页 😊 代码仓库分享💕 🍔前言:在博主的博客中,Linux系统我们已经将关键指令、权限等等全部了解完了。接下来我们应该学习什么呢?当我们拿起一个手机或电脑,我们最先想到的就是下载QQ、…

SurfaceFliger绘制流程

前景提要: 当HWComposer接收到Vsync信号时,唤醒DisSync线程,在其中唤醒EventThread线程,调用DisplayEventReceiver的sendObjects像BitTub发送消息,由于在SurfaceFlinger的init过程中创建了EventThread线程&#xff0c…

Google Play PolicyBytes 政策更新中文视频 | 2023 年 10 月

Google Play 持续帮助开发者开启成功出海之旅,为用户提供安全优质的应用。也感谢大家与我们携手合作,继续努力将 Google Play 打造为一个安全可信赖的平台。欢迎您观看 Google Play PolicyBytes 中文视频了解 2023 年 10 月政策更新内容,更及…

*VS Code中的Ajax

下载插件并使用 下载插件,开放一个端口给要加载的资源,解决跨域问题,没有后端接收数据,用来做小模块很合适 建立文件夹,文件夹下放入jquery插件和json文件 data.json {"total": 4,"data": [{"name&qu…

POI实现省市级联(二级下拉框)

POI实现省市级联(二级下拉框) POI级联下拉框 直接上代码测试结果参考文章 POI级联下拉框 业务上经常会用到POI做Excel的导出,有时导出需求比较复杂,这里记录一下自己参考网上大神的水月境的博文完成的一个导出Excel省市级连下拉…

Spark On Hive原理和配置

目录 一、Spark On Hive原理 (1)为什么要让Spark On Hive? 二、MySQL安装配置(root用户) (1)安装MySQL (2)启动MySQL设置开机启动 (3)修改MySQL…