图神经网络GNN(一)GraphEmbedding

DeepWalk


使用随机游走采样得到每个结点x的上下文信息,记作Context(x)。
SkipGram优化的目标函数:P(Context(x)|x;θ)
θ = argmax P(Context(x)|x;θ)
DeepWalk这种GraphEmbedding方法是一种无监督方法,个人理解有点类似生成模型的Encoder过程,下面的代码中,node_proj是一个简单的线性映射函数,加上elu激活函数,可以看作Encoder的过程。Encoder结束后就得到了Embedding后的隐变量表示。其实GraphEmbedding要的就是这个node_proj,但是由于没有标签,只有训练数据的内部特征,怎么去训练呢?这就需要看我们的训练任务了,个人理解,也就是说,这种无监督的embedding后的结果取决于你的训练任务,也就是Decoder过程。Embedding后的编码对Decoder过程越有利,损失函数也就越小,编码做的也就越好。在word2vec中,有两种训练任务,一种是给定当前词,预测其前两个及后两个词发生的条件概率,采用这种训练任务做出的embedding就是skip-gram;还有一种是给定当前词前两个及后两个词,预测当前词出现的条件概率,采用这种训练任务做出的embedding就是CBOW.DeepWalk作者的论文中采用的是skip-gram。故复现也采用skip-gram进行复现。
针对skip-gram对应的训练任务,代码中的node_proj相当于编码器,h_o_1和h_o_2相当于解码器。Encoder和Decoder可以先联合训练,训练结束后,可以只保留Encoder的部分,舍弃Decoder的部分。当再来一个独热编码的时候,可以直接通过node_proj映射,即完成了独热编码的embedding过程。
(本代码假定在当前结点去往各邻接结点的可能性相同,即不考虑边的权重)

import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import random
import torch.nn.functional as F
import networkx as nx
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.distributions import Categorical
import matplotlib.pyplot as pltclass MyGraph():def __init__(self,device):super(MyGraph, self).__init__()self.G = nx.read_edgelist(path='data/wiki/Wiki_edgelist.txt',create_using=nx.DiGraph(),nodetype=None,data=[('weight',int)])self.adj_matrix = nx.attr_matrix(self.G)self.edges = nx.edges(self.G)self.edges_emb = torch.eye(len(self.G.edges)).to(device)self.nodes_emb = torch.eye(len(self.G.nodes)).to(device)class GraphEmbedding(nn.Module):def __init__(self,nodes_num,edges_num,device,emb_dim = 10):super(GraphEmbedding, self).__init__()self.device = deviceself.nodes_proj = nn.Parameter(torch.randn(nodes_num,emb_dim))self.edges_proj = nn.Parameter(torch.randn(edges_num,emb_dim))self.h_o_1 = nn.Parameter(torch.randn(emb_dim,nodes_num * 2))self.h_o_2 = nn.Parameter(torch.randn(nodes_num * 2,nodes_num))def forward(self,G:MyGraph):self.nodes_proj,self.edges_proj = self.nodes_proj.to(self.device),self.edges_proj.to(device)self.h_o_1,self.h_o_2 = self.h_o_1.to(self.device),self.h_o_2.to(self.device)# Encoderedges_emb,nodes_emb = torch.matmul(G.edges_emb,self.edges_proj),torch.matmul(G.nodes_emb,self.nodes_proj)nodes_emb = F.elu_(nodes_emb)edges_emb,nodes_emb = edges_emb.to(device),nodes_emb.to(device)# Decoderpolicy = self.DeepWalk(G,gamma=5,window=2)outputs = torch.matmul(torch.matmul(nodes_emb[policy[:,0]],self.h_o_1),self.h_o_2)policy,outputs = policy.to(device),outputs.to(device)return policy,outputsdef DeepWalk(self,Graph:MyGraph,gamma:int,window:int,eps=1e-9):# Calculate transpose matrixadj_matrix = torch.tensor(Graph.adj_matrix[0], dtype=torch.float32)for i in range(adj_matrix.shape[0]):adj_matrix[i,:] /= (torch.sum(adj_matrix[i]) + eps)adj_nodes = Graph.adj_matrix[1].copy()random.shuffle(adj_nodes)nodes_idx, route_result = [],[]for node in adj_nodes:node_idx = np.where(np.array(Graph.adj_matrix[1]) == node)[0].item()node_list = self.Random_Walk(adj_matrix,window=window,node_idx=node_idx)route_result.append(node_list)return torch.tensor(route_result)def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]node_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_listdef HMM_process(self,adj_matrix:torch.Tensor,node_idx:int,eps=1e-9):pi = torch.zeros((1, adj_matrix.shape[0]), dtype=torch.float32)pi[:,node_idx] = 1.0pi = torch.matmul(pi,adj_matrix)pi = pi.squeeze(0) / (torch.sum(pi) + eps)return piif __name__ == "__main__":epochs = 200device = torch.device("cuda:1")cross_entrophy_loss = CrossEntropyLoss().to(device)Graph = MyGraph(device)Embedding = GraphEmbedding(nodes_num=len(Graph.G.nodes), edges_num=len(Graph.G.edges),device=device).to(device)optimizer = torch.optim.Adam(Embedding.parameters(),lr=1e-5)scheduler=CosineAnnealingLR(optimizer,T_max=50,eta_min=0.05)loss_list = []epoch_list = [i for i in range(1,epochs+1)]for epoch in range(epochs):policy,outputs = Embedding(Graph)outputs = outputs.unsqueeze(1).repeat(1,policy.shape[-1]-1,1).reshape(-1,outputs.shape[-1])optimizer.zero_grad()loss = cross_entrophy_loss(outputs, policy[:,1:].reshape(-1))loss.backward()optimizer.step()scheduler.step()loss_list.append(loss.item())print(f"Loss : {loss.item()}")plt.plot(epoch_list,loss_list)plt.xlabel('Epoch')plt.ylabel('CrossEntrophyLoss')plt.title('Loss-Epoch curve')plt.show()

