可视化卷积核参数对理解卷积神经网络的工作原理、优化模型性能、提高模型泛化能力有一定帮助作用。
下面以resnet18为例,可视化了部分卷积核参数。
import torchvision from matplotlib import pyplot as plt import torchmodel = torchvision.models.resnet18(pretrained=True) #model = torchvision.models.efficientnet_b0(pretrained=True) num = 1 # 遍历模型的每一层 for name, module in model.named_modules():# 判断是否为卷积层if isinstance(module, torch.nn.Conv2d):# 输出卷积层名称和权重print(f"layer {name} : {module.weight.data.shape}")_,_,H,W = module.weight.data.shapeif H >=3 and W >=3:plt.subplot(5,4,num)data = module.weight.data.numpy()plt.imshow(data[0,0,:,:]) #太多了,只显示一个卷积核num+=1plt.show()
结果如下: