基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图

基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图

  • 一.局部效果图
  • 二.运行训练过程,拦截算子,生成调用关系信息
  • 三.可视化,生成SVG图像

想知道Megatron-DeepSpeed训练过程中各模块之间的调用关系。torch_dispatch机制可以拦截算子,inspect又能获取到调用栈(文件,类名,函数,行号).基于这些信息可以生成调用关系,最后用graphviz生成SVG图像。该思路也可以用来画其它pytorch工程的调用关系图

1.为了减少图像宽度,一行显示一级文件路径
2.没有显示具体的ATen算子。因为边太乱

一.局部效果图

在这里插入图片描述

二.运行训练过程,拦截算子,生成调用关系信息

# 前面构建模型的代码省略
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
from dataclasses import dataclass
from typing import Any
import pickle@dataclass
class _ProfilerState:cls: Anyobject: Any = Noneclass TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parent        self.global_index=0        self.nodes=set()self.edges=set()def __del__(self):self.rank = torch.distributed.get_rank()graph={"nodes":self.nodes,"edges":self.edges}with open(f"call_graph_{self.rank}.pkl","wb") as f:pickle.dump(graph,f)def is_keep(self,node):# if node.function.find("wrapper")>=0:#     return False# if node.function.find("_call_impl")>=0:#     return Falsereturn Truedef __torch_dispatch__(self, func, types, args=(), kwargs=None):self.global_index+=1self.rank = torch.distributed.get_rank() func_packet = func._overloadpacket       if kwargs is None:kwargs = {}if self.rank==0:stacks=[i for i in inspect.stack() if self.is_keep(i)]stacks_sz=len(stacks)for idx in range(stacks_sz-1,1,-1):if "self" in stacks[idx].frame.f_locals:class_name = stacks[idx].frame.f_locals["self"].__class__.__name__else:class_name=""this_node=f"{stacks[idx].filename}:[{class_name}]:{stacks[idx].function}"if "self" in stacks[idx-1].frame.f_locals:class_name = stacks[idx-1].frame.f_locals["self"].__class__.__name__else:class_name=""                                    next_node=f"{stacks[idx-1].filename}:[{class_name}]:{stacks[idx-1].function}"self.nodes.add(this_node)self.nodes.add(next_node)self.edges.add(f"{this_node}->{next_node}")# if stacks_sz>1:#     if "self" in stacks[1].frame.f_locals:#         class_name = stacks[1].frame.f_locals["self"].__class__.__name__#     else:#         class_name=""                #     this_node=f"{stacks[1].filename}:[{class_name}]:{stacks[1].function}"#     next_node=f"{func_packet.__name__}"#     self.nodes.add(this_node)   #     self.nodes.add(next_node)            #     self.edges.add(f"{this_node}->{next_node}")ret= func(*args, **kwargs)return retclass TorchDumper:_CURRENT_Dumper = Nonedef __init__(self,schedule: Any):self.p= _ProfilerState(schedule)def __enter__(self):assert TorchDumper._CURRENT_Dumper is NoneTorchDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)del self.p.object  #序列化保存def main():with TorchDumper(TorchDumpDispatchMode):#训练入口pretrain(train_valid_test_datasets_provider,model_provider,forward_step,extra_args_provider=llama_argument_handler,args_defaults={"tokenizer_type": "GPT2BPETokenizer"},)if __name__ == "__main__":main()

三.可视化,生成SVG图像

