news 2026/6/10 15:52:11

深度学习实验——PyTorch实现CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
深度学习实验——PyTorch实现CIFAR10彩色图片识别
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

文章目录

  • 1. 简介
  • 2. 环境
  • 3. 数据集介绍
  • 4. 代码实现
    • 4.1 前期准备
      • 4.1.1 导入库 & GPU设置
      • 4.1.2 数据下载和数据集划分
      • 4.1.3 数据可视化
    • 4.2 模型构建
    • 4.3 模型训练
      • 4.3.1 设置超参数 & 编写训练和测试函数
      • 4.3.2 正式训练
  • 5. 结果可视化

1. 简介

利用Pytorch构建CNN模型以用于识别彩色图片

2. 环境

  • 语言环境:Python 3.12.7
  • 编译器:Jupyter Notebook
  • 深度学习环境:torch—2.8.0 + cu126 / torchvision—0.23.1+cu126

3. 数据集介绍

CIFAR-10数据集,又称加拿大高等研究院数据集是一个常用于训练机器学习和计算机视觉算法的图像集合。它是最广泛使用的机器学习研究数据集之一。CIFAR-10数据集包含60,000张32×32像素的彩色图像,分为10个不同的类别。

4. 代码实现

4.1 前期准备

4.1.1 导入库 & GPU设置

importtorchimporttorch.nnasnnimportmatplotlib.pyplotaspltimporttorchvisionimportnumpyasnpimporttorch.nn.functionalasFfromtorchinfoimportsummaryimportwarningsfromdatetimeimportdatetime warnings.filterwarnings("ignore")plt.rcParams['font.sans-serif']=['SimHei']plt.rcParams['axes.unicode_minus']=Falseplt.rcParams['figure.dpi']=100device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")device

4.1.2 数据下载和数据集划分

先使用torchvision的datasets下载CIFAR10数据集,并划分好训练集与测试集。

train_ds=torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)test_ds=torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)


然后使用DataLoader()加载数据,并设置好基本的batch_size。

batch_size=32train_dl=torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True)test_dl=torch.utils.data.DataLoader(test_ds,batch_size=batch_size)imgs,labels=next(iter(train_dl))imgs.shape

4.1.3 数据可视化

使用transpose()对NumPy数组进行轴变换,将轴的顺序从PyTorch存储图像的(C, H, W)格式转换为(H, W, C)格式,使得数据格式更适合Matplotlib imshow() 函数可视化和处理。

plt.figure(figsize=(20,5))fori,imgsinenumerate(imgs[:20]):npimg=imgs.numpy().transpose((1,2,0))plt.subplot(2,10,i+1)plt.imshow(npimg,cmap=plt.cm.binary)plt.axis('off')

4.2 模型构建

这个模型专门为32×32像素的CIFAR-10图像设计(10个类别),包含3个卷积层和2个全连接层。
首先通过三个卷积层逐级提取图像特征:第一层将RGB三通道转换为64个特征图,第二层保持64个特征图进行深度特征提取,第三层进一步扩展到128个特征图以捕获更复杂的模式,每个卷积层后都使用2×2最大池化层逐步降低空间分辨率。然后网络将三维特征图展平为一维向量,通过两个全连接层进行分类决策:第一层将512维特征压缩到256维并应用ReLU激活函数,第二层输出最终的10个类别分数。

num_classes=10classModel(nn.Module):def__init__(self):super().__init__()self.conv1=nn.Conv2d(3,64,kernel_size=3)self.pool1=nn.MaxPool2d(kernel_size=2)self.conv2=nn.Conv2d(64,64,kernel_size=3)self.pool2=nn.MaxPool2d(kernel_size=2)self.conv3=nn.Conv2d(64,128,kernel_size=3)self.pool3=nn.MaxPool2d(kernel_size=2)self.fc1=nn.Linear(512,256)self.fc2=nn.Linear(256,num_classes)defforward(self,x):x=self.pool1(F.relu(self.conv1(x)))x=self.pool2(F.relu(self.conv2(x)))x=self.pool3(F.relu(self.conv3(x)))x=torch.flatten(x,start_dim=1)x=F.relu(self.fc1(x))x=self.fc2(x)returnx model=Model().to(device)summary(model)

