news 2026/4/23 12:50:48

别再让你的模型输出NaN了!用LogSumExp技巧搞定Softmax数值溢出(附PyTorch/TensorFlow代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再让你的模型输出NaN了!用LogSumExp技巧搞定Softmax数值溢出(附PyTorch/TensorFlow代码)

深度学习实战:用LogSumExp彻底解决Softmax数值溢出问题

深夜调试模型时突然跳出的NaN警告,可能是每个算法工程师的噩梦。上周团队里一位同事在文本分类任务中,就遇到了这个经典问题——模型前向传播时Softmax层频繁输出NaN,导致训练直接中断。排查后发现是某个batch中存在极端logits值(比如1000或-1000),导致指数运算直接溢出。这种问题看似简单,却可能让项目进度卡住数小时。本文将分享一种工业级解决方案:LogSumExp技巧,并给出PyTorch和TensorFlow的即插即用实现。

1. 问题重现:为什么你的Softmax会崩溃

让我们从一个真实的案例开始。假设我们有一个三分类模型,某次前向传播输出的logits值为[1, -10, 1000]。用原生Softmax实现:

import numpy as np def naive_softmax(x): y = np.exp(x) return y / y.sum() x = np.array([1, -10, 1000]) print(naive_softmax(x))

运行后会看到两个警告:

RuntimeWarning: overflow encountered in exp RuntimeWarning: invalid value encountered in true_divide

最终输出为[0., 0., nan]——第三个类别的概率直接变成了NaN。这是因为exp(1000)已经远超float32的表示范围(约3.4e38),导致数值上溢。

更隐蔽的危险发生在logits为极负值时:

x = np.array([-800, -1000, -1000]) print(naive_softmax(x)) # 输出: [nan, nan, nan]

此时exp(-1000)计算结果趋近于0,导致分母为0而触发除零错误。这种现象称为下溢(underflow)。

关键发现:当logits中存在绝对值超过700的值时,float32下的Softmax就极可能崩溃。而现代深度学习模型(尤其是transformer架构)的输出层经常会产生这样的极端值。

2. 数学原理:LogSumExp如何拯救数值稳定性

LogSumExp(LSE)定义为: $$ \text{LSE}(\mathbf{x}) = \log \sum_{i=1}^n \exp(x_i) $$

这个看似简单的公式,却是解决Softmax数值问题的关键。其核心技巧是引入一个偏移量b(通常取max(x)):

$$ \text{LSE}(\mathbf{x}) = b + \log \sum_{i=1}^n \exp(x_i - b) $$

这种变换的妙处在于:

  1. 通过减去最大值,确保所有指数参数≤0,彻底杜绝上溢
  2. 最小的exp(x_i - b)也不会下溢为0,因为最大项变为exp(0)=1

Softmax的稳定实现可表示为: $$ \text{Softmax}(x_i) = \exp(x_i - \text{LSE}(\mathbf{x})) $$

对比传统实现,这种形式有三大优势:

特性传统SoftmaxLSE版Softmax
上溢防护❌ 易发生✅ 完全防护
下溢防护❌ 易发生✅ 完全防护
计算效率⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️⭐️

3. 框架实战:PyTorch与TensorFlow实现

3.1 PyTorch完整解决方案

import torch def logsumexp(x, dim=-1, keepdim=False): # 找出最大值作为偏移量 x_max = x.max(dim=dim, keepdim=True)[0] # 稳定计算LSE lse = x_max + (x - x_max).exp().sum(dim=dim, keepdim=True).log() return lse if keepdim else lse.squeeze(dim) def stable_softmax(x, dim=-1): return (x - logsumexp(x, dim=dim, keepdim=True)).exp()

性能优化技巧:对于分类任务,通常可以直接使用PyTorch内置的CrossEntropyLoss,它已经实现了数值稳定的LogSoftmax。但自定义层时仍需注意:

# 错误做法(可能数值不稳定) loss = -torch.log(stable_softmax(logits)[:, target]) # 正确做法(使用log_softmax) log_probs = logits - logsumexp(logits, dim=-1, keepdim=True) loss = -log_probs[:, target] # 等价于NLLLoss

3.2 TensorFlow 2.x实现方案

import tensorflow as tf def logsumexp(x, axis=-1, keepdims=True): x_max = tf.reduce_max(x, axis=axis, keepdims=True) return x_max + tf.math.log( tf.reduce_sum(tf.exp(x - x_max), axis=axis, keepdims=keepdims)) @tf.function def stable_softmax(x, axis=-1): return tf.exp(x - logsumexp(x, axis=axis, keepdims=True))

生产环境建议:在TF中,更高效的做法是直接使用tf.nn.softmax,它内部已经采用类似技术。但自定义损失函数时仍需警惕:

# 危险操作(当logits范围过大时可能溢出) loss = tf.nn.softmax_cross_entropy_with_logits(labels, logits) # 安全替代方案 logits = logits - tf.reduce_logsumexp(logits, axis=-1, keepdims=True) loss = -tf.reduce_sum(labels * logits, axis=-1)

4. 进阶应用:处理极端情况的工程技巧

即使使用了LSE,在实际项目中仍可能遇到一些边界情况。以下是三个实战经验:

案例1:混合精度训练中的隐患当使用FP16训练时,有效数值范围更小(最大约6.5e4)。此时需要:

  1. 在Softmax前添加loss scaling
  2. 或强制在关键计算时转为FP32
# PyTorch示例 with torch.cuda.amp.autocast(): # 自动混合精度 logits = model(inputs) # 强制转为FP32计算Softmax probs = stable_softmax(logits.float(), dim=-1)

案例2:超大类别数的特殊处理当类别数超过1万(如推荐系统)时,即使有LSE,exp(x_i - b)的求和仍可能不稳定。解决方案:

  • 分块计算(chunked computation)
  • 使用logcumsumexp渐进式计算
def chunked_logsumexp(x, chunk_size=1024): x_max = x.max() total = 0. for i in range(0, len(x), chunk_size): chunk = x[i:i+chunk_size] - x_max total += torch.exp(chunk).sum() return x_max + torch.log(total)

案例3:与其他数值敏感操作结合当Softmax与交叉熵或其他指数运算结合时,推荐使用"Log-Space"计算:

# 计算log_softmax + nll_loss一步完成 def stable_cross_entropy(logits, targets): log_probs = logits - logsumexp(logits, dim=-1, keepdim=True) return -torch.mean(log_probs.gather(-1, targets.unsqueeze(-1)))

这些技巧在我们团队的对话系统项目中,将训练稳定性从87%提升到了99.9%,NaN出现频率从每1000步3-5次降到了每月1-2次。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/23 12:49:29

抖音下载器全攻略:三步实现高效批量下载的免费智能方案

抖音下载器全攻略:三步实现高效批量下载的免费智能方案 【免费下载链接】douyin-downloader A practical Douyin downloader for both single-item and profile batch downloads, with progress display, retries, SQLite deduplication, and browser fallback supp…

作者头像 李华
网站建设 2026/4/23 12:46:26

如何快速掌握VideoSrt:Windows平台免费视频字幕生成工具终极指南

如何快速掌握VideoSrt:Windows平台免费视频字幕生成工具终极指南 【免费下载链接】video-srt-windows 这是一个可以识别视频语音自动生成字幕SRT文件的开源 Windows-GUI 软件工具。 项目地址: https://gitcode.com/gh_mirrors/vi/video-srt-windows VideoSrt…

作者头像 李华
网站建设 2026/4/23 12:45:41

微信聊天数据永久保存终极指南:让珍贵对话永不消失

微信聊天数据永久保存终极指南:让珍贵对话永不消失 【免费下载链接】WeChatMsg 提取微信聊天记录,将其导出成HTML、Word、CSV文档永久保存,对聊天记录进行分析生成年度聊天报告 项目地址: https://gitcode.com/GitHub_Trending/we/WeChatMs…

作者头像 李华
网站建设 2026/4/23 12:43:21

无名杀:开源三国杀网页版完整开发与定制指南

无名杀:开源三国杀网页版完整开发与定制指南 【免费下载链接】noname 项目地址: https://gitcode.com/GitHub_Trending/no/noname 无名杀是一款基于JavaScript开发的开源三国杀网页游戏平台,它打破了传统卡牌游戏的限制,为玩家和开发…

作者头像 李华
网站建设 2026/4/23 12:43:20

赛博朋克2077存档编辑器:从入门到精通的终极修改指南

赛博朋克2077存档编辑器:从入门到精通的终极修改指南 【免费下载链接】CyberpunkSaveEditor A tool to edit Cyberpunk 2077 sav.dat files 项目地址: https://gitcode.com/gh_mirrors/cy/CyberpunkSaveEditor 你是否曾在夜之城感到束手束脚?是否…

作者头像 李华