目录

多进程最佳实践

torch.multiprocessing 是 Python 的 multiprocessing 模块的直接替代品。它支持完全相同的操作, 但对其进行了扩展,因此所有通过 multiprocessing.Queue 发送的张量,其数据将被移动到共享内存中,并且只会发送一个句柄给另一个进程。

注意

当一个 Tensor 被发送到另一个进程时,Tensor 数据会被共享。如果 torch.Tensor.grad 不是 None,它也会被共享。在没有 torch.Tensor.grad 字段的 Tensor 被发送到其他进程后,它会创建一个标准的进程特定的 .grad Tensor,这个不会像 Tensor 的数据那样自动在所有进程中共享。

这使得可以实现各种训练方法,如 Hogwild、A3C 或任何其他需要异步操作的方法。

CUDA在多进程中的应用

CUDA运行时环境不支持fork启动方法;使用CUDA在子进程中需要spawnforkserver启动方法。

注意

开始方法可以通过创建一个上下文并使用 multiprocessing.get_context(...) 或直接使用 multiprocessing.set_start_method(...) 来设置。

与CPU张量不同,发送进程需要保持原始张量,只要接收进程保留张量的副本。其实现是在内部完成的,但要求用户遵循最佳实践以确保程序正确运行。例如,发送进程必须保持活动状态,只要消费者进程对张量有引用,并且如果消费者进程通过致命信号异常退出,引用计数将无法保护你。请参阅 此部分

另请参阅:使用 nn.parallel.DistributedDataParallel 而不是 multiprocessing 或 nn.DataParallel

最佳实践和技巧

避免和解决死锁

在新进程启动时可能会出现很多问题,最常见的死锁原因是后台线程。如果有任何线程持有锁或导入模块,并且调用了fork,那么子进程很可能会处于损坏状态并发生死锁或以其他方式失败。请注意,即使你没有这样做,Python 内置库也会这样做 - 不用再看别的,看看multiprocessingmultiprocessing.Queue 实际上是一个非常复杂的类,它会生成多个用于序列化、发送和接收对象的线程,这些线程也可能导致上述问题。如果你发现自己处于这种情况,尝试使用一个不使用任何额外线程的SimpleQueue

我们正在尽最大努力让您轻松,并确保这些死锁不会发生,但有些事情是我们无法控制的。如果您有一段时间无法应对的问题,请尝试在论坛上联系我们,我们会看看是否可以解决这个问题。

重用通过队列传递的缓冲区

请记住,每次将一个 Tensor 放入一个 multiprocessing.Queue 中时,都需要将其移动到共享内存中。 如果它已经是共享的,则是一个无操作,否则将会产生额外的内存复制,这可能会减慢整个过程。即使你有一个进程池向单个进程发送数据,也要让它将缓冲区发回 - 这几乎是免费的,并且在发送下一批数据时可以避免复制。

异步多进程训练(例如Hogwild)

使用 torch.multiprocessing,可以异步训练模型,参数要么一直共享,要么定期同步。在第一种情况下,我们建议发送整个模型对象,而在后一种情况下,我们建议只发送 state_dict()

我们建议使用 multiprocessing.Queue 来在进程之间传递各种 PyTorch 对象。例如,在使用 fork 启动方法时,可以继承已经在共享内存中的张量和存储,但这非常容易出错,应谨慎使用,仅限高级用户。队列虽然有时不是最优雅的解决方案,但在所有情况下都能正常工作。

警告

你应该小心使用全局语句,这些语句没有被if __name__ == '__main__'保护。如果使用了不同于fork的启动方法,它们将在所有子进程中执行。

Hogwild

具体的Hogwild实现可以在示例仓库中找到, 但为了展示代码的整体结构,下面还有一个最小的示例:

import torch.multiprocessing as mp
from model import MyModel

def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters

if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源