从零开始学习深度学习库-1:前馈网络

你好!欢迎来到这个系列的第一篇文章,我们将尝试用Python构建自己的深度学习库。在这篇文章中,我们将开始编写一个简单的前馈神经网络。我们将仅在这篇文章中处理前向传播,并在下一篇文章中处理网络的训练。这篇文章将介绍基本的前馈神经网络如何接收输入并从中产生输出。

首先,什么是神经网络?

神经网络是一种机器学习技术,它大致模仿了大脑的模型。和所有机器学习技术一样,它通过包含输入及其对应输出的数据集来学习。神经网络由层组成。每一层都通过权重和偏置与下一层相连。这些权重和偏置被网络用来计算它将给出的输出。在网络训练时,这些权重和偏置会被调整,以便网络根据其训练的数据产生最优输出。
在这里插入图片描述
这张图展示了一个三层的神经网络。连接节点的线条用于表示网络的权重和偏置。

它们是如何工作的?(数学原理)

每一层都有自己的权重和偏置。
权重和偏置最初开始时是一个随机值矩阵。
一个基本的前馈神经网络只包含线性层。
线性层使用以下公式产生它们的输出。

x @ w + b其中...
x 是输入到该层的数据
w 是该层的权重
b 是该层的偏置
(@ 表示矩阵乘法)

每层的输出作为下一层的输入。
注意
如果你不了解矩阵乘法是如何工作的,这个网站很好地解释了这一点。
我们现在只讨论到这里 - 下一篇文章我们将深入研究这些权重和偏置在训练中是如何被纠正的!

激活函数

神经网络的层由节点组成。
激活函数应用于层,以确定哪些节点应该“触发”/“激活”。这种“触发”在人类大脑中也可以观察到,因此在神经网络中引入了激活函数,因为它们大致基于大脑的模型。
激活函数还允许网络模拟非线性数据。没有激活函数,神经网络只会是一个线性回归模型,这意味着它无法模拟大多数现实世界的数据。
有多种激活函数,但以下是最常用的几种…

Sigmoid

Sigmoid函数将输入映射到0到1之间的值,如下图所示。
在这里插入图片描述
在这里插入图片描述
x是输入向量

Relu(修正线性单元)

Relu函数只允许输入向量的正值通过。负值被映射为0。
例如,

