【复杂网络建模】——使用PyTorch和DGL库实现图神经网络进行链路预测

🤵‍♂️ 个人主页:@Lingxw_w的个人主页

✍🏻作者简介:计算机科学与技术研究生在读
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+ 

目录

1、常见的链路预测方法

2、图神经网络上的链路预测

3、使用PyTorch和DGL库实现图神经网络进行链路预测


链路预测是指在一个给定的网络中,根据已有的网络结构信息,尝试预测两个节点之间是否存在连接或者可能会建立连接的概率。这在社交网络分析、生物信息学、推荐系统等领域中都有广泛的应用。

在复杂网络中,链路预测可以帮助我们理解网络的演化过程、发现隐藏的关系和未知的连接,以及预测未来的网络演化趋势。

1、常见的链路预测方法

  1. 基于相似性的方法:这类方法假设具有相似性的节点之间更有可能存在连接。常见的相似性度量方法包括共同邻居数、Jaccard系数、Adamic/Adar指数等。

  2. 基于路径的方法:这类方法考虑节点之间的路径信息,比如最短路径、随机游走路径等。通过分析节点之间的路径特征,可以预测节点间的连接概率。

  3. 基于机器学习的方法:这类方法使用机器学习算法来建模和预测网络中的链路。常见的机器学习算法包括决策树、随机森林、支持向量机(SVM)、神经网络等。

  4. 基于深度学习的方法:这是近年来兴起的一种方法,使用深度学习模型(如图神经网络)来学习节点的表征,并通过这些表征来进行链路预测。

链路预测并非一种绝对准确的预测方法,因为网络的演化和连接行为具有一定的随机性。 

2、图神经网络上的链路预测

图神经网络(Graph Neural Networks,简称GNN)可以用于链路预测任务。GNN是一类专门用于处理图结构数据的深度学习模型,能够学习节点和边的特征表示,并在此基础上进行预测任务。

步骤:

  1. 图表示构建:首先,将原始的网络数据表示为图结构,其中节点表示网络中的实体(如用户、物品),边表示节点之间的连接关系(如关注、交互)。

  2. 节点表征学习:GNN通过多轮的消息传递和聚合操作,从节点和边的特征中学习节点的表征。这样,每个节点都会得到一个向量表示,用于捕捉其在网络中的特征和上下文信息。

  3. 边预测模型构建:在节点表征学习的基础上,可以构建一个边预测模型来预测节点之间的连接概率。一种常见的方法是使用一个全连接层或多层感知机(MLP)来将节点表征映射到一个预测分数或概率。可以使用二元分类任务来预测节点间是否存在连接,或者使用回归任务来预测连接的强度或权重。

  4. 模型训练和评估:使用已知的网络结构数据进行模型的训练,并通过验证集或交叉验证进行模型的选择和调优。评估时,可以使用一些常见的指标,如准确率、精确度、召回率、F1分数等来评估链路预测的性能。

3、使用PyTorch和DGL库实现图神经网络进行链路预测

导入必要的库,包括PyTorch和DGL。

import torch
import torch.nn as nn
import dgl

定义图神经网络模型 GNNLinkPredict,模型包含两个图卷积层,输入特征维度为2,输出特征维度为1。

# 定义图神经网络模型
class GNNLinkPredict(nn.Module):def __init__(self, in_feats, hidden_size, out_feats):super(GNNLinkPredict, self).__init__()self.conv1 = dgl.nn.GraphConv(in_feats, hidden_size)self.conv2 = dgl.nn.GraphConv(hidden_size, out_feats)def forward(self, g, features):x = torch.relu(self.conv1(g, features))x = torch.relu(self.conv2(g, x))return x

创建示例图数据 g,其中包括5个节点和7条边。定义节点特征 features,每个节点有两个特征值。定义标签 labels,表示边的连接情况。

# 构建示例图数据
# 创建一个有向图
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 4, 3, 4])# 定义节点特征
features = torch.tensor([[0.2, 0.4],[0.3, 0.5],[0.4, 0.6],[0.5, 0.7],[0.6, 0.8]
])# 定义标签(边是否存在连接)
labels = torch.tensor([1, 1, 1, 0, 0, 1, 0], dtype=torch.float32)

