DGL中NN模块的构造函数

在这里插入图片描述
上图引用自:dgl用户文档第三章(nn模块编写)

"""构造函数完成以下几个任务:
1、设置选项。
2、注册可学习的参数或者子模块。
3、初始化参数。"""
import torch.nn as nn
from dgl.utils import expand_as_pair
import dgl.nn
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
"""
在构造函数中,用户首先需要设置数据的维度。
对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 
对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。
对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型
包括 mean、 sum、 max 和 min。一些模块可能会使用更加复杂的聚合函数,
比如 lstm。
"""
"""注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear、 nn.LSTM 等。"""class SAGE(nn.Module):def __init__(self, in_feats, out_feats, aggregator_type,bias=True, norm=None, activation=None):super(SAGE, self).__init__()# 获取源节点和目标节点的输入特征维度self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)# 输出特征维度self._out_feats = out_featsself._aggre_type = aggregator_typeself.norm = normself.activation = activation# 聚合类型:mean、pool、lstm、gcnif aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))if aggregator_type == 'pool':self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)if aggregator_type == 'lstm':self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)if aggregator_type in ['mean', 'pool', 'lstm']:self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)self.reset_parameters()#  构造函数的最后调用了 reset_parameters() 进行权重初始化。def reset_parameters(self):"""重新初始化可学习的参数"""gain = nn.init.calculate_gain('relu')if self._aggre_type == 'pool':nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)if self._aggre_type == 'lstm':self.lstm.reset_parameters()if self._aggre_type != 'gcn':nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2def forward(self, graph, feat):    #SAGEConv示例中的 forward() 函数# 输入图对象的规范检测with graph.local_scope():# 指定图类型,然后根据图类型扩展输入特征feat_src, feat_dst = expand_as_pair(feat, graph)# 消息传递和聚合if self._aggre_type == 'mean':graph.srcdata['h'] = feat_srcgraph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))h_neigh = graph.dstdata['neigh']elif self._aggre_type == 'gcn':check_eq_shape(feat)graph.srcdata['h'] = feat_srcgraph.dstdata['h'] = feat_dstgraph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))# 除以入度degs = graph.in_degrees().to(feat_dst)h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)elif self._aggre_type == 'pool':graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))h_neigh = graph.dstdata['neigh']else:raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))# GraphSAGE中gcn聚合不需要fc_selfif self._aggre_type == 'gcn':rst = self.fc_neigh(h_neigh)else:rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)# 更新特征作为输出# 激活函数if self.activation is not None:rst = self.activation(rst)# 归一化if self.norm is not None:rst = self.norm(rst)return rst
"""
在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,
DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:
1、检测输入图对象是否符合规范。
2、消息传递和聚合。
3、聚合后,更新特征作为输出。forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 
比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入
度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是
,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward() 函数的输
出不会全为0。在这种情况下,无需进行此类检验。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图和子图块。聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息
传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的
消息传递代码 里所介绍的性能优化。聚合后,更新特征作为输出
forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/222731.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

优思学院|5S不只是清洁,但却离不开清洁!

很多说5S不止是清洁和搞卫生那么简单,相信有正规地学习过5S的人都应该深切了解。 不过,5S之中的确包括了清理、清洁的步骤,5S,也被称为“五常法则”或“五常法”,它包含了: 整理(SEIRI&#x…

数据结构(超详细讲解!!)第二十五节 线索二叉树

1.线索二叉树的定义和结构 问题的提出: 通过遍历二叉树可得到结点的一个线性序列,在线性序列中,很容易求得某个结点的直接前驱和后继。但是在二叉树上只能找到结点的左孩子、右孩子,结点的前驱和后继只有在遍历过程中才能得到…

Python基础:字符串(String)详解

1. 字符串定义 在Python中,字符串是一种数据类型,用于表示文本数据。字符串是由字符组成的序列,可以包含字母、数字、符号和空格等字符。在Python中,你可以使用单引号()或双引号("&#xf…

四、Lua循环

文章目录 一、while(循环条件)二、for(一)数值for(二)泛型for(三)repeat util 既然同为编程语言,那么控制逻辑里的循环就不能缺少,它可以帮助我们实现有规律的重复操作,而…

可移动框 弹窗 可拖拽的组件

电脑端: <template><divv-if"show"ref"infoBox"mousedown.stop"mouseDownHandler"class"info-box":style"styleObject"><slot></slot></div> </template> <script> export defa…

电脑如何定时关机?

电脑如何定时关机&#xff1f;我承认自己是个相当粗心的人&#xff0c;尤其是在急于离开时经常会忘记关闭电脑&#xff0c;结果就是电量耗尽&#xff0c;导致电脑自动关机。而且&#xff0c;在我使用电脑的时候&#xff0c;经常需要进行软件下载、更新等任务。如果我一直坐等任…

387. 字符串中的第一个唯一字符

这篇文章会收录到 :算法通关村第十二关-白银挑战字符串经典题目-CSDN博客 387. 字符串中的第一个唯一字符 描述 : 给定一个字符串 s &#xff0c;找到 它的第一个不重复的字符&#xff0c;并返回它的索引 。如果不存在&#xff0c;则返回 -1 。 题目 : 387. 字符串中的第一…

NFTScan | 11.20~11.26 NFT 市场热点汇总

欢迎来到由 NFT 基础设施 NFTScan 出品的 NFT 生态热点事件每周汇总。 周期&#xff1a;2023.11.20~ 2023.11.26 NFT Hot News 01/ OKX Ordinals 市场 API 完成升级 11 月 21 日&#xff0c;OKX Ordinals 市场 API 现已完成升级&#xff0c;新增支持按币种单价查询、排序&…

代码签名证书是如何保护软件?

随着互联网的普及和技术的发展&#xff0c;软件开发已经成为了一个非常重要的行业。然而&#xff0c;软件安全问题也日益凸显&#xff0c;恶意软件、病毒、木马等威胁着用户的数据安全和隐私。为了确保软件的安全和可靠性&#xff0c;开发者需要采取一系列措施来保护他们的产品…

西北大学计算机844考研-23年计网计算题详细解析

西北大学计算机844考研-23年计网计算题详细解析 1.计算无传输差错状态下停止—等待ARQ协议效率,电磁波传播速率为2*10^8m/s&#xff0c;链路长为2000m&#xff0c;帧长度为1000比特&#xff0c;计算传输速率10kbps及10Mbps时的协议效率&#xff08;即信道利用率&#xff09; …

什么是美颜sdk?视频直播美颜sdk技术深度剖析

美颜sdk可以通过实时处理图像&#xff0c;提升主播或用户在视频直播中的外观。通过美颜sdk接口调用可以轻松实现美颜效果。美颜sdk的核心目标是在保持图像真实性的同时&#xff0c;为用户创造出最理想的美化效果。 一、美颜sdk的技术实现 1.面部识别技术&#xff1a;美颜sdk…

tabs切换,组件库framework7

IOS和安卓兼容的背景下&#xff0c; 可以使用&#xff1a;framework7.io文档 效果展示&#xff1a; 代码&#xff1a; <!-- Top Tabs --> <div class"tabs tabs-top"><div class"tab tab1 active">...</div><div class"…