Pytorch图像去噪实战(二十六):AMP混合精度训练图像去噪模型,提升速度并降低显存占用
一、问题场景:模型训练太吃显存,batch size上不去
在训练 UNet、Restormer、Diffusion 这类图像去噪模型时,经常遇到:
CUDA out of memory尤其是:
- RGB图像训练
- 大patch训练
- Transformer模型
- Diffusion模型
- 多尺度UNet
我一开始的解决方式很粗暴:
- 减小 batch size
- 减小 patch size
- 减少模型通道数
但这样会影响训练稳定性和模型效果。
后来我开始使用 PyTorch AMP 混合精度训练,显存占用明显下降,训练速度也有提升。
二、什么是AMP混合精度训练?
AMP 全称 Automatic Mixed Precision。
简单理解:
部分计算使用 float16,关键计算仍保留 float32。
这样既能提高速度,又能减少显存占用。
Pytorch中