Skip to content

PyTorch 混合精度训练及 NaN 问题解决总结

混合精度训练概述

混合精度训练结合了 FP32(单精度浮点数)和 FP16(半精度浮点数)的优势,能在减少显存占用的同时提升训练速度。PyTorch 1.7 及以上版本通过 torch.cuda.amp 提供了原生支持。

示例代码

from torch.cuda.amp import autocast, GradScaler

model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

解决 NaN 问题

  1. 提高精度:若出现 NaN,可将相关代码块的精度从 FP16 提高到 BF16 或 FP32。例如,使用 with autocast(dtype=torch.bfloat16) 包裹易出问题的代码。
  2. 启用 TF32
  3. 在训练开始时添加以下代码启用 TF32:

    import torch.backends.cudnn as cudnn
    import torch.backends.cuda
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    
  4. TF32 兼具 FP32 的范围和 BF16 的精度,可提高训练效率且不易出现 NaN,但显存占用减少效果不如 FP16。

注意事项

  • 调试建议:定位出现问题的代码部分,针对性地调整精度设置。
  • 适用场景:混合精度训练适用于对显存要求高且需要加速训练的场景,但需注意精度调整对模型稳定性的影响。

用户评论

我常用的解决方法是load最后一个ckpt,关闭半精度用fp32 train一个epoch,继续用半精度训练就可以了,来自微软亚研同事的经验

Ref

https://zhuanlan.zhihu.com/p/675987100