图神经网络:(图像分割)3D人物图像分割

文章说明:
1)参考资料:PYG的文档。超链。斯坦福大学的机器学习课程。超链。(应该要挂梯子)。博客原文。超链。(应该要挂梯子)。原文理论参考文献。超链。提取码8848。
2)我在百度网盘上传这篇文章jupyter notebook和预训练模型。超链。提取码8848.
3)博主水平不高,如有错误,还望批评指正
一些建议:注重理论建议直接去看文献;注重实践建议直接去看代码。他的代码会有详细注释,但实际没啥用,如果不看原文参考文献。建议手敲一遍代码,会对理解很有帮助。变量名字取得很好,如果有图神经基础,不看文献也是可以。

文章目录

  • 前言1:硬件问题
  • 前言2:有关综述
  • 数据描述
  • 数据下载
  • 任务描述
  • 代码演示

前言1:硬件问题

如果电脑不是很好,并不建议自己训练。我的电脑不是很好,训练大概有20分钟。最后电脑特别的烫,感觉对电脑很不好。我的电脑配置如下(应该是看这个,对于硬件我不清楚)。直接下载预训练的模型就好。在这里插入图片描述

前言2:有关综述

对于一般图像分割以及图像分类任务,卷积神经网络取得巨大成功。但是卷积神经网络不能处理不规则的数据结构。我们希望推广卷积神经网络到不规则数据结构。卷积神经网络博主不很了解,不所以作过多评价。图神经网络为解决问题,应孕而生。我们使用3D点云进行演示。

数据描述

我们使用两个矩阵表示数据:十分简单,看图易懂。图片自源博客。我们需要一个矩阵存储n个点的位置。我们需要一个矩阵存储点间的边关系(3点确定一个平面,这就解释为什么是3个点了)。
在这里插入图片描述

数据下载

超链。

任务描述

正如标题:一个简单分类任务。我们需要对3D点云进行分类。头部点云,躯干点云,左臂点云,左手点云,右臂点云,右手点云,左大腿点云,左小腿点云,左脚点云,右大腿点云,右小腿点云,右脚点云。

代码演示

import torch
device='cuda' if torch.cuda.is_available() else 'cpu'

路径有关注意事项1:下载数据之后不要进行解压,放在一个文件之中就可以了。
路径有关注意事项2:复制文件地址需要进行修改,可能这跟操作系统有关但是我不清楚,我就只说我的。直接复制是这样"C:\Users\19216\Desktop\project\3DImage_Classification_And_Segmentation",我们需要更改所有"\“变为”/"。

root="C:/Users/19216/Desktop/project/3DImage_Classification_And_Segmentation"

以下定义数据变换。

from torch_geometric.transforms import BaseTransform
from torch_geometric.data import Data
#BaseTransform的构造十分简单,建议自己去看源码
class NormalizeUnitSphere(BaseTransform):#静态方法,不依赖类(加了这个应该就不用加self了)@staticmethoddef _re_center(x):centroid=torch.mean(x,dim=0)return x-centroid@staticmethoddef _re_scale_to_unit_length(x):max_dist=torch.max(torch.norm(x,dim=1))return x/max_dist#类的默认调用方法def __call__(self,data:Data):if data.x is not None:data.x=self._re_scale_to_unit_length(self._re_center(data.x))return data#就是打印类的名字def __repr__(self):return "{}()".format(self.__class__.__name__)
from torch_geometric.transforms import Compose,FaceToEdge
pre_transform=Compose([FaceToEdge(remove_faces=False),NormalizeUnitSphere()])

以下加载变换数据。

from pathlib import Path
import trimesh
def load_mesh(mesh_filename:Path):mesh=trimesh.load_mesh(mesh_filename,process=False)vertices=torch.from_numpy(mesh.vertices).to(torch.float)faces=torch.from_numpy(mesh.faces).t().to(torch.long).contiguous()return vertices,faces
from torch_geometric.data import InMemoryDataset,extract_zip
from functools import lru_cache
import numpy as np

关于这部分的代码,必须看这,看了你就知道了吧。这里代码逻辑是挺有意思的,由于篇幅原因读者自行研究。我来讲下逻辑,不一定正确哈。首先train_data申请调用SegmentationFaust。父类立马开始调用四个方法(如果没有直接跳过) raw_file_names(),processed_file_names(),download(),process()。具体到这里就只有processed_file_names()、process()。父类发现文件夹中没有processed_file_names()的对应文件,立即用process()处理数据生成processed_file_names()的对应文件。然后赋值[“training.pt”,“test.pt”]给self.processed_paths。最后子类开始运作读取数据并且赋值。所有数据在第一步处理好了。

class SegmentationFaust(InMemoryDataset):map_seg_label_to_id=dict(head=0,torso=1,left_arm=2,left_hand=3,right_arm=4,right_hand=5,left_upper_leg=6,left_lower_leg=7,left_foot=8,right_upper_leg=9,right_lower_leg=10,right_foot=11)           def __init__(self,root,train:bool=True,pre_transform=None):super().__init__(root,pre_transform)path=self.processed_paths[0] if train else self.processed_paths[1]self.data,self.slices=torch.load(path)#将方法转换为属性@propertydef processed_file_names(self)->list:return ["training.pt","test.pt"]@property#结果缓存,提高效率@lru_cache(maxsize=32)def _segmentation_labels(self):path_to_labels=Path(self.root)/"MPI-FAUST"/"segmentations.npz"seg_labels=np.load(str(path_to_labels))["segmentation_labels"]return torch.from_numpy(seg_labels).type(torch.int64)def _mesh_filenames(self):path_to_meshes=Path(self.root)/"MPI-FAUST"/"meshes"#正则匹配return path_to_meshes.glob("*.ply")def _unzip_dataset(self):path_to_zip=Path(self.root)/"MPI-FAUST.zip"extract_zip(str(path_to_zip),self.root,log=False)def process(self):self._unzip_dataset()data_list=[]for mesh_filename in sorted(self._mesh_filenames()):vertices, faces=load_mesh(mesh_filename)data=Data(x=vertices, face=faces)data.segmentation_labels=self._segmentation_labelsif self.pre_transform is not None:data=self.pre_transform(data)data_list.append(data)torch.save(self.collate(data_list[:80]),self.processed_paths[0])torch.save(self.collate(data_list[80:]),self.processed_paths[1])
train_data=SegmentationFaust(root=root,pre_transform=pre_transform)
#输出:
#Processing...
#Done!
test_data=SegmentationFaust(root=root,train=False,pre_transform=pre_transform)
from torch_geometric.loader import DataLoader
train_loader=DataLoader(train_data,shuffle=True)
test_loader=DataLoader(test_data,shuffle=False)
from itertools import tee

这段代码特别抽象,读者自行理解研究(我的意思语法抽象不指代码逻辑)

def pairwise(iterable):a,b=tee(iterable)next(b,None)return zip(a,b)
import torch.nn as nn

这段代码同样抽象,读者自行理解研究(我的意思语法抽象不指代码逻辑)

def get_mlp_layers(channels:list,activation,output_activation=nn.Identity):layers=[]*intermediate_layer_definitions,final_layer_definition=pairwise(channels)for in_ch,out_ch in intermediate_layer_definitions:intermediate_layer=nn.Linear(in_ch,out_ch)layers+=[intermediate_layer,activation()]layers+=[nn.Linear(*final_layer_definition),output_activation()]return nn.Sequential(*layers)
from torch_geometric.nn import MessagePassing
def get_conv_layers(channels:list,conv:MessagePassing,conv_params:dict):conv_layers=[conv(in_ch,out_ch,**conv_params) for in_ch,out_ch in pairwise(channels)]return conv_layers
from torch_geometric.utils import add_self_loops,remove_self_loops
import torch.nn.functional as F

最后介绍参考论文,这里暂时放下不表
以下部分均为模型建立