[[-5, 10]  [15, -10] --> relu --> [[0, 10][15, 0]]

在这里插入图片描述
在这里插入图片描述

Tanh

Tanh与Sigmoid类似,不同之处在于它将输入映射到-1到1之间的值。
在这里插入图片描述

Softmax

Softmax接收一个输入,并将其映射为概率分布(这意味着输出中的所有值之和为1)。
在这里插入图片描述
(z是输入向量,K是输入向量的长度)

编写代码

我们的矩阵操作需要numpy…

import numpy as np

首先,让我们写我们的线性层类

class Linear:def __init__(self, units):# units指定层中有多少节点self.units = unitsself.initialized = Falsedef __call__(self, x):# 如果层之前没有被调用过,初始化权重和偏置if not self.initialized:self.w = np.random.randn(self.input.shape[-1], self.units)self.b = np.random.randn(self.units)self.initialized = Truereturn self.input @ self.w + self.b

示例用法

x = np.array([[0, 1]])
layer = Linear(5)
print (layer(x))# => [[-2.63399933 -1.18289984  0.32129587  0.2903246  -0.2602642 ]]

现在,让我们按照之前给出的公式编写我们所有的激活函数类

class Sigmoid:def __call__(self, x):return 1 / (1 + np.exp(-x))class Relu:def __call__(self, x):return np.maximum(0, x)   class Softmax:def __call__(self, x):return np.exp(x) / np.sum(np.exp(x))   class Tanh:def __call__(self, x):return np.tanh(x)

现在让我们编写一个“model”类,它将作为我们所有层的容器/实际的神经网络类。

class Model:def __init__(self, layers):self.layers = layersdef __call__(self, x):output = xfor layer in self.layers:output = layer(x)return output

将所有这些类保存到layer.py(或您喜欢的任何名称)中。
现在我们可以使用我们迄今为止的小型库来构建一个简单的神经网络。

import layers
import numpy as np# 输入数组
x = np.array([[0, 1], [0, 0], [1, 1], [0, 1]])# 网络使用我们迄今为止设计的所有层
net = layers.Model([layers.Linear(32),layers.Sigmoid(),layers.Linear(16),layers.Softmax(),layers.Linear(8),layers.Tanh(),layers.Linear(4),layers.Relu(),
])print (net(x))
Output:
[[0.         3.87770361 0.17602662 0.        ][0.         3.85640582 0.22373699 0.        ][0.         3.77290517 0.2469388  0.        ][0.         3.87770361 0.17602662 0.        ]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/536600.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习网络安全:记一次某网站渗透测试过程

本文作者: 汇智知了堂信安教学老师——辉哥 一、信息收集 网站界面 网站信息收集 (1)中间件信息 (2)目录扫描 思路:由于是cms的站,针对这种情况,我们可以收集cms的默认目录结构来…

java内部类的作用与优缺点

一、前言 很久没看到java内部类了,今天在审查代码时候,发现了java内部类,主要是内部类还嵌套了内部类。于是记录一下 二、java内部类的作用与优缺点 Java内部类,也称为嵌套类,是定义在另一个类(外部类&am…

1.1计算机系统构成及硬件系统知识(上)

基础知识部分----chap01 主要议题: 数制转换:一般会涉及存取的计算;ip地址中变长子网掩码的计算题;(难度较大) 数的表示:二进制、十六进制; 计算机的组成:考察的较为深入…

【Java语言】遍历List元素时删除集合中的元素

目录 前言 实现方式 1.普通实现 1.1 使用【for循环】 方式 1.2 使用【迭代器】方式 2.jdk1.8新增功能实现 2.1 使用【lambda表达式】方式 2.2 使用【stream流】方式 注意事项 1. 使用【for循环】 方式 2. 不能使用增强for遍历修改元素 总结 前言 分享几种从List中移…

FreeRTOS操作系统学习——任务通知

任务通知介绍 所谓任务通知,也可以反过来通知任务。在以往使用队列,信号量,事件组等等方法时,我们并不知道对方是谁,而在使用任务通知时,可以明确指定通知哪个任务。使用任务通知时,任务结构体…

程序员的三重境界:码农,高级码农、程序员!

见字如面,我是军哥! 掐指一算,我在 IT 行业摸爬滚打 19 年了,见过的程序员至少大好几千,然后真正能称上程序员不到 10% ,绝大部分都是高级码农而已。 今天和你聊聊程序员的三个境界的差异,文章不…

未来已来:科技驱动的教育变革

我们的基础教育数百年来一成不变。学生们齐聚在一个物理空间,听老师现场授课。每节课时长和节奏几乎一致,严格按照课表进行。老师就像“讲台上的圣人”。这种模式千篇一律,并不适用于所有人。学生遇到不懂的问题,只能自己摸索或者…

【数据结构】详解时间复杂度和空间复杂度的计算

一、时间复杂度(执行的次数) 1.1时间复杂度的概念 1.2时间复杂度的表示方法 1.3算法复杂度的几种情况 1.4简单时间复杂度的计算 例一 例二 例三 1.5复杂时间复杂度的计算 例一:未优化冒泡排序时间复杂度 例二:经过优化…

JAVA初阶数据结构链表(2)双向链表( +专栏数据结构练习是完整版)

1.双向链表的结构(双向不带头不循环链表) 需要注意的一点就是,在jdk中的链表就是双向链表 一个节点有三个域 val(数值域) next(地址域) prev(前驱记录前一个节点的地址&#xff09…

Express学习(四)

使用Express写接口 创建基本的服务器 创建API路由模块 编写GET接口 编写POST接口 CORS跨域资源共享 什么是CORS CORS由一系列HTTP响应头组成,这些HTTP响应头决定浏览器是否阻止前端JS代码跨域获取资源。浏览器的同源安全策略默认会阻止网页“跨域”获取资源。但如…

yum安装mysql 数据库tab自动补全

centos7上面没有mysql,它的数据库名字叫做mariadb [rootlocalhost ~]#yum install mariadb-server -y [rootlocalhost ~]#systemctl start mariadb.service [rootlocalhost ~]#systemctl stop firewalld [rootlocalhost ~]#setenforce 0 [rootlocalhost ~]#ss -na…

Three.js点线几何空间图形代码

Three.js点线几何空间图形代码。效果如下 下载地址 Three.js点线几何空间图形代码