划分训练集和测试集,使用布尔类型的掩码 train_masktest_mask 表示。

# 划分训练集和测试集
train_mask = torch.tensor([True, True, True, False, False])
test_mask = torch.tensor([False, False, False, True, True])

创建图神经网络模型实例 model

定义优化器损失函数,这里使用Adam优化器和二分类的交叉熵损失函数。

# 创建图神经网络模型
model = GNNLinkPredict(in_feats=2, hidden_size=16, out_feats=1)# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()

进行模型训练。循环迭代多个epoch,在每个epoch中执行以下步骤

  • 将模型设置为训练模式 model.train()
  • 前向传播计算预测结果 logits
  • 计算预测结果与标签之间的损失。
  • 清空优化器的梯度。
  • 反向传播计算梯度。
  • 更新模型参数。
# 训练模型
for epoch in range(50):model.train()logits = model(g, features)pred = logits.squeeze()loss = criterion(pred[train_mask], labels[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()# 打印训练损失print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")

在测试集上评估模型。将模型设置为评估模式 model.eval(),然后使用训练好的模型对测试集进行预测。通过将预测结果应用sigmoid函数将其映射到0-1之间,并使用四舍五入将其转换为0或1的预测标签。计算预测准确率并输出。

# 在测试集上评估模型
model.eval()
with torch.no_grad():logits = model(g, features)pred = logits.squeeze()pred = torch.sigmoid(pred)  # 使用sigmoid函数将预测值映射到0-1之间pred_labels = torch.round(pred)  # 四舍五入为0或1的预测标签accuracy = (pred_labels[test_mask] == labels[test_mask]).float().mean()print(f"Accuracy: {accuracy.item()}")

汇总的代码:

# https://www.dgl.ai/pages/start.htmlimport torch
import torch.nn as nn
import dgl# 定义图神经网络模型
class GNNLinkPredict(nn.Module):def __init__(self, in_feats, hidden_size, out_feats):super(GNNLinkPredict, self).__init__()self.conv1 = dgl.nn.GraphConv(in_feats, hidden_size)self.conv2 = dgl.nn.GraphConv(hidden_size, out_feats)def forward(self, g, features):x = torch.relu(self.conv1(g, features))x = torch.relu(self.conv2(g, x))return x# 构建示例图数据
# 创建一个有向图
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 4, 3, 4])# 添加自环
g = dgl.add_self_loop(g)# 定义节点特征
features = torch.tensor([[0.2, 0.4],[0.3, 0.5],[0.4, 0.6],[0.5, 0.7],[0.6, 0.8]
])# 定义标签(边是否存在连接)
labels = torch.tensor([1, 1, 1, 0, 0, 1, 0], dtype=torch.float32)# 划分训练集和测试集
train_mask = torch.tensor([True, True, True, False, False, False, False])
test_mask = torch.tensor([False, False, False, True, True, True, True])# 创建图神经网络模型
model = GNNLinkPredict(in_feats=2, hidden_size=16, out_feats=1)# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()# 训练模型
for epoch in range(50):model.train()logits = model(g, features)pred = logits.squeeze()loss = criterion(pred[train_mask], labels[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()# 打印训练损失print(f"Epoch: {epoch + 1}, Loss: {loss.item()}")# 在测试集上评估模型
model.eval()
with torch.no_grad():logits = model(g, features)pred = logits.squeeze()pred = torch.sigmoid(pred)  # 使用sigmoid函数将预测值映射到0-1之间pred_labels = torch.round(pred)  # 四舍五入为0或1的预测标签accuracy = (pred_labels[test_mask] == labels[test_mask]).float().mean()print(f"Accuracy: {accuracy.item()}")

 留下个问题有空再解决。

关于复杂网络建模,我前面写了很多,大家可以学习参考。

【复杂网络建模】——常用绘图软件和库_图论画图软件

【复杂网络建模】——Pytmnet进行多层网络分析与可视化

【复杂网络建模】——Python通过平均度和随机概率构建ER网络

【复杂网络建模】——通过图神经网络来建模分析复杂网络

【复杂网络建模】——Python可视化重要节点识别(PageRank算法)

【复杂网络建模】——基于Pytorch构建图注意力网络模型

【复杂网络建模】——基于微博数据的影响力最大化算法(PageRank)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/5431.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

05-Redis初步使用

关系型数据的ACID特性:事务的四大特性:原子性,一致性,隔离性,持久性 关系型数据库应对的三高问题:高并发,高效率,高扩展 关系型数据库和非关系型数据库 关系型数据库的数据存储在表中,无法应对陡增的数据 非关系型数据库使用键值对的方式进行存储数据:redis可以用作缓存 r…

【运维工程师学习二】OS系统管理

【运维工程师学习二】OS系统管理 1、操作系统管理2、进程管理3、进程的启动4、进程信息的查看4.1、STAT 进程的状态:进程状态使用字符表示的(STAT的状态码),其状态码对应的含义:4.2、ps命令常用用法(方便查看系统进程&…

go语言环境安装

文章目录 环境介绍安装软件包步骤环境变量设置来一个经典的hello worldNice 最近的项目需要用到go来开发了,前几天就已经在看书了,今天是个周末,先在家里的机器上把环境搭好,特此记录一下。 环境介绍 下载地址:https:…

oracle 过滤字段中的中文,不再洋不洋土不土

目录 前言: 一、知己知彼 1.1业务场景 1.2错误案例 二、思路整理 2.1存储长度与字符串长度比较 三、还有没有其他思路 3.1ascii表查找法 3.2正式案例 四、总结 前言: 随着数字化建设的不断深入,企业越来越注重,企业数据治理&am…

算法笔记——排序算法

👌,begin: 排序算法很重要,它可以使数据按照一定的规律进行排序,各个语言的代码都有自己的排序函数,那么排序到底有哪几种方法,✌,如下: 按照效率分类如上图&#xff1a…

【Spring】设计思想

一、Spring 是什么? Spring是一个开源的Java框架,有着活跃而庞大的社区(例如:Apache),Spring 提供了一系列的工具和库,可以帮助开发者构建高效、可靠、易于维护的企业级应用程序。Spring的核心…

短信压力测试系统,支持自定义接口

短信压力测试系统,支持自定义接口 支持卡密充值,短信压力测试系统,解决一切骚扰电话,教程在压缩包里面 可多个服务器挂脚本分担压力,套了cdn导致无法正常执行脚本可以尝试添加白名单 这边建议使用MySQL方式 同服务器下直接配置…

lesson 12 Zigbee绑定通信

目录 Zigbee绑定通信 通信原理 实验过程 实现步骤 实验现象 实验分析 Zigbee绑定通信 通信原理 1、Zigbee一共有五种通信方式:单播、广播、组播、MAC、广播 2、绑定是Zigbee的一种基本通信方式,具体绑定通信又分为三种模式,模式大同…

tomcat概述,优化,多实例部署

目录 一、概述 二、三个容器 1、Web 容器: 2、Servlet 容器: 3、JSP 容器: 三、Tomcat 功能组件结构 四、优化 1、启动速度优化 2、配置参数优化 五、多实例部署 一、概述 Tomcat 是 Java 语言开发的,Tomcat 服务器是一…

如何建立自己的知识体系?202209

知识太多了,无法全部快速吸收进大脑,需要通过特定的方法、技能,在面对大量知识的情况下,快速梳理,构建自己的知识体系。 学习的目标,不仅仅是记忆知识,而是搜索知识、并过滤、洞察、理解、使用…

搭建Hadoop高可用框架分布式集群

搭建Hadoop高可用框架分布式集群 一.基础配置 1.创建虚拟机,修改虚拟机的主机名 2.修改网络配置 master:192.168.6.200 slave1:192.168.6.201 slave2:192.168.6.202 3.互ping测试 4.sudo授权 5.安装vim编辑器 6.配置网络映射 master配置映射 master向slave1传递映…

大数据Doris(五十二):Doris数据导出案例和注意事项

文章目录 Doris数据导出案例和注意事项 一、Doris数据导出到HDFS案例 1、创建Doris表并插入数据 2、创建Export ,数据导出到 HDFS 3、查看任务 4、查看导出结果 二、Doris数据导出到本地案例 1、配置 fe.conf 2、Doris 数据导出到本地 三、注意事项 Doris数据导出案例…