news 2026/4/23 12:40:13

Day 39 模型可视化与推理

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 39 模型可视化与推理

@浙大疏锦行

一、nn.Module核心自带方法

nn.Module封装了模型的核心逻辑,以下是高频使用的自带方法,按功能分类:

1. 模型状态控制(训练 / 评估模式)

方法作用
model.train()切换为训练模式:启用 Dropout、BatchNorm 等层的训练行为(默认模式)
model.eval()切换为评估模式:关闭 Dropout、固定 BatchNorm 均值 / 方差,用于推理 / 验证
model.training属性,返回布尔值:True= 训练模式,False= 评估模式

示例

import torch import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 16, 3) self.dropout = nn.Dropout(0.5) # 训练时随机失活,评估时关闭 def forward(self, x): x = self.conv(x) x = self.dropout(x) return x model = SimpleCNN() print(model.training) # True(默认训练模式) model.eval() print(model.training) # False(评估模式,dropout失效) model.train() print(model.training) # True(切回训练模式)

2. 设备迁移(CPU/GPU)

方法作用
model.to(device)将模型所有参数 / 缓冲区移到指定设备(cuda/cpu/mps),返回模型实例
model.cuda()快捷方式:移到默认 GPU(等价于model.to('cuda')
model.cpu()快捷方式:移到 CPU(等价于model.to('cpu')

示例

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # 模型移到GPU/CPU # 验证设备 print(next(model.parameters()).device) # 输出:cuda:0 或 cpu

3. 参数管理(查看 / 遍历参数)

方法作用
model.parameters()返回生成器:包含所有可训练参数(nn.Parameter类型)
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.state_dict()返回字典:{参数名:参数值},用于保存模型参数
model.load_state_dict()加载参数字典,用于恢复模型

示例

# 查看所有参数名称和形状 for name, param in model.named_parameters(): print(f"参数名:{name},形状:{param.shape},设备:{param.device}") # 统计总参数量(手动实现,无第三方库时用) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"总参数:{total_params},可训练参数:{trainable_params}")

4. 结构遍历(查看模型层)

方法作用
model.children()返回生成器:仅包含直接子层(如 Sequential 内的第一层),不递归
model.named_children()返回生成器:(层名,子层),仅直接子层
model.modules()返回生成器:递归包含所有层(包括嵌套层)
model.named_modules()返回生成器:(层名,层),递归所有层

示例

# 定义嵌套模型 class NestedModel(nn.Module): def __init__(self): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(3, 16, 3), nn.ReLU() ) self.block2 = nn.Linear(16*30*30, 10) model = NestedModel() # children():仅直接子层(block1、block2) print("=== children() ===") for name, layer in model.named_children(): print(name, layer) # modules():递归所有层(包括Sequential内的Conv2d、ReLU) print("\n=== modules() ===") for name, layer in model.named_modules(): print(name, layer)

5. 前向传播与梯度

方法作用
model.forward(x)手动调用前向传播(不推荐),建议直接model(x)(调用__call__
model(x)等价于model.__call__(x),自动执行 forward + 钩子(hook)逻辑
model.zero_grad()清空所有参数的梯度(训练时反向传播前必须调用)

示例

x = torch.randn(1, 3, 32, 32).to(device) output = model(x) # 推荐:调用__call__,等价于model.forward(x) + 钩子 model.zero_grad() # 清空梯度 output.sum().backward() # 反向传播计算梯度

二、torchsummary库的summary方法

torchsummary是早期轻量库,核心功能是快速打印模型层结构、输出形状、总参数量,仅支持单输入模型,对嵌套模型 / 多输入支持差,维护较少。

1. 安装与基本用法

pip install torchsummary
from torchsummary import summary # 定义模型(输入:3通道32×32图像) class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc1 = nn.Linear(32 * 8 * 8, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = x.view(-1, 32 * 8 * 8) x = nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x # 设备配置 + 模型初始化 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 调用summary:参数(模型,输入形状(通道,高,宽),batch_size可选) summary(model, input_size=(3, 32, 32), batch_size=1)

2. 输出解读

---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [1, 16, 32, 32] 448 MaxPool2d-2 [1, 16, 16, 16] 0 Conv2d-3 [1, 32, 16, 16] 4,640 MaxPool2d-4 [1, 32, 8, 8] 0 Linear-5 [1, 128] 262,272 Linear-6 [1, 10] 1,290 ================================================================ Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ----------------------------------------------------------------

3. 优缺点

优点缺点
极简、无多余依赖仅支持单输入模型
输出简洁、易理解对嵌套模型 / 多分支模型支持差
快速查看参数量 / 形状无批次维度、无内存占用细分
支持 GPU/CPU维护停滞,仅兼容 PyTorch 旧版本

三、torchinfo库的summary方法(推荐)

torchinfotorchsummary的升级版(原torchsummaryX),解决了多输入、嵌套模型、维度展示不清晰的问题,功能更全面,是当前 PyTorch 模型可视化的首选。

1. 安装与基本用法

pip install torchinfo
from torchinfo import summary # 复用上面的SimpleCNN模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 核心参数:model, input_size, batch_dim, device, col_width等 summary( model, input_size=(1, 3, 32, 32), # (batch_size, 通道, 高, 宽) batch_dim=0, # 批次维度的位置(默认0) device=device, # 模型设备 col_width=20, # 列宽 col_names=["input_size", "output_size", "num_params", "trainable"], # 显示列 row_settings=["var_names"] # 显示层变量名 )

2. 输出解读

========================================================================================== Layer (type (var_name)) Input Shape Output Shape Param # Trainable ========================================================================================== SimpleCNN (SimpleCNN) [1, 3, 32, 32] [1, 10] -- -- ├─Conv2d (conv1) [1, 3, 32, 32] [1, 16, 32, 32] 448 True ├─MaxPool2d (pool) [1, 16, 32, 32] [1, 16, 16, 16] -- -- ├─Conv2d (conv2) [1, 16, 16, 16] [1, 32, 16, 16] 4,640 True ├─MaxPool2d (pool) [1, 32, 16, 16] [1, 32, 8, 8] -- -- ├─Linear (fc1) [1, 2048] [1, 128] 262,272 True ├─Linear (fc2) [1, 128] [1, 10] 1,290 True ========================================================================================== Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 Total mult-adds (M): 2.15 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ==========================================================================================

四、推理的写法:评估模式

def evaluate_classification(model, dataloader, device): """ 分类模型评估:计算准确率、F1-score(宏平均)、混淆矩阵 """ # 1. 切换到评估模式(必须!) model.eval() # 2. 初始化指标容器 all_preds = [] all_labels = [] # 3. 关闭梯度计算(加速+省显存) with torch.no_grad(): for batch_idx, (x, y) in enumerate(dataloader): # 数据移到设备 x = x.to(device, dtype=torch.float32) y = y.to(device, dtype=torch.long) # 4. 推理(前向传播) outputs = model(x) # 输出:(batch_size, num_classes) preds = torch.argmax(outputs, dim=1) # 取概率最大的类别 # 5. 收集预测结果和真实标签(转回CPU便于计算指标) all_preds.extend(preds.cpu().numpy()) all_labels.extend(y.cpu().numpy()) # 可选:打印进度 if (batch_idx + 1) % 10 == 0: print(f"Batch [{batch_idx+1}/{len(dataloader)}] 完成") # 6. 计算评估指标 accuracy = accuracy_score(all_labels, all_preds) f1_macro = f1_score(all_labels, all_preds, average="macro") # 宏平均F1(适合类别均衡) f1_weighted = f1_score(all_labels, all_preds, average="weighted") # 加权F1(适合类别不均衡) # 7. 打印结果 print("="*50) print(f"分类模型评估结果:") print(f"准确率 (Accuracy): {accuracy:.4f}") print(f"宏平均F1-score: {f1_macro:.4f}") print(f"加权F1-score: {f1_weighted:.4f}") print("="*50) return { "accuracy": accuracy, "f1_macro": f1_macro, "f1_weighted": f1_weighted, "preds": all_preds, "labels": all_labels } # 执行评估 eval_results = evaluate_classification(model, test_loader, device)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/23 11:31:40

Ubuntu编译自定义immortalwrt固件与软件编译

1 前言 istoreos中有许多可安装的软件,但如果自己需要制作一个特定的固件或者编译开源的源码时就需要编译来生成所需软件 2 所需工具 1.Ubuntu系统2.VMware虚拟机3.相应版本的sdk开发包4.ssh连接工具5.git(可选) 3 软件编译 3.1 openwrt…

作者头像 李华
网站建设 2026/4/23 12:29:26

【课程设计/毕业设计】基于springboot果蔬种植销售一体化服务平台的设计与实现果蔬信息、果蔬入库【附源码、数据库、万字文档】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/3/13 11:40:56

【课程设计/毕业设计】基于springboot的非物质文化遗产系统基于springboot非物质文化遗产数字化传承【附源码、数据库、万字文档】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/4/18 11:25:20

校园招聘会组织不再难,统筹安排让就业季更顺畅

✅作者简介:合肥自友科技 📌核心产品:智慧校园平台(包括教工管理、学工管理、教务管理、考务管理、后勤管理、德育管理、资产管理、公寓管理、实习管理、就业管理、离校管理、科研平台、档案管理、学生平台等26个子平台) 。公司所有人员均有多…

作者头像 李华
网站建设 2026/4/23 12:11:30

基于单片机的室内空气质量检测系统(有完整资料)

资料查找方式:特纳斯电子(电子校园网):搜索下面编号即可编号:T4362305C设计简介:本设计是基于STC89C52的室内空气质量监测系统,主要实现以下功能:可通过气体检测传感器监测当前空气质…

作者头像 李华
网站建设 2026/4/23 12:11:27

探索微网新能源经济消纳的共享储能优化配置之路

考虑微网新能源经济消纳的共享储能优化配置 共享储能是可再生能源实现经济消纳的解决方案之一,在适度的投资规模下,应尽力实现储能电站容量功率与消纳目标相匹配。 对此,提出了考虑新能源消纳的共享储能电站容量功率配置方法,针对…

作者头像 李华