# coding=utf-8import os
from graphviz import Digraph,Graph
import pickle
import random
from distinctipy import distinctipydef generate_colors(N):'''生成N种有区别度的颜色'''result=[]for red, green, blue in distinctipy.get_colors(N):result.append("#{:02X}{:02X}{:02X}".format(int(red*255), int(green*255), int(blue*255)))return resultdef replace_name(name):'''修改节点名字(缩短,添加换行)'''if name.find("__torch_dispatch__")>=0:return Nonename=name.replace("/home/user/Megatron-DeepSpeed/","")name=name.replace("/home/anaconda3/envs/dev/lib/python3.10/site-packages/","")name=name.replace("/home/user/deepspeed/","")name=name.replace("/home/anaconda3/envs/dev/","")name=name.replace("/",r"\n")name=name.replace(":",r"\n")return name# 1.加载HOOK生成的调用关系文件
rank=0
with open(f"call_graph_{rank}.pkl","rb") as f:data=pickle.load(f)# 2.构建图,设置属性
dot = Digraph()
dot.node_attr = {"shape": "plaintext"}
dot.attr('graph', layout='dot')
dot.graph_attr.update(sep='4.0', ratio='compress')node_desc_id_map={}  #节点名与描述的关系映射表
src_node_color={}    #节点颜色映射表(同一个节点输出的边颜色一样)colors = generate_colors(10)
colors_sz=len(colors)fontsize="16"        #节点字体大小
penwidth="2.0"       #边宽度# 3.添加节点
for idx,v in enumerate(data["nodes"]):v=replace_name(v)if v is None:continuenode_desc_id_map[v]=f"{idx}"if v.find("megatron")>=0:dot.node(f"{idx}",v,style='filled',color='#73FBFD',fontsize=fontsize)elif v.find("deepspeed")>=0:dot.node(f"{idx}",v,style='filled',color='#FA8D89',fontsize=fontsize)else:dot.node(f"{idx}",v,style='filled',color='#C0C0C0',fontsize=fontsize)src_node_color[v]=colors[idx%colors_sz]# 4.添加边
for edge in data["edges"]:from_node,to_node=edge.split("->")from_node=replace_name(from_node)to_node=replace_name(to_node)if all([from_node,to_node]):color=src_node_color[from_node]dot.edge(node_desc_id_map[from_node], node_desc_id_map[to_node],color=color,penwidth=penwidth)# 5.保存SVG
save_path='megatron_deepspeed_callgraph'
dot.render(save_path,format='svg', view=False)# 6.修改背景色为灰色
import xml.etree.ElementTree as ET
svg_tree = ET.parse(f'{save_path}.svg')
root = svg_tree.getroot()
element = root.find(".//{http://www.w3.org/2000/svg}polygon")
element.set('fill', 'gray')
svg_tree.write(f'{save_path}.svg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/685634.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《引爆流量获客技术》实操方法,手把手教你搭建盈利流量池

[1]-先导课.mp4 [2]-第1节:设计客户终身价值的方法和买客户思维.mp4 [3]-第2节:【渠道模型】解决谁是我的客户如何找到.mp4 [4]-第3节:【诱饵模型】解决 如何获得更多的客户.mp4 [5]-第4节:【钩子模型】解决让目标客户主动找你…

轮式机器人

迄今为止,轮子一般是移动机器人学和人造交通车辆中最流行的运动机构。它可达到很高的效率, 如图所示, 而且用比较简单的机械就可实现它的制作。 另外,在轮式机器人设计中,平衡通常不是一个研究问题。 因为在所有时间里,轮式机器人一般都被设计成在任何时间里所有轮子均与地接…

milvus元数据在etcd的存储解析

milvus元数据在etcd的存储解析 数据以key-value形式存在。 大致包含如下一些种类: databasecollectionfieldpartitionindexsegment-indexresource_groupsession database 创建一个数据库会产生2个key,但value是相同的。 key规则: 前缀/root-coord/database/db…

01.基本概念

操作系统 为什么要有操作系统? 计算机时一个十分复杂的系统,又cpu、内存、磁盘、IO设备、网络接口等等复杂的硬件组成,人的精力是有限的,不可能了解所有的硬件接口,但是程序可以。 所以我们在计算机上安装了一层软件&…

无人机+光电吊舱:四光(可见光+红外热成像+广角+激光测距)吊舱设计技术详解

无人机与光电吊舱的结合,特别是四光吊舱(包含可见光、红外热成像、广角和激光测距技术)的应用,为无人机提供了强大的侦察和测量能力。以下是对四光吊舱设计技术的详解: 1. 可见光技术:可见光相机是吊舱中最…

福昕PDF阅读器取消手型工具鼠标点击翻页

前言: 本文介绍如何关闭福昕PDF阅读器取消手型工具鼠标点击翻页,因为这样真的很容易误触发PDF翻页,使用起来让人窝火。 引用: NA 正文: 新版的福昕PDF阅读器默认打开了“使用手型工具阅读文章”这个勾选项&#x…

IPO压力应变桥信号处理系列隔离放大器 差分信号隔离转换0-10mV/0-20mV/0-±10mV/0-±20mV转4-20mA/0-5V/0-10V

概述: IPO压力应变桥信号处理系列隔离放大器是一种将差分输入信号隔离放大、转换成按比例输出的直流信号混合集成厚模电路。产品广泛应用在电力、远程监控、仪器仪表、医疗设备、工业自控等行业。该模块内部嵌入了一个高效微功率的电源,向输入端和输出端…

CSS跳动文字

<div class"loading-mask"><div class"loading-text"><span style"--i:1">加</span><span style"--i:2">载</span><span style"--i:3">中</span><span style"--i:…

【Ubuntu18.04+melodic】抓取环境设置

UR5_gripper_camera_gazebo&#xff08;无moveit&#xff09; 视频讲解 B站-我要一米八了-抓取不止&#xff01;Ubuntu 18.04下UR5机械臂搭建Gazebo环境&#xff5c;开源分享 运行步骤 1.创建工作空间 catkin_make2.激活环境变量 source devel/setup.bash3.1 rviz下查看模…

Java入门基础学习笔记10——变量

变量的学习路径&#xff1a; 认识变量->为什么要用变量&#xff1f;->变量有啥特点&#xff1f;->变量有啥应用场景&#xff1f; 什么是变量&#xff1f; 变量是用来记住程序要处理的数据的。 变量的定义格式&#xff1a; 数据类型 变量名称 数据&#xff1b; 数…

照片不大于200K怎么改?在线图片处理工具的使用方法

现在使用图片的地方特别多&#xff0c;有时候需要图片压缩到200k&#xff0c;因为上传或传输大文件会受到限制&#xff0c;例如通过电子邮件发送、上传到云存储空间等等。在这种情况下&#xff0c;压缩图片大小可以让图片更容易地传输和分享&#xff0c;并且节省存储空间&#…

从需求到实现的关键

版本封面 内容&#xff1a;产品logo&#xff0c;项目名称&#xff0c;所属公司&#xff0c;产品名称&#xff0c;文档类型&#xff0c;版本号&#xff0c;时间&#xff0c;相关人员&#xff08;最好说明下负责人&#xff09;。 作用&#xff1a; 突出重要信息&#xff0c;将…