创建一个具有三级嵌套的模型,结构如图:
import torch
import torch.nn as nn# 定义子子模块
class SubSubModule(nn.Module):def __init__(self):super(SubSubModule, self).__init__()self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)def forward(self, x):return self.conv(x)# 定义子模块
class SubModule(nn.Module):def __init__(self):super(SubModule, self).__init__()self.sub_sub_module = SubSubModule() # 实例化子子模块self.pool = nn.MaxPool2d(2)def forward(self, x):x = self.sub_sub_module(x) # 使用子子模块x = torch.relu(x)x = self.pool(x)return x# 定义主模块
class MainModule(nn.Module):def __init__(self):super(MainModule, self).__init__()self.sub_module = SubModule() # 实例化子模块self.fc = nn.Linear(3 * 16 * 16, 10) # 假设输入图像大小为 32x32def forward(self, x):x = self.sub_module(x) # 使用子模块x = x.view(x.size(0), -1) # 展平特征图x = self.fc(x)return x# 实例化主模块
model = MainModule()# 打印模型结构
print(model)
使用print直接打印
直接使用print函数打印,会以整个模型为单位打印
# 实例化主模块
model = MainModule()# 打印模型结构
print(model)
使用named_children()函数打印模型的子模块
named_children()返回的是仅是模型的子模块,字字模块并不返回,也就是次级模块
#打印模型的子模块
for name, module in model.named_children():print(name, module)
使用named_modules函数打印模型的子模块
named_children()会遍历模型中的所有模块,从主模块到子模块到子子模块到子子...子模块,每一个模块都会打印出来
#打印模型的所有模块
for name, module in model.named_modules():print(name, module)
使用named_parameters()函数打印模型的可学习参数
#打印模型的可学习参数
for name, param in model.named_parameters():print(name, param.size())