神经网络:卷积介绍及代码实现

在这里插入图片描述

传统卷积运算是将卷积核以滑动窗口的方式在输入图上滑动,当前窗口内对应元素相乘然后求和得到结果,一个窗口一个结果。相乘然后求和恰好也是向量内积的计算方式,所以可以将每个窗口内的元素拉成向量,通过向量内积进行运算,多个窗口的向量放在一起就成了矩阵,每个卷积核也拉成向量,多个卷积核的向量排在一起也成了矩阵,于是,卷积运算转化成了矩阵乘法运算。下图很好地演示了矩阵乘法的运算过程:

im2col

将卷积运算转化为矩阵乘法,从乘法和加法的运算次数上看,两者没什么差别,但是转化成矩阵后,运算时需要的数据被存在连续的内存上,这样访问速度大大提升(cache),同时,矩阵乘法有很多库提供了高效的实现方法,像BLAS、MKL等,转化成矩阵运算后可以通过这些库进行加速。

缺点呢?这是一种空间换时间的方法,消耗了更多的内存——转化的过程中数据被冗余存储。

代码实现

太久没写python代码,面试的时候居然想用c++来实现,其实肯定能实现,但是比起使用python复杂太多了,所以这里使用python中的numpy来实现。

一、滑动窗口版本实现(这个好理解)

import numpy as np# 为了简化运算,默认batch_size = 1
class my_conv(object):def __init__(self, input_data, weight_data, stride, padding = 'SAME'):self.input = np.asarray(input_data, np.float32)self.weights = np.asarray(weight_data, np.float32)self.stride = strideself.padding = paddingdef my_conv2d(self):"""self.input: c * h * w  # 输入的数据格式self.weights: c * h * w"""[c, h, w] = self.input.shape[kc, k, _] = self.weights.shape  # 这里默认卷积核的长宽相等assert c == kc  # 如果输入的channel与卷积核的channel不一致即报错output = []# 分通道卷积,最后再加起来for i in range(c):  f_map = self.input[i]kernel = self.weights[i]rs = self.compute_conv(f_map, kernel)if output == []:output = rselse:output += rsreturn output# padding和rs的宽高计算全部基于rs_h = (h - k + 2p)//s + 1def compute_conv(self, fm, kernel):[h, w] = fm.shape[k, _] = kernel.shapeif self.padding == 'SAME': # 知道rs_hw,求pad_hwrs_h = h // self.striders_w = w // self.stridepad_h = (self.stride * (rs_h - 1) + k - h) // 2pad_w = (self.stride * (rs_w - 1) + k - w) // 2elif self.padding == 'VALID': # 知道pad_hw,求rspad_h = 0pad_w = 0rs_h = (h - k) // self.stride + 1rs_w = (w - k) // self.stride + 1elif self.padding == 'FULL': # 知道pad_hw,求rs_hwpad_h = k - 1pad_w = k - 1rs_h = (h + 2 * pad_h - k) // self.stride + 1rs_w = (w + 2 * pad_w - k) // self.stride + 1padding_fm = np.zeros([h + 2 * pad_h, w + 2 * pad_w], np.float32)padding_fm[pad_h:pad_h+h, pad_w:pad_w+w] = fm  # 完成对fm的zeros paddingrs = np.zeros([rs_h, rs_w], np.float32)for i in range(rs_h):for j in range(rs_w):roi = padding_fm[i*self.stride:(i*self.stride + k), j*self.stride:(j*self.stride + k)]rs[i, j] = np.sum(roi * kernel) # np.asarray格式下的 * 是对应元素相乘return rsif __name__=='__main__':input_data = [[[1, 0, 1, 2, 1],[0, 2, 1, 0, 1],[1, 1, 0, 2, 0],[2, 2, 1, 1, 0],[2, 0, 1, 2, 0],],[[2, 0, 2, 1, 1],[0, 1, 0, 0, 2],[1, 0, 0, 2, 1],[1, 1, 2, 1, 0],[1, 0, 1, 1, 1],],]weight_data = [[[1, 0, 1],[-1, 1, 0],[0, -1, 0],],[[-1, 0, 1],[0, 0, 1],[1, 1, 1],]]conv = my_conv(input_data, weight_data, 1, 'SAME')print(conv.my_conv2d())

二、矩阵乘法版本实现

import numpy as np# 为了简化运算,默认batch_size = 1
class my_conv(object):def __init__(self, input_data, weight_data, stride, padding = 'SAME'):self.input = np.asarray(input_data, np.float32)self.weights = np.asarray(weight_data, np.float32)self.stride = strideself.padding = paddingdef my_conv2d(self):"""self.input: c * h * w  # 输入的数据格式self.weights: c * h * w"""[c, h, w] = self.input.shape[kc, k, _] = self.weights.shape  # 这里默认卷积核的长宽相等assert c == kc  # 如果输入的channel与卷积核的channel不一致即报错# rs_h与rs_w为最后输出的feature map的高与宽if self.padding == 'SAME':pad_h = (self.stride * (h - 1) + k - h) // 2pad_w = (self.stride * (w - 1) + k - w) // 2rs_h = hrs_w = welif self.padding == 'VALID':pad_h = 0pad_w = 0rs_h = (h - k) // self.stride + 1rs_w = (w - k) // self.stride + 1elif self.padding == 'FULL':pad_h = k - 1pad_w = k - 1rs_h = (h + 2 * pad_h - k) // self.stride + 1rs_w = (w + 2 * pad_w - k) // self.stride + 1# 对输入进行zeros padding,注意padding后依然是三维的pad_fm = np.zeros([c, h+2*pad_h, w+2*pad_w], np.float32)pad_fm[:, pad_h:pad_h+h, pad_w:pad_w+w] = self.input# 将输入和卷积核转化为矩阵相乘的规格mat_fm = np.zeros([rs_h*rs_w, kc*k*k], np.float32)mat_kernel = self.weightsmat_kernel.shape = (kc*k*k, 1) # 转化为列向量row = 0   for i in range(rs_h):for j in range(rs_w):roi = pad_fm[:, i*self.stride:(i*self.stride+k), j*self.stride:(j*self.stride+k)]mat_fm[row] = roi.flatten()  # 将roi扁平化,即变为行向量row += 1# 卷积的矩阵乘法实现rs = np.dot(mat_fm, mat_kernel).reshape(rs_h, rs_w) return rsif __name__=='__main__':input_data = [[[1, 0, 1, 2, 1],[0, 2, 1, 0, 1],[1, 1, 0, 2, 0],[2, 2, 1, 1, 0],[2, 0, 1, 2, 0],],[[2, 0, 2, 1, 1],[0, 1, 0, 0, 2],[1, 0, 0, 2, 1],[1, 1, 2, 1, 0],[1, 0, 1, 1, 1],],]weight_data = [[[1, 0, 1],[-1, 1, 0],[0, -1, 0],],[[-1, 0, 1],[0, 0, 1],[1, 1, 1],]]conv = my_conv(input_data, weight_data, 1, 'SAME')print(conv.my_conv2d())