class FeatureSteeredConvolution(MessagePassing):def __init__(self,in_channels:int,out_channels:int,num_heads:int,ensure_trans_invar:bool=True,bias:bool=True,with_self_loops:bool=True):super().__init__(aggr="mean")self.in_channels=in_channels;self.out_channels=out_channels;self.num_heads=num_heads;self.with_self_loops=with_self_loopsself.linear=torch.nn.Linear(in_features=in_channels,out_features=out_channels*num_heads,bias=False)self.u=torch.nn.Linear(in_features=in_channels,out_features=num_heads,bias=False)self.c=torch.nn.Parameter(torch.Tensor(num_heads))if not ensure_trans_invar:self.v=torch.nn.Linear(in_features=in_channels,out_features=num_heads,bias=False)else:self.register_parameter("v",None)if bias:self.bias=torch.nn.Parameter(torch.Tensor(out_channels))else:self.register_parameter("bias",None)self.reset_parameters()def reset_parameters(self):torch.nn.init.uniform_(self.linear.weight)torch.nn.init.uniform_(self.u.weight)torch.nn.init.normal_(self.c,mean=0.0,std=0.1)if self.v is not None:torch.nn.init.uniform_(self.v.weight)if self.bias is not None:torch.nn.init.normal_(self.bias,mean=0.0,std=0.1)def forward(self,x,edge_index):if self.with_self_loops:edge_index,_=remove_self_loops(edge_index)edge_index,_=add_self_loops(edge_index=edge_index,num_nodes=x.shape[0])out=self.propagate(edge_index,x=x)return out if self.bias is None else out+self.biasdef _compute_attention_weights(self,x_i,x_j):if x_j.shape[-1]!=self.in_channels:raise ValueError(f"Expected input features with {self.in_channels} channels."f"Instead received features with {x_j.shape[-1]} channels.")if self.v is None:attention_logits=self.u(x_i-x_j)+self.celse:attention_logits=self.u(x_i)+self.b(x_j)+self.creturn F.softmax(attention_logits,dim=1)def message(self,x_i,x_j):attention_weights=self._compute_attention_weights(x_i,x_j)x_j=self.linear(x_j).view(-1,self.num_heads,self.out_channels)return (attention_weights.view(-1,self.num_heads,1)*x_j).sum(dim=1)
class GraphFeatureEncoder(torch.nn.Module):def __init__(self,in_features,conv_channels,num_heads,apply_batch_norm:int=True,ensure_trans_invar:bool=True,bias:bool=True,with_self_loops:bool=True):super().__init__()self.apply_batch_norm=apply_batch_norm;conv_params=dict(num_heads=num_heads,ensure_trans_invar=ensure_trans_invar,bias=bias,with_self_loops=with_self_loops)conv_layers=get_conv_layers(channels=[in_features]+conv_channels,conv=FeatureSteeredConvolution,conv_params=conv_params)self.conv_layers=nn.ModuleList(conv_layers)*first_conv_channels,final_conv_channel=conv_channelsself.batch_layers=[None for _ in first_conv_channels]if apply_batch_norm:self.batch_layers=nn.ModuleList([nn.BatchNorm1d(channel) for channel in first_conv_channels])def forward(self,x,edge_index):*first_conv_layers,final_conv_layer=self.conv_layersfor conv_layer,batch_layer in zip(first_conv_layers,self.batch_layers):x=conv_layer(x,edge_index)x=F.relu(x)if batch_layer is not None:x=batch_layer(x)return final_conv_layer(x,edge_index)
class MeshSeg(torch.nn.Module):def __init__(self,in_features,encoder_features,conv_channels,encoder_channels,decoder_channels,num_heads,num_classes,apply_batch_norm=True):super().__init__()self.input_encoder=get_mlp_layers(channels=[in_features]+encoder_channels,activation=nn.ReLU)self.gnn=GraphFeatureEncoder(in_features=encoder_features,conv_channels=conv_channels,num_heads=num_heads,apply_batch_norm=apply_batch_norm)*_,final_conv_channel=conv_channelsself.final_projection=get_mlp_layers([final_conv_channel]+decoder_channels+[num_classes],activation=nn.ReLU)def forward(self,data):x,edge_index=data.x,data.edge_indexx=self.input_encoder(x)x=self.gnn(x,edge_index)return self.final_projection(x)

设定参数

model_params=dict(in_features=3,encoder_features=16,conv_channels=[32,64,128,64],encoder_channels=[16],decoder_channels=[32],num_heads=12,num_classes=12,apply_batch_norm=True)
net=MeshSeg(**model_params).to(device)
best_test_acc=0.0;num_epochs=50;lr=0.001;optimizer=torch.optim.Adam(net.parameters(),lr=lr);loss_fn=torch.nn.CrossEntropyLoss()

开始训练

