news 2026/4/23 9:18:38

线性拟合模型

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
线性拟合模型

线性拟合模型

一、数据准备部分

importnumpyasnpimportkerasimportmatplotlib.pyplotasplt train_X=np.asarray([30.0,40.0,60.0,80.0,100.0,120.0,140.0])train_Y=np.asarray([320.0,360.0,400.0,455.0,490.0,546.0,580.0])train_X/=100.0train_Y/=100.0
  • train_Xtrain_Y是人工构造的训练数据(x 和 y)。

  • 除以 100 是为了归一化(Normalization),将数据范围从 [30-140] 和 [320-580] 缩放到 [0.3-1.4] 和 [3.2-5.8]),有助于神经网络更快收敛。

  • 这是典型的监督学习回归问题:输入 x → 预测 y。

二、可视化函数

defplot_points(x,y,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')plt.scatter(x,y)plt.show()defplot_line(W,b,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')x=np.linspace(0.0,2.0,num=100)y=W*x+b plt.plot(x,y)plt.show()
  • plot_points:画散点图,展示原始数据。

  • plot_line:根据斜率W和截距b画出拟合直线。

三、模型构建

model=keras.models.Sequential()model.add(keras.layers.Dense(units=1,input_dim=1))
  • 只有一层:Dense全连接层
  • units=1:只有一个神经元(输出一个值)
  • input_dim=1:输入数据是一维的(一个特征)
  • 相当于数学公式:y = Wx + b,其中:
    • W:权重(weight),相当于斜率
    • b:偏置(bias),相当于截距

四、编译模型

model.compile(optimizer='sgd',loss='mean_squared_error')
  • optimizer='sgd':使用随机梯度下降优化器
    • SGD是最基础、最经典的优化算法
    • 相比adam,SGD更简单,适合这种简单线性问题
  • loss='mean_squared_error':使用均方误差作为损失函数
    • 计算公式:MSE = Σ(y_pred - y_true)² / n
    • 这是回归问题最常用的损失函数

五、训练模型

history=model.fit(x=train_X,y=train_Y,batch_size=1,epochs=10)
  • batch_size=1批大小为1(在线学习/随机梯度下降)
    • 每看一个样本就更新一次权重
    • 梯度更新频繁,波动较大
    • 内存占用小,适合小数据集
  • epochs=10:训练10轮
    • 把7个样本反复训练10遍
    • 总共训练 7 × 10 = 70 次更新

注意history会记录训练过程中的loss变化,可以用于后续分析

六. 结果可视化

plot_line(model.get_weights()[0][0][0],model.get_weights()[1][0],title_name='Current_Model')
  • model.get_weights()[0]:获取权重W(斜率)
    • [0][0][0]是因为权重的形状是(1,1),需要索引到具体数值
  • model.get_weights()[1]:获取偏置b(截距)
    • [0]是因为偏置的形状是(1,),需要索引到具体数值

这个模型在做什么?

1. 数学本质

这个模型其实就是用神经网络的方式来实现最小二乘法线性回归

  • 要找一条直线y = Wx + b
  • 让这条直线最接近所有数据点
  • "接近"的标准是:均方误差最小

2. 训练过程(SGD)

初始化:W=随机值,b=随机值for10:for每个样本(x_i,y_i):1.计算预测值:y_pred=W*x_i+b2.计算误差:error=y_pred-y_i3.计算梯度:dW=2*error*x_i# 对W的梯度db=2*error# 对b的梯度4.更新参数:W=W-learning_rate*dW b=b-learning_rate*db

完整代码:

importnumpyasnpimportkerasimportmatplotlib.pyplotasplt train_X=np.asarray([30.0,40.0,60.0,80.0,100.0,120.0,140.0])train_Y=np.asarray([320.0,360.0,400.0,455.0,490.0,546.0,580.0])train_X/=100.0train_Y/=100.0#用于对数据点进行可视化defplot_points(x,y,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')plt.scatter(x,y)plt.show()defplot_line(W,b,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')x=np.linspace(0.0,2.0,num=100)y=W*x+b plt.plot(x,y)plt.show()plot_points(train_X,train_Y,title_name='Training Points')#建立线性拟合模型,由斜率和偏移两个参数构成,相当于神经元数为1的一层全连接model=keras.models.Sequential()model.add(keras.layers.Dense(units=1,input_dim=1))#成本函数采用均差误差,优化方法使用随机梯度下降model.compile(optimizer='sgd',loss='mean_squared_error')#模型迭代10个轮次,用单样本的方式进行优化history=model.fit(x=train_X,y=train_Y,batch_size=1,epochs=10)plot_line(model.get_weights()[0][0][0],model.get_weights()[1][0],title_name='Current_Model')