参考资料

1、im2col:将卷积运算转为矩阵相乘
2、面试基础–深度学习 卷积及其代码实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/469986.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

波奇学Linux:文件系统

磁盘认识 磁盘被访问的基本单元是扇区-512字节。 磁盘可以看成多个同心圆,每个同心圆叫做磁道,多个扇区组成同心圆。 我们可以把磁盘看做由无数个扇区构成的存储介质。 要把数据存到磁盘,先定位扇区,用哪一个磁头,…

辗转相除法和同余原理

辗转相除法和同余原理 辗转相除法同余原理 辗转相除法 辗转相除法就是用来求出两个数的最大公约数的方法,那么这个方法怎么用呢?举个例子:有两个数,a12,b8,要求这两个数的最大公约数,首先让a%b得到4&#x…

【Linux笔记】进程间通信之管道

一、匿名管道 我们在之前学习进程的时候就知道了一个概念,就是进程间是互相独立的,所以就算是两个进程是父子关系,其中一个进程退出了也不会影响另一个进程。 也因为进程间是互相独立的,所以两个进程间就不能直接的传递信息或者…

使用一根网线,让Ubuntu和正点原子I.MX6ULL开发板互相ping通

1.硬件准备 准备一根网线即可 2. 让windows和I.MX6ULLping通 2.1 找根网线将I.MX6ULL和电脑连起来 2.2 让I.MX6ULL通电运行起来,我这里使用的是正点原子版本的内核、 2.3 进入电脑的网络连接后,按照如下步骤操作 2.4 将ip地址、子网掩码、默认网关…

【数据结构】二叉树的三种遍历

目录 一、数据结构 二、二叉树 三、如何遍历二叉树 一、数据结构 数据结构是计算机科学中用于组织和存储数据的方式。它定义了数据元素之间的关系以及对数据元素的操作。常见的数据结构包括数组、链表、栈、队列、树、图等。 数组是一种线性数据结构,它使用连续…

嵌入式系统的基础知识:了解嵌入式系统的构成和工作原理

(本文为简单介绍,个人观点仅供参考) 嵌入式系统是建立在微处理器基础上的计算机系统,用于对专门的功能进行控制、运算和接口。它结合了硬件和软件,可以提供实时的响应,广泛应用于工业控制、通信、医疗、交通等领域。 嵌入式系统的核心是微处理…

亿级推送,得物是怎么架构的?

说在前面 在40岁老架构师 尼恩的读者交流群(50)中,很多小伙伴拿到一线互联网企业如阿里、网易、有赞、希音、百度、滴滴的面试资格。 最近,尼恩指导一个小伙伴简历,需要织入亮点项目、黄金项目。 前段时间,指导小伙写了一个《高…

###51单片机学习(2)-----如何通过C语言运用延时函数设计LED流水灯

前言:感谢您的关注哦,我会持续更新编程相关知识,愿您在这里有所收获。如果有任何问题,欢迎沟通交流!期待与您在学习编程的道路上共同进步。 目录 一. 延时函数的生成 1.通过延时计算器得到延时函数 2.可赋值改变…

信息学奥赛一本通1314:【例3.6】过河卒(Noip2002)

1314:【例3.6】过河卒(Noip2002) 时间限制: 1000 ms 内存限制: 65536 KB 提交数: 40991 通过数: 17884 【题目描述】 棋盘上A点有一个过河卒,需要走到目标B点。卒行走的规则:可以向下、或者向右。同时在棋盘上的某一点有一个对方…

【Go语言】第一个Go程序

第一个 Go 程序 1 安装 Go Go语言官网:Download and install - The Go Programming Language,提供了安装包以及引导流程。 以 Windows 为例,进入windows安装包下载地址:All releases - The Go Programming Language&#xff0c…

TeamCity创建git项目Timed out 超时的一个解决办法

问题: 当自己: ping github.com从本地推送到远程仓库浏览器浏览www.github.com ——都没有问题 但是在teamcity创建工程的时候就超时: 或者多试几次,终于成功了,然后构建的时候半途超时报错。。。。。 一种解决办…

Mermaid绘制UML图教程

Mermaid 是一种轻量级的图形描述语言,用于绘制流程图、时序图、甘特图等各种图表。它采用简单的文本语法,使得用户能够快速绘制各种复杂图表,而无需深入学习图形绘制工具。 一、安装Mermaid Mermaid 可以在浏览器中直接使用,也可…