def train(net,train_data,optimizer,loss_fn,device):net.train()cumulative_loss=0.0for data in train_data:data=data.to(device)optimizer.zero_grad()out=net(data)loss=loss_fn(out,data.segmentation_labels.squeeze())loss.backward()cumulative_loss+=loss.item()optimizer.step()return cumulative_loss/len(train_data)
def accuracy(predictions,gt_seg_labels):predicted_seg_labels=predictions.argmax(dim=-1,keepdim=True)if predicted_seg_labels.shape!=gt_seg_labels.shape:raise ValueError("Expected Shapes to be equivalent")correct_assignments=(predicted_seg_labels==gt_seg_labels).sum()num_assignemnts=predicted_seg_labels.shape[0]return float(correct_assignments/num_assignemnts)
def evaluate_performance(dataset,net,device):prediction_accuracies=[]for data in dataset:data=data.to(device)predictions=net(data)prediction_accuracies.append(accuracy(predictions,data.segmentation_labels))return sum(prediction_accuracies)/len(prediction_accuracies)
@torch.no_grad()
def test(net,train_data,test_data,device):net.eval()train_acc=evaluate_performance(train_data,net,device)test_acc=evaluate_performance(test_data,net,device)return train_acc,test_acc
from tqdm import tqdm
with tqdm(range(num_epochs),unit="Epoch") as tepochs:for epoch in tepochs:train_loss=train(net,train_loader,optimizer,loss_fn,device)train_acc,test_acc=test(net,train_loader,test_loader,device)tepochs.set_postfix(train_loss=train_loss,train_accuracy=100*train_acc,test_accuracy=100*test_acc)if test_acc>best_test_acc:best_test_acc=test_acctorch.save(net.state_dict(),root+"/checkpoint_best_colab")

开始画图

def load_model(model_params,path_to_checkpoint,device):try:model=MeshSeg(**model_params)model.load_state_dict(torch.load(str(path_to_checkpoint)),strict=True)model.to(device)return modelexcept RuntimeError as err_msg:raise ValueError(f"Given checkpoint {str(path_to_checkpoint)} could not be loaded. {err_msg}")
def get_best_model(model_params,device):path_to_trained_model=Path(root+"/checkpoint_best_colab")trained_model=load_model(model_params,path_to_trained_model,device)return trained_model
net=get_best_model(model_params,device)
segmentation_colors=dict(head=torch.tensor([255,255,255],dtype=torch.int),torso=torch.tensor([255,255,128],dtype=torch.int),left_arm=torch.tensor([255,255,0],dtype=torch.int),left_hand=torch.tensor([255,128,255],dtype=torch.int),right_arm=torch.tensor([255,128,128],dtype=torch.int),right_hand=torch.tensor([255,128,0],dtype=torch.int),left_upper_leg=torch.tensor([255,0,255],dtype=torch.int),left_lower_leg =torch.tensor([255,0,128],dtype=torch.int),left_foot=torch.tensor([255,0,0],dtype=torch.int),right_upper_leg=torch.tensor([128,255,255],dtype=torch.int),right_lower_leg=torch.tensor([128,255,128],dtype=torch.int),right_foot=torch.tensor([128,255,0],dtype=torch.int)
)
map_seg_id_to_color=dict((_value,segmentation_colors[_key]) for _key,_value in train_data.map_seg_label_to_id.items())
@torch.no_grad()
def visualize_prediction(net,data,device,map_seg_id_to_color):def _map_seg_label_to_color(seg_ids,map_seg_id_to_color):return torch.vstack([map_seg_id_to_color[int(seg_ids[idx])] for idx in range(seg_ids.shape[0])])data=data.to(device)predictions=net(data)predicted_seg_labels=predictions.argmax(dim=-1,keepdim=True)mesh_colors=_map_seg_label_to_color(predicted_seg_labels, map_seg_id_to_color)segmented_mesh=trimesh.base.Trimesh(vertices=data.x.cpu().numpy(),faces=data.face.t().cpu().numpy(),process=False)segmented_mesh.visual.vertex_colors=mesh_colors.cpu().numpy()return segmented_mesh
segmented_meshes=[]
mesh_ids=[0,1,2,3,4,5,6,7,8,9]
for idx,mesh_id in enumerate(mesh_ids):segmented_mesh=visualize_prediction(net,test_data[mesh_id],device,map_seg_id_to_color)segmented_mesh.vertices+=[idx*1.0,0.0,0.0]segmented_meshes.append(segmented_mesh)
scene=trimesh.scene.Scene(segmented_meshes)
scene.show()

在这里插入图片描述
论文部分不想写了。以后再来吧,那就这样吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/7660.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Spring Boot】Spring Boot配置文件详情

前言 Spring Boot是一个开源的Java框架,用于快速构建应用程序和微服务。它基于Spring Framework,通过自动化配置和约定优于配置的方式,使开发人员可以更快地启动和运行应用程序。Spring Boot提供了许多开箱即用的功能和插件,包括嵌…

