DGL在异构图上的GraphConv模块

回顾同构图GraphConv模块

首先回顾一下同构图中实现GraphConv的主要思路(以GraphSAGE为例):
在初始化模块首先是获取源节点和目标节点的输入维度,同时获取输出的特征维度。根据SAGE论文提出的三种聚合操作,需要获取所使用的聚合类型,方便后面使用Pytorch中的nn模块实现。最后是特征归一化操作。
其具体的代码段为:

获取相关输入特征

        # 获取源节点和目标节点的输入特征维度self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)# 输出特征维度self._out_feats = out_featsself._aggre_type = aggregator_typeself.norm = normself.activation = activation

根据聚合类型选择Pytorch对应的nn模块中的函数

        # 聚合类型:mean、pool、lstm、gcnif aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))if aggregator_type == 'pool':self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)if aggregator_type == 'lstm':self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)if aggregator_type in ['mean', 'pool', 'lstm']:self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)

权重初始化

构造函数的最后调用了 reset_parameters() 进行权重初始化。

def reset_parameters(self):"""重新初始化可学习的参数"""gain = nn.init.calculate_gain('relu')if self._aggre_type == 'pool':nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)if self._aggre_type == 'lstm':self.lstm.reset_parameters()if self._aggre_type != 'gcn':nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2

forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:

  1. 检测输入图对象是否符合规范。
  2. 消息传递和聚合
  3. 聚合后,更新特征作为输出。

检测输入图对象的规范性

# 输入图对象的规范检测
with graph.local_scope():# 指定图类型,然后根据图类型扩展输入特征feat_src, feat_dst = expand_as_pair(feat, graph)

对于expand_as_pair()函数,其实现的操作是如果输入的特征不是一对的话(源节点和目标节点),就根据图Graph将特征变成一对,但要求图必须是一个block,其对应的源码为:

def expand_as_pair(input_, g=None):"""Return a pair of same element if the input is not a pair.如果输入不是一对,则返回相同元素的一对。If the graph is a block, obtain the feature of destination nodes from the source nodes.如果图是块,则从源节点中获取目的节点的特征。Parameters----------input_ : Tensor, dict[str, Tensor], or their pairsThe input featuresg : DGLGraph or NoneThe graph.If None, skip checking if the graph is a block.Returns-------tuple[Tensor, Tensor] or tuple[dict[str, Tensor], dict[str, Tensor]]The features for input and output nodes输入和输出节点的特性"""if isinstance(input_, tuple):return input_elif g is not None and g.is_block:if isinstance(input_, Mapping):input_dst = {k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))for k, v in input_.items()}else:input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())return input_, input_dstelse:return input_, input_

消息传递和聚合

聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() APIDGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。

        # 消息传递和聚合if self._aggre_type == 'mean':graph.srcdata['h'] = feat_srcgraph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))h_neigh = graph.dstdata['neigh']elif self._aggre_type == 'gcn':check_eq_shape(feat)graph.srcdata['h'] = feat_srcgraph.dstdata['h'] = feat_dstgraph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))# 除以入度degs = graph.in_degrees().to(feat_dst)h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)elif self._aggre_type == 'pool':graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))h_neigh = graph.dstdata['neigh']else:raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

如果是gcn聚合方式的话还需要用到它自身的特征,但是SAGE不需要,它只需要聚合邻居的特征,这里通过一条判断语句加以区分:

        # GraphSAGE中gcn聚合不需要fc_selfif self._aggre_type == 'gcn':rst = self.fc_neigh(h_neigh)else:rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

更新特征

聚合后,更新特征作为输出——forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

        # 更新特征作为输出# 激活函数if self.activation is not None:rst = self.activation(rst)# 归一化if self.norm is not None:rst = self.norm(rst)return rst

异构图GraphConv模块

DGL提供了 HeteroGraphConv,用于定义异构图上GNN模块。 实现逻辑与消息传递级别的API multi_update_all() 相同,它包括:

  • 每个关系上的DGL NN模块。
  • 聚合来自不同关系上的结果。
    其对应的数学公式为:(r表示关系)

在这里插入图片描述

__ init __函数

异构图的卷积操作接受一个字典类型参数 mods。这个字典的键为关系名,值为作用在该关系上NN模块对象。参数 aggregate 则指定了如何聚合来自不同关系的结果。

class HeteroGraphConv(nn.Module):def __init__(self, mods, aggregate='sum'):super(HeteroGraphConv, self).__init__()self.mods = nn.ModuleDict(mods)if isinstance(aggregate, str):# 获取聚合函数的内部函数self.agg_fn = get_aggregate_fn(aggregate)else:self.agg_fn = aggregate

nn.ModuleDict() 用于保存字典中的子模块。Pytorch官方也给出了对应的示例:

class MyModule(nn.Module):def __init__(self):super().__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})self.activations = nn.ModuleDict([['lrelu', nn.LeakyReLU()],['prelu', nn.PReLU()]])def forward(self, x, choice, act):x = self.choices[choice](x)x = self.activations[act](x)return x

forward函数

对于前向传播函数,除了需要输入图和输入张量以外,它还需要2个额外的字典参数mod_argsmod_kwargs。这2个字典与 self.mods 具有相同的键,值则为对应NN模块自定义参数
forward() 函数的输出结果也是一个字典类型的对象。其键为 nty,其值为每个目标节点类型 nty 的输出张量的列表, 表示来自不同关系的计算结果HeteroGraphConv 会对这个列表进一步聚合,并将结果返回给用户。聚合操作主要是:

if g.is_block:src_inputs = inputsdst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:src_inputs = dst_inputs = inputsfor stype, etype, dtype in g.canonical_etypes:rel_graph = g[stype, etype, dtype]if rel_graph.num_edges() == 0:continueif stype not in src_inputs or dtype not in dst_inputs:continuedstdata = self.mods[etype](rel_graph,(src_inputs[stype], dst_inputs[dtype]),*mod_args.get(etype, ()),**mod_kwargs.get(etype, {}))outputs[dtype].append(dstdata)

输入 g 可以是异构图或来自异构图的子图区块。和普通的NN模块一样,forward() 函数需要分别处理不同的输入图类型

上述代码中的for循环为处理异构图计算的主要逻辑

  • 首先我们遍历图中所有的关系(通过调用 canonical_etypes)。
  • 通过关系名,我们可以使用g[ stype, etype, dtype ]的语法将只包含该关系的子图( rel_graph )抽取出来。
  • 对于二分图,输入特征将被组织为元组 (src_inputs[stype], dst_inputs[dtype])
  • 接着调用用户预先注册在该关系上的NN模块,并将结果保存在outputs字典中。

最后,HeteroGraphConv 会调用用户注册的 self.agg_fn 函数聚合来自多个关系的结果。

rsts = {}
for nty, alist in outputs.items():if len(alist) != 0:rsts[nty] = self.agg_fn(alist, nty)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/215776.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HCIA-RS基础:动态路由协议基础

摘要:本文介绍动态路由协议的基本概念,为后续动态路由协议原理课程提供基础和引入。主要讲解常见的动态路由协议、动态路由协议的分类,以及路由协议的功能和自治系统的概念。文章旨在优化标题吸引力,并通过详细的内容夯实读者对动…

人工智能教程(二):人工智能的历史以及再探矩阵

目录 前言 更多矩阵的知识 Pandas 矩阵的秩 前言 在上一章中,我们讨论了人工智能、机器学习、深度学习、数据科学等领域的关联和区别。我们还就整个系列将使用的编程语言、工具等做出了一些艰难的选择。最后,我们还介绍了一点矩阵的知识。在本文中&am…

蓝桥杯物联网竞赛_STM32L071_4_按键控制

原理图: 当按键S1按下PC14接GND,为低电平 CubMX配置: Keil配置: main函数: while (1){/* USER CODE END WHILE */OLED_ShowString(32, 0, "hello", 16);if(Function_KEY_S1Check() 1){ OLED_ShowString(16, 2, &quo…

shrio----(1)基础

文章目录 前言 一、Shrio1、什么是shiro2、为什么使用shrio 二、主要类2.1、Subject2.2、SecurityManager2.3、Realms 三、认证授权3.1、认证(Authentication)3.2、授权(authorization)四、入门示例参考文章 前言 简单入门介绍 一、Shrio http://shir…

动态规划 之 钢条切割

自顶向下递归实现(Recursive top-down implementation) 程序CUT-ROD对等式(14.2)进行了实现,伪代码如下: CUT-ROD(p, n)if n 0return 0q -∞for i 1 to nq max{q, p[i] CUT-ROD(p, n - i)}return q上面解决中重复对一个子结构问题重复求解了&#…

java中关键字 volatile 和 synchronized 有什么区别

java中 volatile 和 synchronized 有什么区别?

nodejs微信小程序+python+PHP-青云商场管理系统的设计与实现-安卓-计算机毕业设计

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

微机原理_3

一、单项选择题(本大题共15小题,每小题3分,共45分。在每小题给出的四个备选项中,选出一个正确的答案,请将选定的答案填涂在答题纸的相应位置上。) 在 8086 微机系统中,完成对指令译码操作功能的部件是()。 A. EU B. BIU C. SRAM D. DRAM 使计算机执行某…

企业海外分部,如何实现安全稳定的跨境网络互连?

如今,众多企业广泛采取数字化业务系统,如OA、ERP及CRM等,来提升其业务运营效率。同时,私有云与公有云混合架构也逐渐普及化。 具体来说,很多企业选择将研发系统部署在公司本地的私有云环境,以此确保数据安全…

nvm安装及使用

文章目录 一、[介绍](https://github.com/nvm-sh/nvm)1.1、卸载node1.1.1、从控制面板的程序卸载node1.1.2、删除node的安装目录1.1.3、查找.npmrc文件删除1.1.4、逐一删除下列文件1.1.5、删除node环境变量1.1.6、验证是否卸载成功 二、安装2.1、window系统2.2、mac系统2.2.1、…

LeetCode二叉树小题目

Q1将有序数组转换为二叉搜索树 题目大致意思就是从一个数组建立平衡的二叉搜索树。由于数组以及进行了升序处理,我们只要考虑好怎么做到平衡的。平衡意味着左右子树的高度差不能大于1。由此我们可以想着是否能用类似二分递归来解决。 如果left>right,直接返回nul…

web前端之引入svg图片、html引入点svg文件、等比缩放、解决裁剪问题、命名空间、object标签、阿里巴巴尺量图、embed标签、iframe标签

MENU 前言直接在页面编写svg使用img标签引入通过css引入使用object标签引入其他标签参考资料 前言 web应用开发使用svg图片的方式,有如下几种方式 1、直接在页面编写svg 2、使用img标签引入 3、通过css引入 4、使用object标签引入 直接在页面编写svg 在html页面直接…