news 2026/4/23 12:31:28

Day 44 预训练模型与迁移学习

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 44 预训练模型与迁移学习

在深度学习领域,从零开始训练一个高性能模型通常需要海量数据(如 ImageNet 的 120 万张图片)和昂贵的计算资源。对于大多数实际应用场景,我们更倾向于使用迁移学习 (Transfer Learning)

本篇笔记将结合 Day 44 的代码,深入剖析如何利用预训练的ResNet18模型,在CIFAR-10数据集上实现 86%+ 的高准确率。我们将重点拆解代码实现的每一个细节。


一、 数据准备:为模型提供高质量“燃料”

数据增强是提升模型泛化能力的关键。在迁移学习中,由于模型参数量较大(ResNet18 约 1100 万参数),而在小数据集(CIFAR-10 仅 5 万张训练图)上容易过拟合,因此强力的数据增强尤为重要。

1. 训练集增强策略 (代码详解)

train_transform = transforms.Compose([ # 1. 随机裁剪 (RandomCrop) # 先在图像四周填充 4 个像素的 0 (padding=4),图像变大 (40x40) # 然后随机裁剪出 32x32 的区域。 # 作用:让模型学习到物体在不同位置的特征,模拟物体平移。 transforms.RandomCrop(32, padding=4), # 2. 随机水平翻转 (RandomHorizontalFlip) # 以 50% 的概率水平翻转图像。 # 作用:模拟物体朝向的变化(如车头朝左或朝右),增加数据多样性。 transforms.RandomHorizontalFlip(), # 3. 颜色抖动 (ColorJitter) # 随机调整亮度(brightness)、对比度(contrast)、饱和度(saturation) 和色相(hue)。 # 作用:模拟不同光照条件下的物体,让模型对颜色变化不敏感。 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 4. 随机旋转 (RandomRotation) # 在 -15度 到 +15度 之间随机旋转。 # 作用:模拟拍摄角度的微小偏差。 transforms.RandomRotation(15), # 5. 转为 Tensor # 将 PIL Image (0-255) 转换为 Tensor (0.0-1.0),并调整维度顺序 (HWC -> CHW)。 transforms.ToTensor(), # 6. 标准化 (Normalize) # 使用 CIFAR-10 数据集的均值和标准差进行归一化:(x - mean) / std # 作用:加速收敛,使数据分布更符合模型假设。 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])

2. 测试集处理

测试集只需进行必要的格式转换和标准化,严禁使用随机增强操作,以确保评估结果的稳定性与真实性。

test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])

二、 模型构建:ResNet18 的“换头”手术

我们使用torchvision.models提供的 ResNet18。由于预训练模型是在 ImageNet(1000 类)上训练的,我们需要修改其输出层(Head)以适配 CIFAR-10(10 类)。

代码实现与解析

from torchvision.models import resnet18 import torch.nn as nn def create_resnet18(pretrained=True, num_classes=10): # 1. 加载预训练模型 # pretrained=True: 自动下载并加载在 ImageNet 上训练好的权重。 # 这些权重包含了提取通用视觉特征(边缘、纹理、形状)的能力。 model = resnet18(pretrained=pretrained) # 2. 修改全连接层 (Head) # model.fc 是 ResNet 的最后一层全连接层。 # in_features: 获取原全连接层的输入维度(ResNet18 为 512)。 in_features = model.fc.in_features # 3. 替换为新的全连接层 # 新层初始化时权重是随机的,输出维度设为 num_classes (10)。 # 注意:这一层没有预训练权重,需要从头训练。 model.fc = nn.Linear(in_features, num_classes) # 4. 转移到 GPU return model.to(device)

三、 训练策略:冻结与解冻 (Freeze & Unfreeze)

这是迁移学习中最核心的技巧。为了防止新初始化的全连接层(随机权重)产生的巨大梯度破坏预训练好的骨干网络(Backbone),我们通常采用分阶段训练

1. 冻结控制函数

def freeze_model(model, freeze=True): """ 控制模型参数的冻结与解冻 freeze=True: 冻结卷积层,只训练全连接层 freeze=False: 解冻所有层,进行全局微调 """ # 遍历模型的所有参数(权重和偏置) for name, param in model.named_parameters(): # 我们始终要训练全连接层 (fc),所以只冻结非 fc 层 if 'fc' not in name: # requires_grad=False 表示该参数不计算梯度,也不会被优化器更新 param.requires_grad = not freeze # (可选) 打印当前冻结状态,方便调试 frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) print(f"当前冻结参数量: {frozen_params}") return model

2. 分阶段训练逻辑

我们在train_with_freeze_schedule函数中实现了这一逻辑:

  • 阶段一 (Epoch 0 ~ freeze_epochs-1)
    • 调用freeze_model(model, freeze=True)
    • 此时,只有全连接层在更新。骨干网络充当一个固定的特征提取器。
    • 目的:让全连接层的权重快速收敛到合理范围。
  • 阶段二 (Epoch >= freeze_epochs)
    • 调用freeze_model(model, freeze=False)
    • 解冻所有参数
    • 降低学习率:通常将学习率降低 10 倍(如从 1e-3 降到 1e-4),以免破坏预训练的特征。
    • 目的:让整个网络针对 CIFAR-10 的特征进行微调 (Fine-tuning),进一步提升性能。