在这里插入图片描述

Node2Vec

在这里插入图片描述
在这里插入图片描述
修改Random_Walk函数如下:

    def Random_Walk(self,adj_matrix:torch.Tensor,window:int,node_idx:int):node_list = [node_idx]for i in range(window):pi = self.HMM_process(adj_matrix,node_idx)if torch.sum(pi) == 0:pi += 1 / pi.shape[0]if i > 0:v,t = node_list[-1],node_list[-2]x_list = torch.nonzero(adj_matrix[v]).squeeze(-1)for x in x_list:if t == x:  # 0pi[x] *= 1/self.pelif adj_matrix[t][x] == 1:  # 1pi[x] *= 1else:   # 2pi[x] *= 1/self.qnode_idx = Categorical(pi).sample().item()node_list.append(node_idx)return node_list

结果如下,这里令p=2,q=3,即1/p=0.5,1/q=0.33,会相对保守周围。结果似乎好了那么一点点。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/124105.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

对图像中边、线、点的检测(支持平面/鱼眼/球面相机)附源码

前言 图像的线段检测是计算机视觉和遥感技术中的一个基本问题,可广泛应用于三维重建和 SLAM 。虽然许多先进方法在线段检测方面表现出了良好的性能,但对未去畸变原始图像的线 段检测仍然是一个具有挑战性的问题。此外,对于畸变和无畸变的图像都缺乏统一的…

Git使用【中】

欢迎来到Cefler的博客😁 🕌博客主页:那个传说中的man的主页 🏠个人专栏:题目解析 🌎推荐文章:题目大解析3 目录 👉🏻分支管理分支概念git branch(查看/删除分…

树的存储结构以及树,二叉树,森林之间的转换

目录 1.双亲表示法 2.孩子链表 3.孩子兄弟表示法 4.树与二叉树的转换 (1)树转换为二叉树 (2)二叉树转换成树 5.二叉树与森林的转化 (1)森林转换为二叉树 以下树为例 1.双亲表示法 双亲表示法定义了…

CompletableFuture 异步编排

目录 CompletableFuture 的详解代码测试配置类的引入Demo1Demo2CompletableFuture的async后缀函数与不带async的函数的区别ThreadPoolTaskExecutor 和 ThreadPoolExecutor 的区别Spring 线程池的使用业务使用多线程的原因场景一:场景二:FutureTask介绍线程池为什么要使用阻塞队…

nodejs+vue校园跑腿系统elementui

购物车品结算,管理个人中心,订单管理,接单处理,商品维护,用户管理,系统管理等功育食5)要求系统运行可靠、性能稳定、界面友好、使用方便。 第三章 系统分析 10 3.1需求分析 10 3.2可行性分析 10 3.2.1技术…

正则表达式 Regular Expression学习

该文章内容为以下视频的学习笔记: 10分钟快速掌握正则表达式_哔哩哔哩_bilibili正则表达式在线测试工具:https://regex101.com/, 视频播放量 441829、弹幕量 1076、点赞数 19330、投硬币枚数 13662、收藏人数 26242、转发人数 2768, 视频作者 奇乐编程学…

Spring MVC 十:异常处理

异常是每一个应用必须要处理的问题。 Spring MVC项目,如果不做任何的异常处理的话,发生异常后,异常堆栈信息会直接抛出到页面。 比如,我们在Controller写一个异常: GetMapping(value"/hello",produces{&qu…

【数据结构与算法】- 数组

数组 1.1 数组的定义1.2 数组的创建1.3 数组在内存中的情况2.1 初始化数组2.2 插入元素2.3 删除元素2.4 读取元素2.5 遍历数组 1.1 数组的定义 数组中的是在内存中是连续存储的,内存是由一个个内存单元组成的,每一个内存单元都有自己的地址,…

优先级队列

优先级队列 一端进,另一端出 按优先级出队 普通队列 一端进,另一端出 FIFO 我们做个约定:数字大的,优先级越高 ,优先出队 无序数组实现 要点 入队保持顺序 出队前找到优先级最高的出队,相当于一次选择排序…

Tensorflow2 GPU 安装方法

一、Tensorflow2 GPU 安装方法 1. 首先安装Anaconda3环境2. 在Anaconda Prompt 中安装tensorflow23. 验证GPU是否可以使用4. 错误解决 1. 首先安装Anaconda3环境 https://www.anaconda.com/ 2. 在Anaconda Prompt 中安装tensorflow2 conda update conda conda create -n ten…

Unity把UGUI再World模式下显示到相机最前方

Unity把UGUI再World模式下显示到相机最前方 通过脚本修改Shader 再VR里有时候要把3D的UI显示到相机最前方,加个UI相机会坏事,可以通过修改unity_GUIZTestMode来解决。 测试用例 测试用例如下: 场景包含一个红色的盒子,一个UI…

github搜索技巧

指定语言 language:java 比如我要找用java写的含有blog的内容 搜索项目名称包含关键词的内容 vue in:name 其他如项目描述跟项目文档,如下 组合使用 vue in:name,description,readme 根据Star 或者fork的数量来查找 总结 springboot vue stars:>1000 p…