torch.distributed.tensor¶
注意
torch.distributed.tensor 当前处于 alpha 状态并正在开发中,我们致力于确保文档中列出的大多数 API 的向后兼容性,但如果有必要,API 可能会发生更改。
PyTorch 分布式张量 (DTensor)¶
PyTorch DTensor 提供简单灵活的张量分片原语,透明地处理分布式逻辑,包括分片存储、算子计算和设备/主机间的集体通信。
DTensor 可用于构建不同的并行解决方案,并在处理多维分片时支持分片状态字典表示。
请查看基于DTensor构建的PyTorch原生并行解决方案的示例:
DTensor 遵循 SPMD(单程序,多数据)编程模型,使用户能够编写分布式程序,就像它是具有相同收敛性质的单设备程序一样。它通过指定DeviceMesh和Placement提供了一种统一的张量分片布局(DTensor 布局):
DeviceMesh表示设备拓扑和集群的通信者,使用一个 n 维数组。Placement描述了逻辑张量在DeviceMesh上的分片布局。 DTensor 支持三种放置类型:Shard,Replicate和Partial。
DTensor 类 API 文档¶
DTensor 是一个 torch.Tensor 的子类。这意味着一旦创建了一个 DTensor,就可以以非常相似的方式使用它,就像使用 torch.Tensor 一样,包括运行不同类型的 PyTorch 操作符,就好像在单个设备上运行它们一样,从而允许 PyTorch 操作符的适当分布式计算。
除了现有的 torch.Tensor 种方法外,它还提供了一组额外的方法来与
torch.Tensor、redistribute 的DTensor布局进行交互,将其转换为新的DTensor,在所有设备上获取完整的张量内容等。
- class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)¶
DTensor(分布式张量)是torch.Tensor的一个子类,它为使用多设备torch.Tensor编程提供了单设备类似的抽象。它通过DeviceMesh和以下类型的Placement描述分布式张量分片布局(DTensor 布局):Shard: 张量在第dim维度上分片,在第DeviceMesh维度的设备上Replicate: 张量在DeviceMesh维设备上复制Partial: 张量在DeviceMesh维度的设备上等待减少
在调用PyTorch算子时,
DTensor会覆盖PyTorch算子以执行分片计算并在必要时进行通信。除了算子计算外,DTensor还会根据算子语义正确地转换或传播位置(DTensor布局)并生成新的DTensor输出。为了确保调用PyTorch操作时
DTensor分片计算的数值正确性,DTensor要求操作符的每个张量参数都是DTensor。- Return type
- property device_mesh: DeviceMesh¶
与该DTensor对象关联的
DeviceMesh属性。注意
device_mesh是一个只读属性,不能被设置。
- static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source]¶
从每个排名的本地torch.Tensor创建一个
DTensor,根据指定的device_mesh和placements。- Parameters
本地张量 (torch.Tensor) – 每个rank上的本地torch.Tensor。
设备网格 (
DeviceMesh,可选) – 放置张量的设备网格,如果未指定,则必须在设备网格上下文管理器下调用,默认值:None位置 (List[
Placement],可选) - 描述如何在DeviceMesh上放置本地torch.Tensor的位置,必须与device_mesh.ndim具有相同数量的元素。
- Keyword Arguments
run_check (bool, 可选) – 以额外通信为代价,跨 ranks 执行完整性检查,验证每个本地张量的元信息以确保正确性。如果有
Replicate在placements中,则设备网格维度的第一个 rank 的数据将广播到其他 ranks。默认值:False形状 (torch.Size, 可选) – 一个指定构建在 local_tensor 之上的 DTensor 的大小的整数列表。请注意,如果
local_tensor在各个等级上的形状不同,则需要提供此参数。如果没有提供,将假设给定的分布式张量在各等级之间均匀分片,并据此计算shape。默认值:Nonestride (元组, 可选) – 一个整数列表,用于指定 DTensor 的步长。 如果未提供,则假设给定的分布式张量在各个进程之间均匀切分,并计算
stride。默认值:None
- Returns
一个
DTensor对象- Return type
注意
当
run_check=False时,确保传入的本地张量在各个排名之间是正确的(即张量为Shard(dim)位置进行分片或为Replicate()位置进行复制)是用户的责任。 如果不正确,则创建的DTensor的行为将是未定义的。注意
from_local是可微分的,创建的对象的 requires_grad 将取决于 local_tensor 是否设置了 requires_grad。
- full_tensor(*, grad_placements=None)[source]¶
返回此 DTensor 的完整张量。它将执行必要的集合操作,从其设备网格中的其他排名收集本地张量,并将它们拼接在一起。这是以下代码的语法糖:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()- Keyword Arguments
grad_placements (List[
Placement], optional) – 描述了此函数返回的完整张量的梯度布局的未来布局。 full_tensor 将DTensor转换为完整的torch.Tensor,并且返回的torch.tensor可能在后续代码中不能作为原始复制的DTensor布局使用。此参数是用户可以给autograd的提示,以防返回张量的梯度布局与原始复制的DTensor布局不匹配。如果未指定,则假定完整张量的梯度布局为复制布局。- Returns
一个
torch.Tensor对象,表示此DTensor的完整张量。- Return type
注意
full_tensor是可微分的。
- property placements: Tuple[Placement, ...]¶
此 DTensor 的 placements 属性描述了该 DTensor 在其设备网格中的布局。
注意
placements是一个只读属性,不能被设置。
- redistribute(device_mesh=None, placements=None, *, async_op=False)[source]¶
redistribute执行必要的集体操作,将当前的 DTensor 从其当前位置重新分布到新的位置,或者从当前的 DeviceMesh 转换到新的 DeviceMesh。也就是说,我们可以通过为 DeviceMesh 的每个维度指定 Replicate 位置,将一个分片的 DTensor 转换为一个复制的 DTensor。在从当前位置重新分布到一个设备网状维度的新位置时,我们将执行以下操作,包括通信集合操作或本地操作:
Shard(dim)->Replicate():all_gatherShard(src_dim)->Shard(dst_dim):all_to_allReplicate()->Shard(dim): 局部切分(即torch.chunk)Partial()->Replicate():all_reducePartial()->Shard(dim):reduce_scatter
redistribute会正确地计算出在1-D或N-D设备网格上创建的DTensor所需的重新分布步骤。- Parameters
设备网格 (
DeviceMesh,可选) – 放置张量的设备网格。如果没有指定,则使用当前张量的设备网格。 默认值:无位置 (List[
Placement],可选) – 新的位置描述了如何将DTensor放置到设备网格中,必须 与device_mesh.ndim具有相同数量的元素。 默认值:在所有网格维度上复制
- Keyword Arguments
async_op (bool, 可选) – 是否异步执行DTensor重分布操作。默认值:False
- Returns
一个
DTensor对象- Return type
注意
redistribute是可微分的,这意味着用户无需担心重新分配操作的反向公式。注意
redistribute当前仅支持在同一DeviceMesh上重新分布DTensor, 请提交问题以获取在不同DeviceMesh之间重新分布DTensor的支持。
- to_local(*, grad_placements=None)[source]¶
获取此张量对象在当前设备上的本地张量。对于分片操作,它返回逻辑张量视图的局部片段;对于复制操作,它返回当前设备上的副本。
- Keyword Arguments
grad_placements (List[
Placement], optional) – 描述了此函数返回的张量的梯度布局的未来布局。 to_local 将DTensor转换为本地张量,返回的本地张量在后续代码中可能无法再用作原始DTensor布局。此参数是用户可以给自动微分提供的一种提示,以防返回张量的梯度布局与原始DTensor布局不匹配。如果未指定,默认假设梯度布局与原始DTensor相同,并据此进行梯度计算。- Returns
一个
torch.Tensor或AsyncCollectiveTensor对象。它表示当前排名上的本地张量。当返回一个AsyncCollectiveTensor对象时,意味着本地张量尚未准备好(即通信尚未完成)。在这种情况下,用户需要调用wait来等待本地张量准备好。- Return type
注意
to_local是可微分的,返回的本地张量的requires_grad取决于 DTensor 是否需要计算梯度。
设备网格作为分布式通信器¶
DeviceMesh 是从 DTensor 演变而来,作为描述集群设备拓扑的抽象,并表示多维通信器(基于 ProcessGroup)。有关如何创建/使用设备网格的详细信息,请参阅 设备网格配方。
张量放置类型¶
DTensor 支持以下类型的 Placement 在每个 DeviceMesh 维度上:
- class torch.distributed.tensor.placement_types.Shard(dim)[source]¶
The
Shard(dim)placement describes the DTensor sharding on tensor dimensiondimover a correspondingDeviceMeshdimension, where each rank on the DeviceMesh dimension only holds a shard/piece of the global Tensor. TheShard(dim)placement follows thetorch.chunk(dim)semantic, where the last few shards on the DeviceMesh dimension might be empty when the tensor dimension is not evenly divisible on the DeviceMesh dimension. TheShardplacement can be used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.)- Parameters
dim (int) – 描述DTensor在其对应的设备网格维度上被分片的张量维度。
警告
在设备网格维度上对张量尺寸不能被整除的张量维度进行分片目前处于实验阶段,可能会发生变化。
- class torch.distributed.tensor.placement_types.Replicate[source]¶
“
Replicate()放置”描述了DTensor在相应维度上的复制,在DeviceMesh维度中的每个排名都持有一个全局张量的副本。Replicate放置可以被所有DTensor API使用(即distribute_tensor,DTensor.from_local等)。
- class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source]¶
放置
Partial(reduce_op)描述了在指定DeviceMesh维度上待进行归约的 DTensor,其中每个 DeviceMesh 维度上的排名都持有全局张量的部分值。用户可以使用redistribute将PartialDTensor 重新分布到指定DeviceMesh维度上的Replicate或Shard(dim)放置,这会在后台触发必要的通信操作(即allreduce,reduce_scatter)。- Parameters
reduce_op (str, 可选) – 用于将部分DTensor生成复制/分片DTensor的约简操作。仅支持逐元素约简操作,包括:“sum”,“avg”,“product”,“max”,“min”,默认值为“sum”。
注意
The
Partial位置可以作为 DTensor 操作符的结果生成,并且只能由DTensor.from_localAPI 使用。
创建DTensor的不同方法¶
- There’re three ways to construct a
DTensor: distribute_tensor()创建了一个DTensor,该对象来自每个进程中的逻辑或“全局”torch.Tensor。这可以用于拆分叶子torch.Tensor(即模型参数/缓冲区和输入)。DTensor.from_local()创建了一个DTensor,它基于每个进程的本地torch.Tensor,可以用来创建DTensor,这些torch.Tensor是非叶子张量(即前向/后向传播中的中间激活张量)。DTensor 提供了专门的张量工厂函数(例如
empty(),ones(),randn()等) 以允许通过直接指定DeviceMesh和Placement来创建不同的DTensor。与distribute_tensor()相比,这可以直接在设备上实现分片内存,而不是在初始化逻辑张量内存后进行分片。
从逻辑 torch.Tensor 创建 DTensor¶
SPMD(单程序,多数据)编程模型在torch.distributed中启动多个进程
(即通过torchrun)来执行相同的程序,这意味着程序中的模型会首先在不同的进程中初始化
(即模型可能会在CPU上初始化,或者在元设备上初始化,或者如果内存足够的话直接在GPU上初始化)。
DTensor 提供了一个distribute_tensor() API,可以将模型权重或张量分片到DTensor中,
在每个进程中都会从“逻辑”张量创建一个DTensor。这将使生成的DTensor符合单设备语义,这对于数值正确性至关重要。
- torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None)¶
将一个叶子
torch.Tensor(即 nn.Parameter/缓冲区)根据指定的device_mesh按照placements进行分发。device_mesh和placements的秩必须相同。要分发的tensor是逻辑或“全局”张量,API 将使用 DeviceMesh 维度中第一个秩的tensor作为真实来源以保持单设备语义。如果要在 Autograd 计算中间构造 DTensor,请改用DTensor.from_local()。- Parameters
张量 (torch.Tensor) – 要分发的torch.Tensor。请注意,如果您想在一个维度上分割张量,而该维度不能被该网格维度中的设备数量整除,我们将使用
torch.chunk语义来分割张量并分散碎片。这种不均匀的分割行为是实验性的,并且可能会发生变化。设备网格 (
DeviceMesh,可选) – 分布张量的设备网格,如果没有指定,则必须在设备网格上下文管理器下调用,默认值:Noneplacements (List[
Placement], 可选) – 描述如何将张量放置在 DeviceMesh 上的排列方式,必须与device_mesh.ndim具有相同数量的元素。如果未指定,默认将在device_mesh的第一个维度上复制张量,从 device_mesh 的每个维度的第一个排名开始。
- Returns
A
DTensororXLAShardedTensorobject.- Return type
注意
当使用
xla个设备类型初始化 DeviceMesh 时,distribute_tensor返回 XLAShardedTensor。有关更多详细信息,请参阅此问题。XLA 集成是实验性的,并可能会发生变化。
与distribute_tensor()一起,DTensor 还提供了distribute_module() API,以便在nn.Module级别上更轻松地进行分片
- torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)¶
此功能提供了三个用于控制模块的参数/输入/输出的函数:
1. 在运行时执行之前通过指定
partition_fn(即允许用户根据指定的partition_fn将模块参数转换为DTensor参数)。 2. 在运行时执行期间通过指定input_fn和output_fn来控制模块的输入或输出。(即将输入转换为DTensor,将输出转换回torch.Tensor)- Parameters
模块 (
nn.Module) – 用户模块,将要分区。设备网格 (
DeviceMesh) – 放置模块的设备网格。partition_fn (Callable) – 划分参数的函数(即在
device_mesh中划分某些参数)。如果未指定partition_fn,默认情况下会在网格中复制module的所有模块参数。input_fn (Callable) – 指定输入分布,即可以控制模块的输入如何分片。
input_fn将被安装为模块forward_pre_hook(预前向钩子)。output_fn (可调用对象) – 指定输出分布,即可以控制输出如何分片,或将其转换回torch.Tensor。
output_fn将被安装为模块forward_hook(后向钩子之后)。
- Returns
一个包含所有参数/缓冲区为
DTensor的模块。- Return type
注意
当使用
xla设备类型初始化 DeviceMesh 时,此问题 。XLA 集成是实验性的,可能会发生变化。
张量工厂函数¶
DTensor 还提供了专用的张量工厂函数,允许通过使用类似 torch.Tensor 的工厂函数 API(例如 torch.ones、torch.empty 等),直接创建 DTensor,并通过额外指定 DeviceMesh 和 Placement 来为创建的 DTensor 设置:
- torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)¶
返回一个用标量值 0 填充的
DTensor。- Parameters
size (int...) – 定义输出形状的整数序列。 可以是任意数量的参数,也可以是列表或元组等集合。 例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))
- Keyword Arguments
requires_grad (bool, 可选) – 如果自动求梯度应记录返回的
DTensor上的操作。默认值:False。数据类型 (
torch.dtype,可选) – 返回的数据类型。 默认值:如果为None,使用全局默认设置(参见torch.set_default_dtype())。布局 (
torch.layout,可选) – 返回的DTensor的期望布局。 默认值:torch.strided。设备网格 –
DeviceMesh类型,包含排名的网格信息位置 – 一个由
Placement类型组成的序列:Shard,Replicate
- Returns
一个
DTensor对象在每个排名上- Return type
- torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)¶
返回一个用标量值 1 填充的
DTensor,其形状由变量参数size定义。- Parameters
size (int...) – 定义输出形状的整数序列。 可以是任意数量的参数,也可以是列表或元组等集合。 例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))
- Keyword Arguments
数据类型 (
torch.dtype,可选) – 返回的数据类型。 默认值:如果为None,使用全局默认设置(参见torch.set_default_dtype())。布局 (
torch.layout,可选) – 返回的DTensor所需的布局。 默认:torch.strided。requires_grad (bool, 可选) – 如果自动求梯度应记录返回的
DTensor上的操作。默认值:False。设备网格 –
DeviceMesh类型,包含排名的网格信息位置 – 一个由
Placement类型组成的序列:Shard,Replicate
- Returns
一个
DTensor对象在每个排名上- Return type
- torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)¶
返回一个用未初始化数据填充的
DTensor。其形状由变量参数size定义。- Parameters
size (int...) – 定义输出形状的整数序列。 可以是任意数量的参数,也可以是列表或元组等集合。 例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))
- Keyword Arguments
数据类型 (
torch.dtype,可选) – 返回的DTensor的数据类型。 默认值:如果为None,使用全局默认值(参见torch.set_default_dtype())。 布局 (torch.layout,可选):返回的DTensor的布局。 默认值:torch.strided。requires_grad (bool, 可选) – 如果自动求梯度应记录返回的
DTensor上的操作。默认值:False。设备网格 –
DeviceMesh类型,包含排名的网格信息位置 – 一个由
Placement类型组成的序列:Shard,Replicate
- Returns
一个
DTensor对象在每个排名上- Return type
- torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)¶
返回一个用
DTensor填充的张量,根据device_mesh和placements,其形状由参数size定义。- Parameters
size (int...) – 定义输出形状的整数序列。 可以是任意数量的参数,也可以是列表或元组等集合。 例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))
fill_value (标量) – 用于填充输出张量的值。
- Keyword Arguments
数据类型 (
torch.dtype,可选) – 返回的数据类型。 默认值:如果为None,使用全局默认设置(参见torch.set_default_dtype())。布局 (
torch.layout,可选) – 返回的DTensor所需的布局。 默认:torch.strided。requires_grad (bool, 可选) – 如果自动求梯度应记录返回的
DTensor上的操作。默认值:False。设备网格 –
DeviceMesh类型,包含排名的网格信息。位置 – 一个由
Placement类型组成的序列:Shard,Replicate
- Returns
一个
DTensor对象在每个排名上- Return type
- torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)¶
返回一个由均匀分布在区间
[0, 1)上的随机数填充的DTensor。张量的形状由可变参数size定义。- Parameters
size (int...) – 定义输出形状的整数序列。 可以是任意数量的参数,也可以是列表或元组等集合。 例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))
- Keyword Arguments
数据类型 (
torch.dtype,可选) – 返回的数据类型。 默认值:如果为None,使用全局默认设置(参见torch.set_default_dtype())。布局 (
torch.layout,可选) – 返回的DTensor所需的布局。 默认:torch.strided。requires_grad (bool, 可选) – 如果自动求梯度应记录返回的
DTensor上的操作。默认值:False。设备网格 –
DeviceMesh类型,包含排名的网格信息。位置 – 一个由
Placement类型组成的序列:Shard,Replicate
- Returns
一个
DTensor对象在每个排名上- Return type
- torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)¶
返回一个由均值为0、方差为1的正态分布随机数填充的
DTensor。张量的形状由变量 参数size定义。- Parameters
size (int...) – 定义输出形状的整数序列。 可以是任意数量的参数,也可以是列表或元组等集合。 例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))
- Keyword Arguments
数据类型 (
torch.dtype,可选) – 返回的数据类型。 默认值:如果为None,使用全局默认设置(参见torch.set_default_dtype())。布局 (
torch.layout,可选) – 返回的DTensor所需的布局。 默认:torch.strided。requires_grad (bool, 可选) – 如果自动求梯度应记录返回的
DTensor上的操作。默认值:False。设备网格 –
DeviceMesh类型,包含排名的网格信息。位置 – 一个由
Placement类型组成的序列:Shard,Replicate
- Returns
一个
DTensor对象在每个排名上- Return type
调试¶
日志记录¶
启动程序时,您可以使用TORCH_LOGS环境变量从torch._logging启用额外的日志记录:
TORCH_LOGS=+dtensor 将显示 logging.DEBUG 条消息及其以上的所有级别。
TORCH_LOGS=dtensor 将显示 logging.INFO 条及以上消息。
TORCH_LOGS=-dtensor 将显示 logging.WARNING 条及以上消息。
调试工具¶
为了调试应用了DTensor的程序,并更详细地了解底层发生了哪些集体操作,DTensor提供了一个CommDebugMode:
- class torch.distributed.tensor.debug.CommDebugMode¶
CommDebugMode是一个上下文管理器,用于计算其上下文中功能集体的数量。它使用一个TorchDispatchMode来实现。示例用法
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[source]¶
生成详细表格,显示模块级别的操作和聚合跟踪信息。信息量取决于噪声级别。
打印模块级汇总计数
打印非简单操作的 dTensor 操作以及模块信息
打印操作不包括在简单操作中
打印所有操作
为了可视化维度小于3的DTensor的分片,DTensor提供了visualize_sharding():
- torch.distributed.tensor.debug.visualize_sharding(dtensor, header='')¶
在终端中可视化
DTensor的分片,这些分片是一维或二维的。注意
这需要
tabulate包。对于空张量不会打印分片信息。
实验功能¶
DTensor 还提供了一组实验性功能。这些功能要么处于原型阶段,要么基本功能已完成但正在寻求用户反馈。如果您对这些功能有任何意见,请提交给 PyTorch。
- torch.distributed.tensor.experimental.local_map(func, out_placements, in_placements=None, device_mesh=None, *, redistribute_inputs=False)¶
local_map()是一个实验性 API,允许用户将DTensor传递给一个函数,该函数被编写为应用于torch.Tensor。这是通过提取DTensor的本地组件,调用函数,并根据out_placements将输出包装成DTensor来实现的。- Parameters
函数 (可调用对象) – 要应用于每个本地分片的函数。
DTensors.out_placements (PlacementType 或 PlacementType 元组)– 指定在展平后的
DTensor中的位置。 如果展平后的output是单个值,那么out_placements应该是类型 PlacementType。 否则,如果展平后的output有多个值,那么out_placements应该是一个与展平后的output一一对应的 PlacementType 值元组。 此外,对于Tensor输出,我们使用 PlacementType 作为其位置(一个 Tuple[Placement] 值)。对于非张量输出,PlacementType 应该是 None。 需要注意的是,唯一的例外是在没有传递DTensor参数的情况下。在这种情况下,即使 out_placements 不是 None,结果函数也应该忽略指定的位置,因为函数不是在DTensor上运行的。in_placements (Tuple[PlacementType, …], 可选) – 指定输入张量中
DTensor的位置。 如果指定了in_placements,local_map()会检查每个DTensor参数的位置是否与所需位置一致。 如果不一致且redistribute_inputs为False,则会抛出异常。否则,如果redistribute_inputs为True, 参数将在传递其本地张量到func之前重新分布到所需的位置。 唯一例外是当所需位置不是None且参数是一个torch.Tensor时,在这种情况下,将跳过位置检查,并直接将参数传递给func。 如果in_placements为None,则不会进行位置检查。 默认值:无设备网格 (
DeviceMesh,可选) – 所有DTensor放置在其上的设备网格。如果没有指定,则根据输入的DTensor的设备网格进行推断。local_map要求每个DTensor都放置在相同的设备网格上。默认值:None。重新分配输入 (bool, 可选) – 表示是否在输入放置与所需输入放置不同时重新划分输入
DTensor的布尔值。如果该值为False,且某些DTensor输入的放置不同,则会引发异常。默认值:False。
- Returns
一个
Callable,它将func应用于输入DTensor的每个本地分片, 并返回由func的返回值构建的DTensor。- Raises
AssertionError – 如果输入
DTensor没有放置在相同的设备网格上,或者它们被放置在与传入的device_mesh参数不同的设备网格上。AssertionError – 对于任何非DTensor输出,我们要求其对应的输出位置在
out_placements为None。如果这不是这种情况,则会引发AssertionError。ValueError – 如果
redistribute_inputs=False,但输入DTensor需要根据in_placements进行重新分配。
示例
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
注意
该 API 目前处于实验阶段,可能会发生变化。
- torch.distributed.tensor.experimental.register_sharding(op)¶
register_sharding()是一个实验性的 API,允许用户在张量输入和输出为 DTensor 时为操作注册分片策略。 它在以下情况下非常有用:(1) 对于op没有默认的分片策略,例如当op是一个不受DTensor支持的自定义操作符;(2) 当用户希望覆盖现有操作符的默认分片策略时。- Parameters
操作 (Union[OpOverload, List[OpOverload]]) – 指定要注册自定义分片函数的一个操作或操作列表。
- Returns
一个函数装饰器,可用于包装定义算子分片策略的函数,该算子在
op中指定。所定义的分片策略将被注册到DTensor,并且如果DTensor已经实现了该算子,则会覆盖默认的分片策略。自定义的分片函数接受与原始操作相同的输入(除了如果某个参数是一个torch.Tensor,它将被DTensor内部使用的张量对象替换)。该函数应返回一个2元组序列,每个2元组指定了可接受的输出位置及其对应的输入位置。
示例
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
注意
该 API 目前处于实验阶段,可能会发生变化。