news 2026/4/23 22:02:28

PyTorch实现二分类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
PyTorch实现二分类

二分类问题的实现方法,核心是把线性回归的 “连续值输出” 改成 “0/1 类别概率输出”。最基础常用的二分类模型基于逻辑回归(Logistic Regression)。

线性回归实现方式:PyTorch实现线性回归-CSDN博客

二分类本质上也是一种回归(Regression)问题,在上述线性回归的基础上修改就可以实现。下面是线性回归与二分类任务的差异:

环节线性回归(回归任务)二分类(分类任务)
输出目标连续数值(如 y=2x 的预测值)0/1 类别概率(0≤P≤1)
核心激活函数无(直接输出线性结果)Sigmoid(把线性输出映射到 0-1)
损失函数MSELoss(均方误差)BCELoss(二元交叉熵损失)
预测逻辑直接取输出值概率 > 0.5 归为 1 类,≤0.5 归为 0 类

1. 准备数据集(Prepare Dataset)

对比线性回归,数据格式还是 Tensor,但标签y_data是0/1 离散值,这是分类任务的核心特征。

import torch # 构造数据集:特征x(学分),标签y(0=不及格,1=及格) # 样本:[1.0], [2.0], [3.0] → 标签:0,0,1 x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[0], [0], [1]])

2. 设计模型(Design model)

class LogisticRegressionModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 核心:线性输出 + Sigmoid激活 → 映射到0-1概率 y_pred = torch.sigmoid(self.linear(x)) return y_pred model = LogisticRegressionModel()
  • Sigmoid函数公式:
  • forward中增加torch.sigmoid()把线性层的 “任意实数输出” 压缩到0~1 区间,这个值就是 “样本属于 1 类的概率”。

3. 构造损失函数(Construct Loss)

criterion = nn.BCELoss(reduction='sum')
  • BCELoss:二元交叉熵损失,是二分类的专用损失。

关于二元交叉熵损失函数的介绍,参考文章PyTorch_conda-CSDN博客中《nn.BCELoss(二元交叉熵损失)》一节。

4. 构造优化器(Construct Optimizer)

optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

5. 训练循环(Training Cycle)

for epoch in range(1000): y_pred = model(x_data) # 前向传播(计算预测值) loss = criterion(y_pred, y_data) # 计算损失值 print(epoch, loss.item()) optimizer.zero_grad() # 梯度清零 loss.backward() # 反向传播计算梯度 optimizer.step() # 更新参数

完整实例

import torch import torch.nn as nn x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[0], [0], [1]]) class LogisticRegressionModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(1, 1) def forward(self, x): return torch.sigmoid(self.linear(x)) model = LogisticRegressionModel() criterion = nn.BCELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for epoch in range(10000): y_pred = model(x_data) loss = criterion(y_pred, y_data) if epoch % 1000 == 0: print(f"Epoch: {epoch}, Loss: {loss.item():.4f}") optimizer.zero_grad() loss.backward() optimizer.step() x_test = torch.Tensor([[4.0]]) y_test_pred = model(x_test) print("\n测试结果:") print('y_pred = ', y_test_pred.data) # 查看模型参数 print(f"\n模型权重:{model.linear.weight.item():.6f}") print(f"模型偏置:{model.linear.bias.item():.6f}")
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/23 9:58:13

《边缘受限设备API客户端轻量化与功能适配实战指南》

不同IoT终端的资源禀赋与业务诉求存在天壤之别,环境感知类终端仅需完成基础数据上报的核心交互,工业现场传感终端则需兼顾指令接收与状态回传,楼宇监测终端还需适配间歇性的断网续传需求,这就决定轻量化设计绝不能采用一刀切的模式,必须基于终端硬件参数台账与业务场景图谱…

作者头像 李华
网站建设 2026/4/23 9:58:34

基于SpringBoot和Vue的实验室预约系统设计与实现

文章目录 详细视频演示项目介绍技术介绍功能介绍核心代码系统效果图源码获取 详细视频演示 文章底部名片,获取项目的完整演示视频,免费解答技术疑问 项目介绍 基于Spring Boot的实验室预约系统采用前后端分离架构,后端以Spring Boot为核心框…

作者头像 李华
网站建设 2026/4/23 11:18:39

从企业能耗集采到区域碳管理-智慧能源平台开发指南

先上干货! 墙内仓库地址(码云):https://gitee.com/guangdong122/energy-management 已同步更新到 github 仓库 温馨提示:文末有资源获取方式~ 能源系统|能源系统源码|企业能源系统|企业能源系统源码|能源监测系统 一…

作者头像 李华
网站建设 2026/4/23 12:12:19

机器学习面试问题及答案

摘要:本文整理了50个机器学习面试问题及答案,涵盖基础概念到高级应用。基础部分包括机器学习定义、监督/无监督学习、过拟合/欠拟合及解决方法、正则化、特征工程等核心概念。中级部分涉及线性回归、逻辑回归、决策树、随机森林等常用算法原理。高级部分…

作者头像 李华
网站建设 2026/4/23 12:14:25

二维钻孔封孔效果模拟案例解析

二维钻孔封孔效果模拟案例 钻孔封孔这事儿听着简单,实际在地下工程里可是个技术活。今天咱们拿MATLAB的PDE工具箱做个二维模拟,看看封孔材料怎么影响密封效果。先别急着关页面,代码部分我尽量说得像唠嗑,保证不催眠。 先整点基础…

作者头像 李华