Transformer中的FeedForward

Transformer中的FeedForward

flyfish

class PoswiseFeedForwardNet(nn.Module):def __init__(self, d_ff=2048):super(PoswiseFeedForwardNet, self).__init__()# 定义一维卷积层 1,用于将输入映射到更高维度self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=d_ff, kernel_size=1)# 定义一维卷积层 2,用于将输入映射回原始维度self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_embedding, kernel_size=1)# 定义层归一化self.layer_norm = nn.LayerNorm(d_embedding)def forward(self, inputs): #------------------------- 维度信息 -------------------------------- # inputs [batch_size, len_q, embedding_dim]#----------------------------------------------------------------                       residual = inputs  # 保留残差连接 # 在卷积层 1 后使用 ReLU 激活函数 output = nn.ReLU()(self.conv1(inputs.transpose(1, 2))) #------------------------- 维度信息 -------------------------------- # output [batch_size, d_ff, len_q]#----------------------------------------------------------------# 使用卷积层 2 进行降维 output = self.conv2(output).transpose(1, 2) #------------------------- 维度信息 -------------------------------- # output [batch_size, len_q, embedding_dim]#----------------------------------------------------------------# 与输入进行残差链接,并进行层归一化output = self.layer_norm(output + residual) #------------------------- 维度信息 -------------------------------- # output [batch_size, len_q, embedding_dim]#----------------------------------------------------------------return output # 返回加入残差连接后层归一化的结果

PoswiseFeedForwardNe继承自PyTorch的nn.Module类。该网络包含两个一维卷积层和一个层归一化操作。在前向传播过程中,该网络首先对输入进行卷积操作,然后通过ReLU激活函数进行非线性变换,接着进行降维操作,并与原始输入进行残差连接,最后通过层归一化得到输出。

具体来说,函数初始化时,根据参数d_ff设置两个一维卷积层的输出维度。在前向传播过程中,输入的维度为[batch_size, len_q, embedding_dim],首先将输入保留为残差连接的副本。然后,通过一维卷积层1将输入映射到更高维度,并在卷积层1后使用ReLU激活函数。接着,通过一维卷积层2将输出映射回原始维度。最后,将卷积层2的输出与原始输入进行残差连接,并通过层归一化操作得到最终输出。
在这里插入图片描述
上述部分在整体的位置
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/520108.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

List之ArrayList、LinkedList深入分析

集合 Java 集合, 也叫作容器,主要是由两大接口派生而来:一个是 Collection接口,主要用于存放单一元素;另一个是 Map 接口,主要用于存放键值对。对于Collection 接口,下面又有三个主要的子接口&…

Spring揭秘:ApplicationContextAware应用场景及实现原理!

内容概要 ApplicationContextAware接口能够轻松感知并在Spring中获取应用上下文,进而访问容器中的其他Bean和资源,这增强了组件间的解耦,了代码的灵活性和可扩展性,是Spring框架中实现高级功能的关键接口之一。 核心概念 它能用…

【深度学习笔记】优化算法——梯度下降

梯度下降 🏷sec_gd 尽管梯度下降(gradient descent)很少直接用于深度学习, 但了解它是理解下一节随机梯度下降算法的关键。 例如,由于学习率过大,优化问题可能会发散,这种现象早已在梯度下降中…

遥感与ChatGPT:科研中的强强联合

随着科技的飞速发展,人工智能(AI)已逐渐渗透到各个领域,为传统行业带来了前所未有的变革。其中,遥感技术作为观测和解析地球的重要手段,正逐渐与AI技术相结合,为地球科学研究与应用提供了全新的…

14:00面试,15:00就出来了,问的问题过于变态了。。。

从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到2月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%…

【操作系统概念】 第4章:线程

文章目录 0.前言4.1 概述4.1.1 多线程编程的优点 4.2 多线程模型4.2.1 多对一模型4.2.2 一对一模型4.2.3 多对多模型 4.3 线程库4.4 多线程问题4.4.1 系统调用fork()和exec()4.4.2 取消4.4.3 信号处理4.4.4 线程池4.4.5 线程特定数据 0.前言 第3章讨论的进程模型假设每个进程是…

Linux——进程控制(三)进程程序替换

目录 前言 一、进程程序替换 二、execl 三、多进程版execl 四、exec相关函数 1.execlp 2.execv 3.execvp 五、替换自己写的程序 六、替换其他语言程序 七、execle 前言 之前,我们学习了进程的fork创建,进程的等待,执行的代码都是…

解决:ModuleNotFoundError: No module named ‘paddle‘

错误显示: 原因: 环境中没有‘paddle’的python模块,但是您在尝试导入 解决方法: 1.普通方式安装: pip install paddlepaddle #安装命令 2.镜像源安装 pip install paddlepaddle -i https://pypi.tuna.tsinghua.e…

第18课:让客户看了就满意的商业软文是如何练成的?

选品上的注意事项 结合影视热点 通过追影视热点,找出能够跟产品贴合的点。在前面先道出痛点,痛点越深刻,用户对产品的过度才会更自然。 用户体验 真实体验才能真正写得出来。 结合时事热点 用的少,赶上了用就会效果很好&#xf…

C# Mel-Spectrogram 梅尔频谱

目录 介绍 Main features Philosophy of NWaves 效果 项目 代码 下载 C# Mel-Spectrogram 梅尔频谱 介绍 利用NWaves实现Mel-Spectrogram 梅尔频谱 NWaves github 地址:https://github.com/ar1st0crat/NWaves NWaves is a .NET DSP library with a lot …

Postman 接口自动化测试教程:入门介绍和从 0 到 1 搭建 Postman 接口自动化测试项目

关于Postman接口自动化测试的导引,全面介绍入门基础和从零开始搭建项目的步骤。学习如何有效地使用Postman进行API测试,了解项目搭建的基础结构、环境设置和测试用例的编写。无论您是新手还是经验丰富的测试人员,这篇教程都将为您提供清晰的指…

代码第二十四天-寻找旋转排序数组中的最小值Ⅱ

寻找旋转排序数组中的最小值Ⅱ 题目要求 解题思路 二分法 当遇到两个left、right两个位置值相同时候&#xff0c;可以选择将 right right-1 代码 class Solution:def findMin(self, nums: List[int]) -> int:left,right0,len(nums)-1while left<right:mid(leftright…