4.3 模型训练

4.3.1 设置超参数 & 编写训练和测试函数

训练函数train在每个批次中执行前向传播计算预测值,使用交叉熵损失评估误差,通过反向传播计算梯度并利用SGD优化器更新模型参数,同时统计训练准确率和损失;测试函数test则在禁用梯度计算的模式下进行前向传播,评估模型在验证集上的表现而不更新权重,最终返回模型在测试数据上的平均准确率和损失,两个函数共同构成了一个典型的有监督深度学习训练评估循环。

loss_fn=nn.CrossEntropyLoss()learn_rate=1e-2opt=torch.optim.SGD(model.parameters(),lr=learn_rate)deftrain(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_loss,train_acc=0,0forX,yindataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_loss+=loss.item()train_acc/=size train_loss/=num_batchesreturntrain_acc,train_lossdeftest(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)test_loss,test_acc=0,0withtorch.no_grad():forimgs,targetindataloader:imgs,target=imgs.to(device),target.to(device)target_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=size test_loss/=num_batchesreturntest_acc,test_loss

4.3.2 正式训练

epochs=10train_loss=[]train_acc=[]test_loss=[]test_acc=[]forepochinrange(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template=('Epoch:{:2d}, train_acc:{:.1f}%, train_loss:{:.3f}, test_acc:{:.1f}%, test_loss:{:.3f}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))print('Done')

5. 结果可视化

current_time=datetime.now()epochs_range=range(epochs)plt.figure(figsize=(12,3))plt.subplot(1,2,1)plt.plot(epochs_range,train_acc,label='Training Accuracy')plt.plot(epochs_range,test_acc,label='Test Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.xlabel(current_time)plt.subplot(1,2,2)plt.plot(epochs_range,train_loss,label='Training Loss')plt.plot(epochs_range,test_loss,label='Test Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/9 14:43:21

【医疗数据PHP备份终极指南】:9大策略确保数据零丢失

第一章:医疗数据备份的核心挑战与PHP角色在现代医疗信息系统中,数据的完整性、安全性和可恢复性至关重要。患者病历、诊断记录和治疗方案等敏感信息一旦丢失或泄露,可能造成不可挽回的后果。因此,构建高效可靠的医疗数据备份机制成…

作者头像 李华
网站建设 2026/6/8 20:19:09

Cangaroo开源CAN总线分析工具终极指南

Cangaroo开源CAN总线分析工具终极指南 【免费下载链接】cangaroo 项目地址: https://gitcode.com/gh_mirrors/ca/cangaroo Cangaroo是一款功能强大的开源CAN总线分析软件,专为汽车电子、工业控制和嵌入式系统开发设计。作为专业的CAN总线调试工具&#xff0…

作者头像 李华
网站建设 2026/6/7 3:02:33

GraphQL的PHP字段别名使用全解析(性能优化与编码规范)

第一章:GraphQL的PHP字段别名概述在构建现代Web API时,GraphQL因其灵活的数据查询能力而广受欢迎。当使用PHP实现GraphQL服务时,字段别名(Field Aliasing)是一项关键功能,允许客户端在查询中为返回的字段指…

作者头像 李华
网站建设 2026/6/10 14:39:57

沪上装修公司前十名避坑指南,2025年家悦可可装饰帮你筛靠谱名单

为什么“沪上装修公司前十名”成了搜索热词?在上海,装修一套房子动辄几十万,工期动辄三个月,谁都不想“踩坑”。于是,很多业主在动工前都会把“沪上装修公司前十名”敲进搜索框,希望用一份“榜单”快速锁定…

作者头像 李华
网站建设 2026/6/10 4:51:23

开发者必看:如何通过LLama-Factory在Ollama中部署自定义微调模型

如何通过 LLama-Factory 在 Ollama 中部署自定义微调模型 在大语言模型(LLM)日益渗透各行各业的今天,越来越多开发者不再满足于“通用对话”能力。他们真正关心的是:如何让一个像 Llama-3 这样的开源模型,变成懂金融、…

作者头像 李华