查看神经网络中间层特征矩阵及卷积核参数

可视化feature maps以及kernel weights,使用alexnet模型进行演示。

1. 查看中间层特征矩阵

alexnet模型,修改了向前传播

import torch
from torch import nn
from torch.nn import functional as F# 对花图像数据进行分类
class AlexNet(nn.Module):def __init__(self,num_classes=1000,init_weights=False, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.conv1 = nn.Conv2d(3,48,11,4,2)self.pool1 = nn.MaxPool2d(3,2)self.conv2 = nn.Conv2d(48,128,5,padding=2)self.pool2 = nn.MaxPool2d(3,2)self.conv3 = nn.Conv2d(128,192,3,padding=1)self.conv4 = nn.Conv2d(192,192,3,padding=1)self.conv5 = nn.Conv2d(192,128,3,padding=1)self.pool3 = nn.MaxPool2d(3,2)self.fc1 = nn.Linear(128*6*6,2048)self.fc2 = nn.Linear(2048,2048)self.fc3 = nn.Linear(2048,num_classes)# 是否进行初始化# 其实我们并不需要对其进行初始化,因为在pytorch中,对我们对卷积及全连接层,自动使用了凯明初始化方法进行了初始化if init_weights:self._initialize_weights()def forward(self,x):outputs = []  # 定义一个列表,返回我们要查看的哪一层的输出特征矩阵x = self.conv1(x)outputs.append(x)x = self.pool1(F.relu(x,inplace=True))x = self.conv2(x)outputs.append(x)x = self.pool2(F.relu(x,inplace=True))x = self.conv3(x)outputs.append(x)x = F.relu(x,inplace=True)x = F.relu(self.conv4(x),inplace=True)x = self.pool3(F.relu(self.conv5(x),inplace=True))x = x.view(-1,128*6*6)x = F.dropout(x,p=0.5)x = F.relu(self.fc1(x),inplace=True)x = F.dropout(x,p=0.5)x = F.relu(self.fc2(x),inplace=True)x = self.fc3(x)# for name,module in self.named_children():#     x = module(x)#     if name == ["conv1","conv2","conv3"]:#         outputs.append(x)return outputs# 初始化权重def _initialize_weights(self):for m in self.modules():if isinstance(m,nn.Conv2d):# 凯明初始化 - 何凯明nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m,nn.Linear):nn.init.normal_(m.weight, 0,0.01)  # 使用正态分布给权重赋值进行初始化nn.init.constant_(m.bias,0)

拿到向前传播的结果,对特征图进行可视化,这里,我们使用训练好的模型,直接加载模型参数。

注意,要使用与训练时相同的数据预处理。

import matplotlib.pyplot as plt
from torchvision import transforms
import alexnet_model
import torch
from PIL import Image
import numpy as np
from alexnet_model import AlexNet# AlexNet 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)image = Image.open("./images/yjx.jpg")
image = transform(image)
image = image.unsqueeze(0)with torch.no_grad():output = model(image)for feature_map in output:# (N,C,W,H) -> (C,W,H)im = np.squeeze(feature_map.detach().numpy())# (C,W,H) -> (W,H,C)im = np.transpose(im,[1,2,0])plt.figure()# 展示当前层的前12个通道for i in range(12):ax = plt.subplot(3,4,i+1) # i+1: 每个图的索引plt.imshow(im[:,:,i],cmap='gray')plt.show()

结果:

在这里插入图片描述


2. 查看卷积核参数

import matplotlib.pyplot as plt
import numpy as np
import torchfrom AlexNet.model import AlexNet# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)weights_keys = model.state_dict().keys()
for key in weights_keys:if "num_batches_tracked" in key:continueweight_t = model.state_dict()[key].numpy()weight_mean = weight_t.mean()weight_std = weight_t.std(ddof=1)weight_min = weight_t.min()weight_max = weight_t.max()print("mean is {}, std is {}, min is {}, max is {}".format(weight_mean, weight_std, weight_min, weight_max))weight_vec = np.reshape(weight_t,[-1])plt.hist(weight_vec,bins=50)plt.title(key)plt.show()

结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/413336.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

day2:TCP、UDP网络通信模型

思维导图 机械臂实现 #include <head.h> #define SER_POTR 8899 #define SER_IP "192.168.125.223" int main(int argc, const char *argv[]) {//创建套接字int cfdsocket(AF_INET,SOCK_STREAM,0);if(cfd-1){perror("");return -1;}//链接struct so…

书生·浦语大模型--第三节课笔记--基于 InternLM 和 LangChain 搭建你的知识库

文章目录 大模型开发范式RAGLangChain框架&#xff1a;构建向量数据库构建检索问答链优化建议web 部署 实践部分环境配置 大模型开发范式 LLM的局限性&#xff1a;时效性&#xff08;最新知识&#xff09;、专业能力有限&#xff08;垂直领域&#xff09;、定制化成本高&#…

响应式Web开发项目教程(HTML5+CSS3+Bootstrap)第2版 例4-4 label

代码 <!doctype html> <html> <head> <meta charset"utf-8"> <title>label</title> </head><body> 性别: <label for"male">男</label> <input type"radio" name"sex&quo…

vue2踩坑之项目:vue2+element实现前端导出

1.安装插件依赖 npm i --save xlsx0.17.0 file-saver2.0.5 2.单页面引入 前端导出插件 import FileSaver from "file-saver"; import * as XLSX from "xlsx"; //html <el-form-item><el-button type"primary" plain size"mini&quo…

三角形任意一外角大于不相邻的任意一内角

一.代数证明 ∵ 对与△ A C B 中 ∠ c 外接三角形是 ∠ B C D ∵对与△ACB中∠c外接三角形是∠BCD ∵对与△ACB中∠c外接三角形是∠BCD ∴ ∠ B C D π − ∠ C ∴∠BCD\pi-∠C ∴∠BCDπ−∠C ∵ ∠ A ∠ B ∠ C π ∵∠A∠B∠C\pi ∵∠A∠B∠Cπ ∴ ∠ B C D ∠ A ∠…

我在Vscode学OpenCV 图像处理五(直方图处理)

直方图是一种统计图&#xff0c;显示了图像中每个灰度级别&#xff08;或颜色通道&#xff09;的像素数量。通过分析图像的直方图&#xff0c;可以获得关于图像对比度、亮度和颜色分布等方面的重要信息。 直方图处理 一、直方图的意义二、绘制直方图2.1 直接使用Matplotlib.pyp…

关于Access中列的冻结的知识,看这篇就够了

在Microsoft Access中&#xff0c;有一个名为“冻结”的功能&#xff0c;使用户可以在滚动到另一个区域时保持数据表的某个区域可见。 可以使用冻结功能冻结数据表中的表、查询、窗体、视图或存储过程中的一个或多个字段。你冻结的字段将移动到数据表的左侧位置。 如何在Micr…

【云原生系列】容器安全

容器之所以广受欢迎&#xff0c;是因为它能简化应用或服务及其所有依赖项的构建、封装与推进&#xff0c;而且这种简化涵盖整个生命周期&#xff0c;跨越不同的工作流和部署目标。然而&#xff0c;容器安全依然面临着一些挑战。虽然容器有一些固有的安全优势&#xff08;包括增…

滚动菜单ListView

activity_main.xml <include layout"layout/title"/> 引用上章自定义标题栏 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app&qu…

PGSQL主键序列

PostgreSQL和 MySQL数据库还是有一定的区别。 下面了解一下 PGSQL的主键序列。 一、主键 1、系统自带主键序列 在 PostgreSQL 中&#xff0c;GENERATED BY DEFAULT 和 GENERATED ALWAYS 是用于定义自动生成的列&#xff08;Generated Column&#xff09;的选项。一般可作用…

Mybatis 分页插件 PageHelper

今天记录下 Mybatis 分页插件 pageHelper 的使用。 背景 有一个员工表(employee)&#xff0c;现在要使用 pageHelper 插件实现员工的分页查询。 员工表 create table employee (id bigint auto_increment comment 主键primary key,name varchar(32) not …

Flutter编译报错Connection timed out: connect

背景&#xff1a;用Android Studo 创建了Flutter项目&#xff0c;编译运行报错java.net.ConnectException: Connection timed out: connect 我自己的环境&#xff1a; windows11 Android Studio Flutter 截图如下&#xff1a; 将错误日志展开之后&#xff1a; Exception…