附解释可视化函数部分
1.散点图
def plot_points(x, y, title_name):

  • 定义一个名为plot_points的函数。

    x:横坐标数据(如你的 train_X)
    y:纵坐标数据(如你的 train_Y)
    title_name:图表的标题(字符串)

​ plt.title(title_name) # 设置图表标题
​ plt.xlabel(‘x’) # 设置x轴标签
​ plt.ylabel(‘y’) # 设置y轴标签
​ plt.scatter(x, y) # 绘制散点图
​ plt.show() # 显示图表

2.直线图
def plot_line(W, b, title_name):
plt.title(title_name) # 设置图表标题
plt.xlabel(‘x’) # 设置x轴标签
plt.ylabel(‘y’) # 设置y轴标签

​ x = np.linspace(0.0, 2.0, num=100) # 生成100个等间距的x值
​ np: numpy模块的别名
​ .linspace(): 生成等差数列(linear space)
​ 参数:
​ 0.0: 起始值(start)
​ 2.0: 结束值(stop)
​ num=100: 生成100个点

​ y = W * x + b # 计算对应的y值

​ plt.plot(x, y) # 绘制折线图(这里是直线)
​ .plot(): 绘制折线图
​ 参数:(x, y)坐标点

​ plt.show() # 显示图表

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/20 17:14:41

昇腾310P平台强化学习训练环境搭建实战:基于Qwen2.5-7B的完整部署流程

目录引一、Docker环境准备1.1 镜像选择与下载1.2 创建容器坑1: 镜像ID混淆**坑2: 容器秒退**1.3 正确的创建方式1.4 进入容器二、Python环境配置2.1 安装Miniconda2.2 激活conda环境2.3 创建Python 3.10环境三、安装PyTorch与昇腾支持3.1 安装PyTorch 2.5.13.2 安装torch-npu四…

作者头像 李华
网站建设 2026/4/16 14:28:32

精密仪器中的微型导轨如何选对润滑脂?

微型导轨是一种高精度、小型化的直线运动元件,具备体积小、负载能力强、摩擦系数低等特点。被广泛应用于精密仪器、医疗设备、半导体设备、机器人等领域。其运行稳定性与寿命高度依赖润滑脂的性能,选型不当易导致磨损加剧、噪音增大甚至故障停机。那么&a…

作者头像 李华
网站建设 2026/4/22 5:18:13

如何选择德诺超声波焊接机才合适?

在选择德诺超声波焊接机时,用户需要关注多个关键因素。首先,设备性能是重中之重,包括功率、频率与工作效率,这直接影响焊接质量。此外,维修服务也是必不可少的,确保在设备出现故障时能够快速恢复生产。与此…

作者头像 李华
网站建设 2026/4/20 2:08:22

不花一分钱广告,月增3000客户?一招让客户主动帮你介绍三个月

私域圈里有个常见困惑:曾经风靡的“推三返一”模式,为什么总是火一阵就凉?不少商家初期靠它快速拉新,用户为了“免费拿”主动分享。可没过俩月,参与度断崖式下滑,最终只剩老板在群里自嗨。其实不是模式不行…

作者头像 李华
网站建设 2026/4/16 7:20:11

JSM9N20C 200V N 沟道 MOSFET

在电力电子技术日新月异的当下,MOSFET 作为承载能量转换与电路控制的核心功率器件,其性能表现直接决定了终端产品的能效水平、运行稳定性与使用寿命。从工业自动化生产线的核心控制单元,到消费电子的高效电源适配器;从新能源领域的…

作者头像 李华
网站建设 2026/4/21 19:25:55

GraniStudio:IO写入例程

1.文件运行 导入工程 双击运行桌面GraniStudio.exe。 通过引导界面导入IO写入例程,点击导入按钮。 打开IO写入例程所在路径,选中IO写入.gsp文件,点击打开,完成导入。 2.功能说明 实现输出IO控制以及读取。 2.1通过初始化IO算子…

作者头像 李华