- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
一、前期准备
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorch.optimasoptimfromtorchvisionimportdatasets,transformsfromtorch.utils.dataimportDataset,DataLoaderimportosimportglobfromPILimportImageimportnumpyasnpimportmatplotlib.pyplotaspltimportwarnings warnings.filterwarnings("ignore")device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")二、字典定义
CHARS=['京','沪','津','渝','冀','晋','蒙','辽','吉','黑','苏','浙','皖','闽','赣','鲁','豫','鄂','湘','粤','桂','琼','川','贵','云','藏','陕','甘','青','宁','新','0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F','G','H','J','K','L','M','N','P','Q','R','S','T','U','V','W','X','Y','Z']char2code={c:ifori,cinenumerate(CHARS)}CLASS_NUM=len(CHARS)三、自定义数据集
classMyDataset(Dataset):def__init__(self,img_paths,transform=None):self.img_paths=img_paths self.transform=transformdef__len__(self):returnlen(self.img_paths)def__getitem__(self,idx):path=self.img_paths[idx]try:image=Image.open(path).convert('RGB')except:returntorch.zeros(3,224,224),torch.zeros(7,dtype=torch.long)ifself.transform:image=self.transform(image)# 解析标签filename=os.path.basename(path).split('.')[0]label=[]forcharinfilename:ifcharinchar2code:label.append(char2code[char])# 强制确保标签是 7 位,防止报错label=label[:7]iflen(label)<7:label=label+[0]*(7-len(label))returnimage,torch.tensor(label,dtype=torch.long)# 准备数据路径data_dir='./015_licence_plate'all_paths=glob.glob(os.path.join(data_dir,'*.jpg'))+glob.glob(os.path.join(data_dir,'*.png'))train_transforms=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])total_data=MyDataset(all_paths,transform=train_transforms)train_size=int(0.8*len(total_data))test_size=len(total_data)-train_size train_ds,test_ds=torch.utils.data.random_split(total_data,[train_size,test_size])train_loader=DataLoader(train_ds,batch_size=32,shuffle=True)test_loader=DataLoader(test_ds,batch_size=32,shuffle=False)四、搭建模型
classNet(nn.Module):def__init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(3,16,3,1,1);self.pool=nn.MaxPool2d(2,2)self.conv2=nn.Conv2d(16,32,3,1,1)self.conv3=nn.Conv2d(32,64,3,1,1)self.fc1=nn.Linear(64*28*28,128)self.dropout=nn.Dropout(0.5)# 输出 7 * 65self.fc2=nn.Linear(128,7*CLASS_NUM)defforward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=self.pool(F.relu(self.conv3(x)))x=x.view(-1,64*28*28)x=F.relu(self.fc1(x))x=self.dropout(x)x=self.fc2(x)# 调整形状为 [Batch, 7, 65]x=x.view(-1,7,CLASS_NUM)returnx model=Net().to(device)criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=0.001)五、训练与测试
deftrain(model,device,train_loader,optimizer,epoch):model.train()total_loss=0fordata,targetintrain_loader:data,target=data.to(device),target.to(device)optimizer.zero_grad()output=model(data)# 计算Loss需转置loss=criterion(output.transpose(1,2),target)loss.backward()optimizer.step()total_loss+=loss.item()returntotal_loss/len(train_loader)deftest(model,device,test_loader):model.eval()correct=0total=0withtorch.no_grad():fordata,targetintest_loader:data,target=data.to(device),target.to(device)output=model(data)# ACC统计更新部分# 1. 获取预测结果: [batch, 7, 65] -> [batch, 7]predicted_indices=output.argmax(dim=2)# 2. 计算准确率# 只有当一张图片的 7 个字符全部预测正确,才算这张图片对了 (.all(dim=1))# match_matrix = (predicted_indices == target)# correct_plates = match_matrix.all(dim=1)correct+=(predicted_indices==target).all(dim=1).sum().item()total+=target.size(0)acc=correct/totalprint(f"Test Accuracy:{acc:.2%}")returnacc六、结果可视化
if__name__=='__main__':epochs=5#train_loss_list=[]test_acc_list=[]forepochinrange(1,epochs+1):print(f"Epoch{epoch}Running...")loss=train(model,device,train_loader,optimizer,epoch)acc=test(model,device,test_loader)train_loss_list.append(loss)test_acc_list.append(acc)# 画图plt.plot(train_loss_list,label='Train Loss')plt.plot(test_acc_list,label='Test Accuracy')plt.legend()plt.show()七、总结
7.1 任务本质:多标签分类 (Multi-label Classification)
与之前的天气识别(单标签多分类)不同,车牌识别是一个典型的多标签分类任务。
- 输入:一张车牌图片。
- 输出:7 个独立的字符(省份+字母+5位数字/字母)。
- 维度变化:模型的输出不再是
[Batch, Class_Num],而是变成了[Batch, 7, Class_Num]。这意味着我们要对 7 个位置分别进行 65 种字符的预测。
7.2 核心难点:准确率 (Accuracy) 的统计逻辑
这是本周任务的重难点。在多标签任务中,准确率的定义非常严格:
- 逻辑:必须是一张车牌上的7 个字符全部预测正确,该样本才算正确。只要错一个字(比如把 ‘8’ 认成 ‘B’),整张车牌就算识别失败。
- 代码实现:
# 1. 获取最大概率索引 [batch, 7, 65] -> [batch, 7]predicted_indices=output.argmax(dim=2)# 2. 维度比对 (.all(dim=1) 是关键)# 只有当一行的 7 个 bool 值全为 True,结果才为 Truecorrect+=(predicted_indices==target).all(dim=1).sum().item()
7.3 实战踩坑:数据清洗的重要性
在训练过程中,我遇到了RuntimeError: Expected target size [32, 7], got [32, 17]的报错。
- 原因:数据集中存在文件名异常的图片(文件名长度不符合标准的 7 位车牌格式),导致标签长度不一致,无法组成 Batch。
- 解决:在
__getitem__数据读取阶段增加了长度校验与截断逻辑,强制保证输出标签长度为 7。