Generic Join Context Manager¶
The generic join context manager facilitates distributed training on uneven
inputs. This page outlines the API of the relevant classes: Join,
Joinable, and JoinHook. For a tutorial, see
Distributed Training with Uneven Inputs Using the Join Context Manager.
- 
class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[source]¶
- This class defines the generic join context manager, which allows custom hooks to be called after a process joins. These hooks should shadow the collective communications of non-joined processes to prevent hanging and erroring and to ensure algorithmic correctness. Refer to - JoinHookfor details about the hook definition.- Warning - The context manager requires each participating - Joinableto call the method- notify_join_context()before its own per- iteration collective communications to ensure correctness.- Warning - The context manager requires that all - process_groupattributes in the- JoinHookobjects are the same. If there are multiple- JoinHookobjects, then the- deviceof the first is used. The process group and device information is used for checking for non- joined processes and for notifying processes to throw an exception if- throw_on_early_terminationis enabled, both of which using an all- reduce.- Parameters
- joinables (List[Joinable]) – a list of the participating - Joinables; their hooks are iterated over in the given order.
- enable (bool) – a flag enabling uneven input detection; setting to - Falsedisables the context manager’s functionality and should only be set when the user knows the inputs will not be uneven (default:- True).
- throw_on_early_termination (bool) – a flag controlling whether to throw an exception upon detecting uneven inputs (default: - False).
 
 - Example: - >>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # On each spawned worker >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # Rank 1 gets one more input than rank 0 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # All ranks reach here without hanging/erroring - 
static notify_join_context(joinable)[source]¶
- Notifies the join context manager that the calling process has not yet joined; then, if - throw_on_early_termination=True, checks if uneven inputs have been detected (i.e. if one process has already joined) and throws an exception if so.- This method should be called from a - Joinableobject before its per-iteration collective communications. For example, this should be called at the beginning of the forward pass in- DistributedDataParallel.- Only the first - Joinableobject passed into the context manager performs the collective communications in this method, and for the others, this method is vacuous.
 
- 
class torch.distributed.algorithms.Joinable[source]¶
- This defines an abstract base class for joinable classes. A joinable class (inheriting from - Joinable) should implement- join_hook(), which returns a- JoinHookinstance, in addition to- join_device()and- join_process_group()that return device and process group information, respectively.- 
abstract property join_device¶
- Returns the device from which to perform collective communications needed by the join context manager implementation itself. 
 - 
abstract property join_process_group¶
- Returns the process group for the collective communications needed by the join context manager itself. 
 
- 
abstract property 
- 
class torch.distributed.algorithms.JoinHook[source]¶
- This defines a join hook, which provides two entry points in the join context manager: a main hook, which is called repeatedly while there exists a non-joined process, and a post-hook, which is called once all processes have joined. - To implement a join hook for the generic join context manager, define a class that inherits from - JoinHookand override- main_hook()and- post_hook()as appropriate.