Skip to content

PyTorch 限制显存使用文档

总结

在 PyTorch 中,可以通过 torch.cuda.set_per_process_memory_fraction 函数限制单个进程的 GPU 显存使用量。该函数接受两个参数:显存使用比例(0 到 1 之间的浮点数)和目标 GPU 设备的索引。通过设置合适的比例,可以有效地控制每个进程占用的显存量,避免显存溢出(OOM)问题,特别是在多用户或多个任务共享同一张 GPU 卡的场景下。使用时需注意,该函数限制的是进程级别的显存,与 TensorFlow 的显存限制机制类似。

详细展开

函数介绍

torch.cuda.set_per_process_memory_fraction 是 PyTorch 提供的一个用于限制单个进程 GPU 显存使用的函数。其函数原型为:

torch.cuda.set_per_process_memory_fraction(fraction, device)
  • fraction:浮点数,表示允许该进程使用的显存比例,取值范围为 0 到 1。例如,设置为 0.5 表示该进程最多可使用一半的显存。
  • device:整数或 torch.device 对象,指定目标 GPU 设备的索引。

使用示例

以下是一个简单的使用示例,演示如何限制 0 号 GPU 设备的显存使用量为总显存的一半:

import torch

# 限制 0 号设备的显存使用量为总显存的 0.5
torch.cuda.set_per_process_memory_fraction(0.5, 0)

# 清空显存缓存
torch.cuda.empty_cache()

# 获取总显存大小
total_memory = torch.cuda.get_device_properties(0).total_memory

# 使用 0.499 的显存
tmp_tensor = torch.empty(int(total_memory * 0.499), dtype=torch.int8, device='cuda')

# 清空该显存
del tmp_tensor
torch.cuda.empty_cache()

# 下面的语句会触发显存 OOM 错误,因为刚好触碰到了上限
# torch.empty(total_memory // 2, dtype=torch.int8, device='cuda')

注意事项

  • 进程级限制:该函数限制的是单个进程的显存使用量,类似于 TensorFlow 的显存限制机制。这意味着在多进程环境下,每个进程都需要单独设置显存限制。
  • 显存释放:在限制显存使用后,若需释放显存,可以使用 torch.cuda.empty_cache() 清空缓存的显存,但已分配的显存需要手动删除相关张量后才能释放。
  • 错误处理:当显存使用超过限制时,会触发 CUDA out of memory 错误,错误信息中会包含允许使用的显存大小提示。

背景与应用场景

在实际使用 PyTorch 时,可能会遇到需要限制显存上限的场景,例如:

  • 多用户共享 GPU:多个用户共用一张 GPU 卡,每个用户按一定比例使用显存,以确保公平性和资源合理分配。
  • 多任务共享 GPU:一个用户运行多个任务,每个任务对显存的占用按照一定比例进行限制,以充分利用 GPU 算力且避免任务间显存冲突。

函数原理与实现

PyTorch 的动态构图机制使得其默认按需申请显存,类似于 TensorFlow 中的 allow_growth 模式。然而,这种机制在多用户或多任务共享 GPU 时可能会导致显存被某个任务动态增长占用,从而引发其他任务的显存不足问题。因此,引入 set_per_process_memory_fraction 函数来限制单个进程的显存使用上限。

该函数通过在底层 CUDA 运行时 API 的基础上,结合 PyTorch 自身的显存管理机制(如 Allocator 的内存块管理),实现对单个进程显存使用的限制。当显存申请超过设定的比例时,会直接报显存溢出错误,而不是尝试将超出部分放入主机内存中,因为这样会增加通信开销并降低训练/推理速度。

其他解决方案

除了使用 torch.cuda.set_per_process_memory_fraction 函数外,还可以通过其他方式限制显存使用,例如:

  • 劫持 CUDA API:通过修改或拦截 CUDA 相关的 API 调用,实现对显存使用的限制。
  • 利用 NVML 监控:使用 NVIDIA Management Library(NVML)来监控和限制显存使用情况。

然而,这些方法可能需要更复杂的实现和维护工作,而 set_per_process_memory_fraction 函数提供了一种相对简单且直接的解决方案。