PyTorch 混合精度训练及 NaN 问题解决总结
混合精度训练概述
混合精度训练结合了 FP32(单精度浮点数)和 FP16(半精度浮点数)的优势,能在减少显存占用的同时提升训练速度。PyTorch 1.7 及以上版本通过 torch.cuda.amp 提供了原生支持。
示例代码
from torch.cuda.amp import autocast, GradScaler
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
解决 NaN 问题
- 提高精度:若出现 NaN,可将相关代码块的精度从 FP16 提高到 BF16 或 FP32。例如,使用
with autocast(dtype=torch.bfloat16)包裹易出问题的代码。 - 启用 TF32:
-
在训练开始时添加以下代码启用 TF32:
-
TF32 兼具 FP32 的范围和 BF16 的精度,可提高训练效率且不易出现 NaN,但显存占用减少效果不如 FP16。
注意事项
- 调试建议:定位出现问题的代码部分,针对性地调整精度设置。
- 适用场景:混合精度训练适用于对显存要求高且需要加速训练的场景,但需注意精度调整对模型稳定性的影响。
用户评论
我常用的解决方法是load最后一个ckpt,关闭半精度用fp32 train一个epoch,继续用半精度训练就可以了,来自微软亚研同事的经验