# 伪代码演示阶段切换逻辑 if epoch == freeze_epochs: print(">>> 解冻所有层,开始全局微调!") model = freeze_model(model, freeze=False) # 降低学习率,精细调整 optimizer.param_groups[0]['lr'] = 1e-4

四、 完整训练流程详解

以下是整合了上述所有模块的训练循环核心代码,每一行都有详细注释:

def train_one_epoch(model, loader, criterion, optimizer, device): model.train() # 切换到训练模式(启用 Dropout 和 BatchNorm 更新) running_loss = 0.0 correct = 0 total = 0 for batch_idx, (data, target) in enumerate(loader): # 1. 数据迁移到 GPU data, target = data.to(device), target.to(device) # 2. 梯度清零 (标准步骤) optimizer.zero_grad() # 3. 前向传播 output = model(data) # 4. 计算损失 loss = criterion(output, target) # 5. 反向传播 # 如果是冻结阶段,只有 fc 层的参数会有梯度 loss.backward() # 6. 参数更新 optimizer.step() # --- 统计指标 --- running_loss += loss.item() _, predicted = output.max(1) # 获取预测类别 total += target.size(0) correct += predicted.eq(target).sum().item() return running_loss / len(loader), 100. * correct / total

五、 实验现象与经验总结

在运行 Day 44 的代码时,你会观察到几个有趣的现象:

  1. 起步即巅峰
    • 即使在冻结阶段(前 5 个 Epoch),准确率也能迅速达到 70% 左右。这归功于 ResNet 强大的特征提取能力。
  2. 解冻后的飞跃
    • 第 6 个 Epoch(解冻瞬间),准确率通常会有一个明显的提升,因为卷积层开始适应新数据集的特征分布(如 CIFAR-10 的低分辨率)。
  3. 训练集 vs 测试集准确率倒挂
    • 现象:训练准确率 (Training Acc) 往往低于测试准确率 (Test Acc)。
    • 原因:训练集使用了强力数据增强(裁剪、旋转、变色),模型看到的是“变态”难度的图片;而测试集是“标准”图片。模型就像是在负重训练(训练集),考试(测试集)时自然觉得轻松。
  4. 最终性能
    • 经过 40 个 Epoch 的训练,ResNet18 在 CIFAR-10 上通常能达到86% - 90%的准确率,远超普通 CNN 的表现。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/21 19:48:36

Langchain-Chatchat关系图谱构建:揭示知识点之间的关联网络

Langchain-Chatchat关系图谱构建:揭示知识点之间的关联网络 在企业知识管理日益复杂的今天,一个常见却棘手的问题是:员工明明拥有数百份制度文档、操作手册和项目记录,但在面对“跨部门报销流程”或“绩效考核与晋升机制的联动规…

作者头像 李华
网站建设 2026/4/21 4:32:42

36、玩转媒体收藏:Windows Media Player 使用全攻略

玩转媒体收藏:Windows Media Player 使用全攻略 1. 管理媒体收藏 当你想要管理媒体收藏时,可点击媒体播放器功能任务栏中的“媒体库”按钮。此时屏幕会分成两个窗格,左侧是分类,右侧是单个歌曲。右侧窗格中显示的歌曲取决于你点击的分类。例如,点击“所有音乐”,右侧窗…

作者头像 李华
网站建设 2026/4/16 12:41:35

37、用Windows Movie Maker 2制作家庭电影

用Windows Movie Maker 2制作家庭电影 1. Windows Movie Maker简介 每一部电影或电视剧都是由一系列场景组织成的故事。Windows Movie Maker 是一款能让你以类似方式创建专业级视频的程序,你可以将家庭电影中的精彩场景,甚至从网络下载的视频片段组合起来。你制作的电影可以…

作者头像 李华
网站建设 2026/4/10 15:41:56

Gatus配置终极指南:从零开始构建企业级监控系统

Gatus配置终极指南:从零开始构建企业级监控系统 【免费下载链接】gatus ⛑ Automated developer-oriented status page 项目地址: https://gitcode.com/GitHub_Trending/ga/gatus 还在为服务频繁宕机而头疼?想找一个既简单又强大的监控工具&#…

作者头像 李华
网站建设 2026/4/22 10:58:08

Langchain-Chatchat方言识别尝试:粤语、四川话能否听懂?

Langchain-Chatchat方言识别尝试:粤语、四川话能否听懂? 在企业智能问答系统日益普及的今天,一个看似简单却极具现实挑战的问题浮出水面:当员工用一口地道的四川话问“报销流程咋个搞?”或用粤语嘀咕“我哋份合同有冇…

作者头像 李华
网站建设 2026/4/21 15:24:50

豆包手机正在重新定义规则

字节跳动的“豆包手机”低调上线。 首批仅几万台,官方定性为“技术预览版”,看起来像是一次小规模的硬件尝试。然而,剥开它“中兴代工、3499元售价”的普通外壳,你会发现这其实是一枚投向移动互联网深水区的核弹。 它不仅仅是一…

作者头像 李华