torch.distributed.tensor¶
注意
torch.distributed.tensor
当前处于 Alpha 状态及以下
开发,我们将为列出的大多数 API 提供向后兼容性
,但如有必要,可能会有 API 更改。
PyTorch DTensor(分布式张量)¶
PyTorch DTensor 提供简单灵活的张量分片基元,以透明方式处理分布式
逻辑,包括分片存储、算子计算和跨设备/主机的集体通信。 可用于构建不同的 paralleism 解决方案并支持分片 state_dict 表示
使用多维分片时。DTensor
请参阅基于 PyTorch 原生并行解决方案的示例:DTensor
遵循 SPMD(单程序多数据)编程模型,使用户能够
编写分布式程序,就好像它是具有相同 convergence 属性的单设备程序一样。它
通过指定 和 来提供统一的张量分片布局 (DTensor Layout):
DeviceMesh
Placement
DeviceMesh
表示设备拓扑和群集的通信器,使用 一个 n 维数组。Placement
描述 逻辑张量在 . DTensor 支持三种类型的放置:和 .DeviceMesh
Shard
Replicate
Partial
DTensor 类 API¶
是一个子类。这意味着一旦创建了 a
,它就可以
的使用方式与 非常相似,包括运行不同类型的 PyTorch 运算符,就像
在单个设备中运行它们,从而允许对 PyTorch 运算符进行适当的分布式计算。
torch.Tensor
torch.Tensor
除了现有的方法外,它还提供了一组额外的方法,用于交互,将 DTensor 布局到新的 DTensor,获取完整的张量内容
在所有设备上,等等。torch.Tensor
torch.Tensor
redistribute
- 类 torch.distributed.tensor 中。DTensor(local_tensor、spec、*、requires_grad)¶
DTensor
(Distributed Tensor) 是 的一个子类,它提供单设备类,如 抽象为具有多设备 的程序。它描述了分布式张量分片 布局 (DTensor Layout) 通过以下类型:torch.Tensor
torch.Tensor
DeviceMesh
Placement
Shard
:在该维度的设备上的 tensor 维度上分片的 Tensordim
DeviceMesh
Replicate
:在 dimension 的设备上复制的 TensorDeviceMesh
Partial
:Tensor 正在维度的设备上等待缩减DeviceMesh
调用 PyTorch 算子时,覆盖 PyTorch 算子进行分片计算,并发出 必要时的通信。除了运算符计算,还将转换或传播 placements (DTensor Layout) 正确(基于运算符语义本身)并生成新的输出。
DTensor
DTensor
DTensor
为了在调用 PyTorch 算子时确保分片计算的数值正确性,要求算子的每个 Tensor 参数都是 DTensor。
DTensor
DTensor
- 返回类型
- 属性device_mesh:DeviceMesh¶
与此 DTensor 对象关联的属性。
DeviceMesh
注意
device_mesh
是只读属性,则无法设置。
- static from_local(local_tensor, device_mesh=无, placements=无, *, run_check=False, shape=None, stride=None)[来源]¶
从本地 Torch 创建 A
。每个等级的 Tensor 根据 和 指定。
device_mesh
placements
- 参数
local_tensor(Torch。Tensor) – 本地Torch。每个 rank 上的 Tensor。
device_mesh ( 可选 ) – 用于放置 tensor(如果未指定),则必须在 DeviceMesh 下调用 上下文管理器,默认值:无
DeviceMesh
placements (List[], optional) – 放置 介绍如何放置本地Torch。DeviceMesh 上的 Tensor 必须 的元素数与 相同。
Placement
device_mesh.ndim
- 关键字参数
run_check (bool, optional) – 以额外的通信为代价,执行 跨 ranks 的健全性检查,以检查每个本地 Tensor 的元信息 以确保正确性。如果 在 中,则 将广播设备网格维度的第一个排名的数据 到其他级别。默认值:False
Replicate
placements
形状(Torch。Size, optional) – 一个 int 列表,它指定 DTensor 构建在 local_tensor 之上。请注意,这需要 如果 的形状在各个等级中不同,则提供。 如果未提供,则将假设给定的分布式 Tensor 在各个等级之间均匀分片。默认值:无
local_tensor
shape
stride (tuple, optional) - 一个 int 列表,用于指定 DTensor 的步幅。 如果未提供,则将假设给定的分布式 Tensor 在各个等级之间均匀分片。默认值:无
stride
- 返回
- 返回类型
注意
当 时,用户有责任确保 传入的 local tensor 在各个 rank 中是正确的(即 tensor 被分片为 放置 或 replicad)。 否则,创建的 DTensor 的行为是未定义的。
run_check=False
Shard(dim)
Replicate()
注意
from_local
是可微分的,则创建的 DTensor 对象的requires_grad将取决于local_tensor是否requires_grad。
- full_tensor(*, grad_placements=无)[来源]¶
返回此 DTensor 的完整张量。它将执行必要的集合 从其 DeviceMesh 中的其他 ranks 收集本地张量并连接 他们在一起。它是以下代码的合成糖:
dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()
- 关键字参数
grad_placements (List[], optional) – placements 描述 从此返回的完整 Tensor 的任何梯度布局的未来布局 功能。full_tensor 将 DTensor 转换为完整的 Torch。Tensor 和返回的 torch.tensor 稍后在代码中可能不会用作原始复制的 DTensor 布局。这 argument 是用户可以向 autograd 提供的提示,以防 gradient 返回的 Tensor 的布局与原始复制的 DTensor 布局不匹配。 如果未指定,我们将假设复制完整张量的梯度布局。
Placement
- 返回
- 返回类型
注意
full_tensor
是可微分的。
- 属性放置: Tuple[Placement, ...]¶
此 DTensor 的 placements 属性,用于描述此 DTensor 的布局 DTensor 在其 DeviceMesh 上。
注意
placements
是只读属性,则无法设置。
- redistribute(device_mesh=无, placements=无, *, async_op=False)[来源]¶
redistribute
执行必要的集合操作,以重新分配当前的 DTensor 从其当前位置移动到新的位置,或者 from 是当前 DeviceMesh 添加到新的 DeviceMesh 中。即,我们可以通过以下方式将分片的 DTensor 转换为复制的 DTensor 为 DeviceMesh 的每个维度指定 Replicate placement (复制位置)。在一个设备网格维度上从当前位置重新分布到新位置时,我们 将执行以下操作,包括通信集体或本地操作:
Shard(dim)
->Replicate()
:all_gather
Shard(src_dim)
->Shard(dst_dim)
:all_to_all
Replicate()
->Shard(dim)
:局部分块(即torch.chunk
)Partial()
->Replicate()
:all_reduce
Partial()
->Shard(dim)
:reduce_scatter
redistribute
将正确地找出 DTensor 的必要重新分发步骤 ,它们是在 1-D 或 N-D DeviceMesh 上创建的。- 参数
device_mesh ( 可选 ) – 用于放置 DTensor 的如果未指定,它将使用当前 DTensor 的 DeviceMesh。 默认值:无
DeviceMesh
placements (List[], optional) – 新的位置 介绍如何将 DTensor 放入 DeviceMesh 中,必须 的元素数与 相同。 default:在所有网格维度上复制
Placement
device_mesh.ndim
- 关键字参数
async_op (bool, optional) – 是否执行 DTensor redistribute 操作 异步或非异步。默认值:False
- 返回
- 返回类型
注意
redistribute
是可微的,这意味着用户无需担心 redistribute 操作的反向公式。注意
redistribute
目前只支持在同一个 DeviceMesh 上重分发 DTensor, 如果您需要将 DTensor 重新分发到不同的 DeviceMesh,请提交 issue。
- to_local(*, grad_placements=无)[来源]¶
获取此 DTensor 的当前秩的局部张量。对于分片,它返回 逻辑张量视图的本地分片,对于复制,它返回 它现在的等级。
- 关键字参数
grad_placements (List[], optional) – placements 描述 从此返回的 Tensor 的任何梯度布局的未来布局 功能。to_local将 DTensor 转换为本地张量,并返回本地张量 稍后在代码中可能不会用作原始 DTensor 布局。这 argument 是用户可以向 autograd 提供的提示,以防 gradient 返回的张量的布局与原始 DTensor 布局不匹配。 如果未指定,我们将假设渐变布局保持不变 作为原始 DTensor 进行分配,并将其用于梯度计算。
Placement
- 返回
A
或 object。它表示 local tensor 的 local 张量。当返回对象时, 这意味着本地张量尚未准备好(即通信尚未完成)。在这个 case,用户需要调用等待本地 Tensor 准备就绪。
AsyncCollectiveTensor
AsyncCollectiveTensor
wait
- 返回类型
注意
to_local
是微分的,则返回的局部张量 将取决于 DTensor 是否requires_grad。requires_grad
DeviceMesh 作为分布式通信器¶
是从 DTensor 构建的,作为描述集群设备拓扑的抽象,并表示
多维通信器(在 之上)。要查看如何创建/使用 DeviceMesh 的详细信息,
请参考 DeviceMesh 配方。
ProcessGroup
DTensor 放置类型¶
DTensor 在每个维度上支持以下类型:
DeviceMesh
- torch.distributed.tensor.placement_types 类。分片(dim)[来源]¶
位置描述了 DTensor 在相应维度上的张量维度上的分片,其中 DeviceMesh 维度仅包含全局 Tensor 的一个分片/块。放置遵循语义,其中 当 Tensor 维度 在 DeviceMesh 维度上不是均匀划分的。放置可以是 由所有 DTensor API(即 distribute_tensor、from_local 等)使用
Shard(dim)
dim
DeviceMesh
Shard(dim)
torch.chunk(dim)
Shard
- 参数
dim (int) – 描述 DTensor 的张量维度在其 相应的 DeviceMesh 维度。
警告
在 Tensor 维度大小不为 在 DeviceMesh 维度上均匀整除目前是试验性的,可能会发生变化。
- torch.distributed.tensor.placement_types 类。复制[源]¶
放置描述在相应维度上复制的 DTensor,其中 DeviceMesh 维度上的每个秩都包含一个 全局 Tensor 的副本。该位置可供所有人使用 DTensor API(即 、 等)
Replicate()
DeviceMesh
Replicate
distribute_tensor
DTensor.from_local
- torch.distributed.tensor.placement_types 类。partial(reduce_op='sum')[来源]¶
放置描述待处理的 DTensor reduction 的 Reduction 在指定维度上,其中 DeviceMesh 维度保存全局 Tensor 的部分值。用户可以 使用 将 DTensor 重新分配到指定维度上的 或 位置 。 这将在后台触发必要的通信操作(即 , )。
Partial(reduce_op)
DeviceMesh
Partial
Replicate
Shard(dim)
DeviceMesh
redistribute
allreduce
reduce_scatter
- 参数
reduce_op (str, optional) – 要用于部分 DTensor 的缩减运算 生成 Replicated/Sharded DTensor。仅元素级缩减操作 支持,包括:“sum”、“avg”、“product”、“max”、“min”,默认值:“sum”。
注意
放置可以作为 DTensor 运算符 并且只能由 API 使用。
Partial
DTensor.from_local
创建 DTensor 的不同方法¶
从逻辑 torch 创建 DTensor。张肌¶
启动多个流程中的 SPMD(单个程序,多个数据)编程模型
(即 via )来执行相同的程序,这意味着程序内部的模型将是
首先在不同的进程上初始化(即模型可能在 CPU 或 Meta 设备上初始化,或者直接在
如果内存足够,则在 GPU 上)。torch.distributed
torchrun
DTensor
提供了一个 API,可以将模型权重或 Tensor 分片为 s,
其中,它将从每个进程上的“逻辑”Tensor 创建一个 DTensor。这将使创建的 s 能够符合单个设备语义,这对于数字正确性至关重要。
DTensor
DTensor
- torch.distributed.tensor Tensor 中。distribute_tensor(tensor, device_mesh=无, placements=无)¶
分发一个叶子(即 nn.Parameter/buffers) 添加到相应的 到指定的。的秩 和 必须是 相同。to distribute 是逻辑张量或“全局”张量,API 将使用 从 DeviceMesh 维度的 first rank 开始作为要保留的真实来源 单设备语义。如果要在 Autograd 的中间构造一个 DTensor 计算,请改用
。
torch.Tensor
device_mesh
placements
device_mesh
placements
tensor
tensor
- 参数
张量 (Torch.Tensor) – Torch。要分布的 Tensor。请注意,如果你 想要在不能被 该网格维度中的设备数量,我们使用 Semantic 对 Tensor 进行分片并分散 Shard。分片不均匀 行为是实验性的,可能会发生变化。
torch.chunk
device_mesh ( 可选 ) – 用于分发 tensor(如果未指定),则必须在 DeviceMesh 上下文下调用 manager,默认值:None
DeviceMesh
placements (List[], optional) – 放置 介绍如何将 Tensor 放置在 DeviceMesh 上,必须具有相同的 元素数为 。如果未指定,我们将 默认情况下,从 device_mesh 的每个维度的 first rank。
Placement
device_mesh.ndim
device_mesh
- 返回
- 返回类型
注意
使用 device_type 初始化 DeviceMesh 时,请改为返回 XLAShardedTensor。有关更多详细信息,请参阅此问题。XLA 集成是实验性的,可能会发生变化。
xla
distribute_tensor
除了 ,DTensor 还提供了一个
API 来允许更轻松地
级别上的分片
nn.Module
- torch.distributed.tensor Tensor 中。distribute_module(module, device_mesh=无, partition_fn=无, input_fn=无, output_fn=无)¶
此函数公开了三个函数来控制模块的参数/输入/输出:
1. 在运行时执行之前对 Module 进行分片,通过指定 (即允许用户根据指定的partition_fn将 Module 参数转换为
参数)。 2. 在运行时执行期间控制模块的输入或输出 指定 和 .(即 将 input 转换为
,将 output 转换回
partition_fn
input_fn
output_fn
torch.Tensor
)- 参数
module () – 要分区的用户模块。
nn.Module
device_mesh () – 用于放置模块的设备网格。
DeviceMesh
partition_fn (Callable) – 用于对参数进行分区的函数(即分片确定) 参数跨 )。如果未指定,则 默认情况下,我们复制整个网格的所有模块参数。
device_mesh
partition_fn
module
input_fn (Callable) – 指定输入分布,即可以控制 input 的 input 被分片。 将作为一个模块安装(pre forward hook)。
input_fn
forward_pre_hook
output_fn (Callable) – 指定输出分布,即可以控制 output 被分片,或将其转换回 Torch。张肌。 将是 作为模块安装(POST FORWARD 钩子)。
output_fn
forward_hook
- 返回
一个模块,其中包含所有 s.
DTensor
- 返回类型
注意
使用 device_type 初始化 DeviceMesh 时,返回 nn.具有 PyTorch/XLA SPMD 注释参数的模块。有关更多详细信息,请参阅此问题。XLA 集成是实验性的,可能会发生变化。
xla
distribute_module
DTensor 工厂函数¶
DTensor 还提供了专用的 Tensor Factory 函数,允许直接创建
使用 Torch。类似工厂函数 API 的张量(即 torch.ones、torch.empty 等),另外由
为创建的指定 and
:
DeviceMesh
Placement
- torch.distributed.tensor Tensor 中。零(*大小, requires_grad=False, dtype=无, layout=torch.strided, device_mesh=无,placements=None)¶
返回标量值为 0 的 fill。
- torch.distributed.tensor Tensor 中。一(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=无,placements=None)¶
返回标量值为 1 的 fill,其形状已定义
通过变量参数 .
size
- torch.distributed.tensor Tensor 中。empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=无,placements=None)¶
返回
一个 filled 的未初始化数据。的形状
由变量 argument 定义。
size
- 参数
size (int...) – 定义输出
形状的整数序列。 可以是可变数量的参数,也可以是列表或元组等集合。 例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))
- 关键字参数
dtype (
, optional) – 返回
的所需数据类型。 默认值:如果 ,则使用全局默认值(请参阅
)。layout (
可选 ):返回
的所需布局。 违约:。
None
torch.strided
requires_grad (bool, optional) – 如果 autograd 应该记录对 返回
。违约:。
False
device_mesh – 类型,包含等级的网格信息
DeviceMesh
placements – 类型为 ,
Placement
Shard
Replicate
- 返回
- 返回类型
- torch.distributed.tensor Tensor 中。full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=无,placements=无)¶
返回一个
填充的 根据 和 ,其形状由参数 定义。
fill_value
device_mesh
placements
size
- 参数
size (int...) – 定义输出
形状的整数序列。 可以是可变数量的参数,也可以是列表或元组等集合。 例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))
fill_value (Scalar) – 用于填充输出张量的值。
- 关键字参数
requires_grad (bool, optional) – 如果 autograd 应该记录对 返回
。违约:。
False
device_mesh – 类型,包含秩的网格信息。
DeviceMesh
placements – 类型为 ,
Placement
Shard
Replicate
- 返回
- 返回类型
- torch.distributed.tensor Tensor 中。rand(*size, requires_grad=False, dtype=无, layout=torch.strided, device_mesh=无,placements=None)¶
返回来自均匀分布的填充随机数
在区间 上。张量的形状由变量 论点。
[0, 1)
size
- torch.distributed.tensor Tensor 中。randn(*size, requires_grad=False, dtype=无, layout=torch.strided, device_mesh=无,placements=None)¶
返回一个由正态分布中的随机数填充的
均值为 0,方差为 1。张量的形状由变量 论点。
size
调试¶
伐木¶
启动程序时,您可以使用 torch._logging 中的 TORCH_LOGS 环境变量打开其他日志记录:
TORCH_LOGS=+dtensor 将显示 logging。DEBUG 消息及其上面的所有级别。
TORCH_LOGS=dtensor 将显示 logging.INFO 及以上的消息。
TORCH_LOGS=-dtensor 将显示日志记录。WARNING 消息及以上。
调试工具¶
调试应用 DTensor 的程序,并了解有关在
hood 中,DTensor 提供了一个 :
- 类 torch.distributed.tensor.debug 中。CommDebugMode (共调试模式)¶
是一个上下文管理器,它计算 函数式集合。它使用 .
TorchDispatchMode
用法示例
mod = ... comm_mode = CommDebugMode() with comm_mode: mod.sum().backward() print(comm_mode.get_comm_counts())
- generate_comm_debug_tracing_table(noise_level=3)[来源]¶
生成显示操作和集合跟踪信息的详细表 在模块级别。信息量取决于noise_level
打印模块级集体计数
打印 dTensor 操作 intrivial operations 中不包含的 dTensor 操作、模块信息
打印 Importas Operations 中未包含的操作
打印所有操作
为了可视化小于 3 个维度的 DTensor 的分片,DTensor 提供了:
- torch.distributed.tensor.debug 的 Debug。visualize_sharding(dtensor, header='')¶
在终端中可视化 1D 或 2D 的分片。
DTensor
注意
这需要 package。不会为空张量打印分片信息
tabulate
实验性功能¶
DTensor
还提供了一组实验性功能。这些功能要么处于原型设计阶段,要么处于基本
功能已完成,但正在寻找用户反馈。如果您有反馈,请将问题提交到 PyTorch
这些功能。
- torch.distributed.tensor.experimental 的local_map(func, out_placements, in_placements=无, device_mesh=无, *, redistribute_inputs=False)¶
是一个实验性 API,允许用户将 s 添加到写入以应用于 S 的函数。它是通过提取 的局部组件 调用 函数,并根据 将输出包装为 。
DTensor
torch.Tensor
DTensor
DTensor
out_placements
- 参数
func (Callable) – 要应用于 s 的每个本地分片的函数。
DTensor
out_placements (Union[PlacementType, Tuple[PlacementType, ...]]) – 拼合输出中 s 的所需位置。 如果 flattened 是单个值,则 应为 的类型 PlacementType。否则,如果平展的 API 具有多个 值,则应为 PlacementType 值 1:1 的元组 映射到展平的 . 此外,对于输出,我们使用 PlacementType 作为其 placements(一个 Tuple[Placement] 值)。对于非 Tensor 输出,PlacementType 应为 None。 请注意,唯一的例外是未传递参数时 在。在这种情况下,即使 out_placements 不是 None,结果函数 应忽略所需的放置位置,因为该函数未与 S 一起运行。
DTensor
func
output
out_placements
output
out_placements
output
Tensor
DTensor
DTensor
in_placements (Tuple[PlacementType, ...], optional) – s 在 的拼合输入中的所需位置。 如果指定,
将检查 每个参数的 placements 与所需的 placements 与否。如果位置不同,并且 是 ,则会引发异常。否则 if is ,则参数将首先重新分配给 在将其本地张量传递给 之前所需的分片放置位置。 唯一的例外是,当 required placements 不是并且 参数是
.在这种情况下,安置考试 将跳过,参数将直接传递给 . 如果是,则不会进行安置检查。 默认值:无
DTensor
func
in_placements
DTensor
redistribute_inputs
False
redistribute_inputs
True
func
None
func
in_placements
None
device_mesh (,可选) – 放置所有 S 的设备网格。如果不是 specified,这将从输入 S 的设备推断出来 网孔。local_map 要求将每个 s 都放在同一个 device mesh 的默认值:None。
DeviceMesh
DTensor
DTensor
DTensor
redistribute_inputs (bool, optional) – 布尔值,指示何时重新分片输入 s 它们的放置与必需的 Input 放置不同。如果此 value 是,并且某些输入具有不同的位置, 将引发异常。默认值:False。
DTensor
False
DTensor
- 返回
A 应用于输入的每个本地分片,并返回根据 的返回值构建的 。
Callable
func
DTensor
DTensor
func
- 提高
AssertionError – 如果输入未放置在同一设备上 mesh 的 net 协议,或者如果它们被放置在与传入的参数不同的设备 mesh 上。
DTensor
device_mesh
AssertionError – 对于任何非 DTensor 输出,我们需要其相应的 output placement 在 be None 中。将引发 AssertionError 如果不是这种情况。
out_placements
ValueError – 如果输入需要 根据 进行重新分配。
redistribute_inputs=False
DTensor
in_placements
例
>>> def mm_allreduce_forward(device_mesh, W, X): >>> partial_sum_tensor = torch.mm(W, X) >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) >>> return reduced_tensor >>> >>> W = torch.randn(12, 8, requires_grad=False) >>> X = torch.randn(8, 16, requires_grad=False) >>> Y = torch.mm(W, X) >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], >>> in_placements=[col_wise, row_wise], >>> device_mesh=device_mesh, >>> ) >>> >>> W_dt = distribute_tensor(W, device_mesh, (col_wise)) # col-wisely sharded W tensor >>> X_dt = distribute_tensor(X, device_mesh, (row_wise)) # row-wisely sharded X tensor >>> Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) # apply local_mm_allreduce_forward to DTensors
注意
此 API 目前处于试验阶段,可能会发生更改
- torch.distributed.tensor.experimental 的register_sharding(op)¶
是一个实验性 API,允许用户注册分片 当张量输入和输出为 DTensor 时运算符的策略。 它在以下情况下可能很有用:(1) 不存在 的默认分片策略 , 例如,when 是 不支持的自定义运算符;(2) 当用户想要覆盖现有算子的默认分片策略时。
op
op
DTensor
- 参数
op (Union[OpOverload, List[OpOverload]]) – 用于注册自定义分片函数的操作或操作列表。
- 返回
一个函数装饰器,可用于包装定义分片的函数 策略。定义的分片策略将为 注册到 DTensor,如果 DTensor 具有 已经实现了 Operator。自定义分片函数采用相同的输入 作为原始 op (不同之处在于,如果 arg 是 ,
它将是 替换为 DTensor 内部使用的类似 Tensor 的对象)。该函数应 返回一个 2 元组序列,每个元组指定可接受的输出位置及其 相应的 intput placements。
op
例
>>> @register_sharding(aten._softmax.default) >>> def custom_softmax_sharding(x, dim, half_to_float): >>> softmax_dim = dim if dim >= 0 else dim + x.ndim >>> acceptable_shardings = [] >>> >>> all_replicate = ([Replicate()], [Replicate(), None, None]) >>> acceptable_shardings.append(all_replicate) >>> >>> for sharding_dim in range(x.ndim): >>> if sharding_dim != softmax_dim: >>> all_sharded = ( >>> [Shard(sharding_dim)], >>> [Shard(sharding_dim), None, None], >>> ) >>> acceptable_shardings.append(all_sharded) >>> >>> return acceptable_shardings
注意
此 API 目前处于试验阶段,可能会发生更改