多进程最佳实践¶
torch.multiprocessing 是 Python 的
multiprocessing 模块的直接替代品。它支持完全相同的操作,
但对其进行了扩展,因此所有通过
multiprocessing.Queue 发送的张量,其数据将被移动到共享内存中,并且只会发送一个句柄给另一个进程。
注意
当一个 Tensor 被发送到另一个进程时,Tensor 数据会被共享。如果 torch.Tensor.grad 不是 None,它也会被共享。在没有 torch.Tensor.grad 字段的 Tensor 被发送到其他进程后,它会创建一个标准的进程特定的 .grad Tensor,这个不会像 Tensor 的数据那样自动在所有进程中共享。
这使得可以实现各种训练方法,如 Hogwild、A3C 或任何其他需要异步操作的方法。
CUDA在多进程中的应用¶
CUDA运行时环境不支持fork启动方法;使用CUDA在子进程中需要spawn或forkserver启动方法。
注意
开始方法可以通过创建一个上下文并使用
multiprocessing.get_context(...) 或直接使用
multiprocessing.set_start_method(...) 来设置。
与CPU张量不同,发送进程需要保持原始张量,只要接收进程保留张量的副本。其实现是在内部完成的,但要求用户遵循最佳实践以确保程序正确运行。例如,发送进程必须保持活动状态,只要消费者进程对张量有引用,并且如果消费者进程通过致命信号异常退出,引用计数将无法保护你。请参阅 此部分。
另请参阅:使用 nn.parallel.DistributedDataParallel 而不是 multiprocessing 或 nn.DataParallel
最佳实践和技巧¶
避免和解决死锁¶
在新进程启动时可能会出现很多问题,最常见的死锁原因是后台线程。如果有任何线程持有锁或导入模块,并且调用了fork,那么子进程很可能会处于损坏状态并发生死锁或以其他方式失败。请注意,即使你没有这样做,Python 内置库也会这样做 - 不用再看别的,看看multiprocessing。
multiprocessing.Queue 实际上是一个非常复杂的类,它会生成多个用于序列化、发送和接收对象的线程,这些线程也可能导致上述问题。如果你发现自己处于这种情况,尝试使用一个不使用任何额外线程的multiprocessing.queues.SimpleQueue。
我们正在尽最大努力让您轻松,并确保这些死锁不会发生,但有些事情是我们无法控制的。如果您有一段时间无法应对的问题,请尝试在论坛上联系我们,我们会看看是否可以解决这个问题。
重用通过队列传递的缓冲区¶
请记住,每次将一个 Tensor 放入一个
multiprocessing.Queue 中时,都需要将其移动到共享内存中。
如果它已经是共享的,则是一个无操作,否则将会产生额外的内存复制,这可能会减慢整个过程。即使你有一个进程池向单个进程发送数据,也要让它将缓冲区发回 - 这几乎是免费的,并且在发送下一批数据时可以避免复制。
异步多进程训练(例如Hogwild)¶
使用 torch.multiprocessing,可以异步训练模型,参数要么一直共享,要么定期同步。在第一种情况下,我们建议发送整个模型对象,而在后一种情况下,我们建议只发送
state_dict()。
我们建议使用 multiprocessing.Queue 来在进程之间传递各种 PyTorch 对象。例如,在使用 fork 启动方法时,可以继承已经在共享内存中的张量和存储,但这非常容易出错,应谨慎使用,仅限高级用户。队列虽然有时不是最优雅的解决方案,但在所有情况下都能正常工作。
警告
你应该小心使用全局语句,这些语句没有被if __name__ == '__main__'保护。如果使用了不同于fork的启动方法,它们将在所有子进程中执行。
Hogwild¶
具体的Hogwild实现可以在示例仓库中找到, 但为了展示代码的整体结构,下面还有一个最小的示例:
import torch.multiprocessing as mp
from model import MyModel
def train(model):
# Construct data_loader, optimizer, etc.
for data, labels in data_loader:
optimizer.zero_grad()
loss_fn(model(data), labels).backward()
optimizer.step() # This will update the shared parameters
if __name__ == '__main__':
num_processes = 4
model = MyModel()
# NOTE: this is required for the ``fork`` method to work
model.share_memory()
processes = []
for rank in range(num_processes):
p = mp.Process(target=train, args=(model,))
p.start()
processes.append(p)
for p in processes:
p.join()