微信小程序 滚动到底部加载新的数据 之后滚动到顶部

1.配置到底部监听 在app.json的window里面加入 里面的300表示距离底部300rpx触发onReachBottom事件 默认50rpx "window": {"onReachBottomDistance": 300}, 2.在数据列表的js页面 /*** 页面上拉触底事件的处理函数*/onReachBottom() {console.log("…

消息中间件面试题详解

RabbitMQ 如何保证消息不丢失 消息的重复消费问题如何解决 rabbitmq中死信交换机(RabbitMQ延迟队列有了解吗) 延迟队列:进入队列的消息会被延迟消费的队列 场景:超时订单,限时优惠,定时发布 延迟队列 …

【Linux】-第一个小程序(进度条)

💖作者:小树苗渴望变成参天大树 🎉作者宣言:认真写好每一篇博客 🎊作者gitee:gitee 💞作者专栏:C语言,数据结构初阶,Linux,C 动态规划算法 如 果 你 喜 欢 作 者 的 文 章 ,就 给 作…

真赞!IDEA中可以这么玩MyBatis,让编码速度飞起!

本篇博客图解 MyBatis Generator 的使用过程&#xff0c;并结合实战说明逆向工程的使用方式。 搭建 MyBatis Generator 插件环境 a. 添加插件依赖 pom.xml <!--mybatis 逆向生成插件--> <plugin><groupId>org.mybatis.generator</groupId><artifac…

iPad平板用的触控笔什么牌子好?主动式电容笔推荐

现在&#xff0c;电容笔已经成为在线办公、在线教育等产业中的热门产品&#xff0c;那么&#xff0c;平替电容笔是否会代替苹果原有的电容笔呢&#xff1f;实际上&#xff0c;你根本不需要花那么多钱去买一个原装的苹果电容笔。一支普通的平替式电容笔只需要一两百元&#xff0…

微分方程应用——笔记整理

首先&#xff0c;根据正常思路走&#xff0c;化简得到式子&#xff1a; 不难发现&#xff0c;设 后面得出该方程的通解&#xff1a; 这里要注意什么等于这个通解 --- z 又因为该曲线过点 所以可以求出c为3 该题虽然简单&#xff0c;但是要注意几个问题&#xff0c;该定…

【封装丨工具类】

封装工具类 封装 Java 工具类1. 使用静态工厂方法或静态方法封装实例2.将工具类中的方法进行分组3. 常用方法定义为静态方法或枚举4. 工厂 | 抽象工厂模式5. 访问数据库5.1 JDBC API &#xff1a;5.2 第三方数据库 封装 Java 工具类 1. 使用静态工厂方法或静态方法封装实例 使…

netty学习(2):多个客户端与服务器通信

1. 基于前面一节netty学习&#xff08;1&#xff09;:1个客户端与服务器通信 只需要把服务器的handler改造一下即可&#xff0c;通过ChannelGroup 找到所有的客户端channel&#xff0c;发送消息即可。 package server;import io.netty.channel.*; import io.netty.channel.gr…

68、基于51单片机语音识别控制小车行走系统设计(程序+原理图+PCB源文件+参考论文+开题报告+任务书+元器件清单等)

摘 要 随着电子工业的发展&#xff0c;具有语音控制功能的小车越来越受到人们的青睐&#xff0c;在人们的日常消费生活中起着不可忽视的作用。目前&#xff0c;声控技术已在很多领域得到使用。 本文对语音控制功能的小车概况做了阐述。在硬件设计方面&#xff0c;本论文以凌阳…

LabVIEW评估儿童的运动认知技能

LabVIEW评估儿童的运动认知技能 以前测量认知运动功能的技术范围从基本和耗时的笔和纸技术&#xff0c;到使用准确但复杂和昂贵的实验室设备。Kinelab的主要要求是提供一个易于配置、坚固且便携的平台&#xff0c;以便在向4-12岁的儿童展示交互式视觉刺激期间快速收集运动学测…

blender 之点云渲染(论文渲图)

blender 之点云渲染&#xff08;论文渲图&#xff09; 一、导入点云1.新建2.导入点云3.位置移动&放大缩小 二、Geometry Nodes实体化点云1.新建节点2.实体化 三、给实体化点云添加材质四、设置渲染引擎更换为Cycles。 五、对准视角1.新建一个球2.创建相机视角跟踪3.将uv球挪…