目录

PyTorch在XLA设备上的使用

PyTorch可以在XLA设备(如TPU)上运行,使用 torch_xla包。本文档描述了如何在这些设备上运行您的模型。

创建XLA张量

PyTorch/XLA添加了一个新的xla设备类型到PyTorch。这个设备类型就像其他PyTorch设备类型一样工作。例如,这里是如何创建和打印一个XLA张量的方式:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

这段代码应该很熟悉。PyTorch/XLA 使用与普通 PyTorch 相同的接口,但有一些添加。导入 torch_xla 初始化 PyTorch/XLA,而 xm.xla_device() 返回当前的 XLA 设备。这可能是一个 CPU 或 TPU,取决于您的环境。

XLA张量是PyTorch张量

PyTorch操作可以在XLA张量上执行,就像在CPU或CUDA张量上一样。

例如,XLA张量可以相加:

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

或者矩阵相乘:

print(t0.mm(t1))

或者与神经网络模块一起使用:

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

和其他设备类型一样,XLA张量只能在同一设备上的其他XLA张量上工作。因此,代码如

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

将抛出错误,因为torch.nn.Linear模块在CPU上。

在XLA设备上运行模型

构建一个新的PyTorch网络或将其转换为在XLA设备上运行,只需要几行特定于XLA的代码。以下示例代码展示了这些代码在单个设备和多个设备上的XLA多进程运行时的具体内容。

在单个XLA设备上运行

以下代码片段展示了在一个单个XLA设备上进行网络训练:

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

此代码片段展示了如何轻松地将模型切换到在XLA上运行。模型定义、数据加载器、优化器和训练循环可以在任何设备上工作。唯一的XLA特定代码是几行获取XLA设备并标记步骤的代码。在每个训练迭代结束时调用 xm.mark_step() 会导致XLA执行当前的图并更新模型的参数。有关XLA如何创建图并运行操作的更多信息,请参阅XLA张量深度解析

在多个XLA设备上运行并支持多进程

PyTorch/XLA 使训练加速变得简单,通过在多个 XLA 设备上运行。以下示例展示了如何实现:

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  torch_xla.launch(_mp_fn, args=())

这个多设备片段与之前的单设备片段有三个不同之处。让我们逐一进行说明。

  • torch_xla.launch()

    • 创建每个运行XLA设备的进程。

    • 这个函数是一个多线程spawn的包装器,允许用户使用torchrun命令行运行脚本。每个进程只能访问当前进程分配的设备。例如,在一个TPU v4-8上,将启动4个进程,每个进程将拥有一个TPU设备。

    • 请注意,如果你在每个进程中打印 xm.xla_device(),你将在所有设备上看到 xla:0。这是因为每个进程只能看到一个设备。这并不意味着多进程没有正常工作。只有在 TPU v2 和 TPU v3 上使用 PJRT 运行时进行执行,因为会有 #devices/2 个进程,每个进程将有 2 个线程(更多详情请查看此 文档)。

  • MpDeviceLoader

    • 将训练数据加载到每个设备上。

    • MpDeviceLoader 可以在 PyTorch 数据加载器上进行包装。它可以预加载数据到设备并重叠数据加载与设备执行以提高性能。

    • MpDeviceLoader 也会在每 batches_per_execution 批(默认为 1)生成时调用一次 xm.mark_step

  • xm.optimizer_step(optimizer)

    • 将梯度在设备间汇总,并发出XLA设备步骤计算。

    • 这是差不多一个 all_reduce_gradients + optimizer.step() + mark_step 并返回损失减少。

模型定义、优化器定义和训练循环保持不变。

NOTE: It is important to note that, when using multi-processing, the user can start retrieving and accessing XLA devices only from within the target function of torch_xla.launch() (or any function which has torch_xla.launch() as parent in the call stack).

查看 完整的多进程示例 以了解更多关于如何在多个XLA设备上使用多进程训练网络的信息。

在TPU Pods上运行

多主机设置对于不同的加速器来说可能会有很大的不同。本文将讨论多主机训练的设备无关部分,并将以TPU + PJRT运行时(目前在1.13和2.x版本中可用)为例进行说明。

在开始之前,请查看我们的用户指南 这里,其中将解释一些Google Cloud的基础知识,例如如何使用 gcloud 命令以及如何设置您的项目。您还可以查看 这里 了解所有Cloud TPU的使用方法。本文档将专注于PyTorch/XLA的设置视角。

假设你有一个来自上述部分的mnist示例在train_mnist_xla.py中。如果是单主机多设备训练,你会通过SSH连接到TPUVM并运行类似以下命令

PJRT_DEVICE=TPU python3 train_mnist_xla.py

现在为了在TPU v4-16(每个主机有4个TPU设备)上运行相同的模型,您将需要

  • 确保每个主机都能访问训练脚本和训练数据。这通常通过使用gcloud scp命令或gcloud ssh命令将训练脚本复制到所有主机来完成。

  • 在同一时间在所有主机上运行相同的训练命令。

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"

gcloud ssh 命令之上,将 SSH 到 TPUVM Pod 中的所有主机,并同时运行相同的命令。

NOTE: You need to run run above gcloud command outside of the TPUVM vm.

模型代码和训练脚本对于多进程训练和多主机训练是相同的。PyTorch/XLA 和底层基础设施将确保每个设备了解全局拓扑以及每个设备的本地和全局顺序号。跨设备通信将在所有设备之间发生,而不是在本地设备之间。

有关PJRT运行时的更多详细信息以及如何在pod上运行它,请参阅此 文档。有关PyTorch/XLA和TPU pod的更多信息,以及在TPU pod上使用假数据运行resnet50的完整指南,请参阅此 指南

XLA张量深度探索

使用XLA张量和设备只需要更改几行代码。但是,尽管XLA张量的行为与CPU和CUDA张量非常相似,它们的内部却有所不同。本节描述了XLA张量的独特之处。

XLA张量是惰性的

CPU 和 CUDA 张量立即执行操作或急切地执行。而 XLA 张量, 则是惰性的。它们会记录操作到一个图中,直到需要结果时才执行。像这样延迟执行让 XLA 能够优化它。多个独立操作的图可能会融合成一个优化的操作, 例如。

懒执行对调用者通常是不可见的。PyTorch/XLA 会自动构建图,将其发送到 XLA 设备,并在将数据从 XLA 设备复制到 CPU 时进行同步。在采取优化器步骤时插入一个屏障可以显式地同步 CPU 和 XLA 设备。有关我们的懒张量设计的更多信息,请阅读 这篇论文

内存布局

XLA张量的内部数据表示对用户来说是透明的。它们不暴露其存储方式,总是看起来是连续的,不像CPU和CUDA张量。这使得XLA能够调整张量的内存布局以获得更好的性能。

将XLA张量移动到和从CPU之间

XLA张量可以从CPU移动到XLA设备,也可以从XLA设备移动到CPU。如果视图被移动,则其查看的数据也会被复制到另一个设备上,并且视图关系不会被保留。换句话说,一旦数据被复制到另一个设备上,它就不再与之前的设备或设备上的任何张量有任何关系。再次强调,根据你的代码如何操作,理解和适应这种转换可能很重要。

保存和加载XLA张量

XLA张量在保存之前应移动到CPU,如下所示的示例代码:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

这让你可以把加载的张量放在任何可用的设备上,而不仅仅是它们初始化的那个设备。

根据上述关于将XLA张量移动到CPU的说明,当处理视图时必须小心。建议在加载并移动张量到目标设备后重新创建它们,而不是保存视图。

提供了一个实用的API,可以保存数据并处理之前将其移动到CPU的操作。

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

在多设备的情况下,上述API只会保存主设备的序号(0)的数据。

如果内存比模型参数的大小有限,提供了一个API来减少主机上的内存占用:

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

这个API逐个将XLA张量流式传输到CPU,减少主机内存的使用,但它需要一个匹配的加载API来恢复:

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)
model.load_state_dict(state_dict)

直接保存XLA张量是可能的,但不推荐这样做。XLA张量总是从它们被保存的设备加载回来,如果那个设备不可用,加载将会失败。PyTorch/XLA,就像所有PyTorch一样,正在积极开发中,这种行为在未来可能会改变。

编译缓存

XLA 编译器将追踪的 HLO 转换为可以在设备上运行的可执行文件。编译过程可能耗时,而在执行过程中 HLO 不发生变化的情况下,编译结果可以持久保存到磁盘以供重复使用,从而显著减少开发迭代时间。

注意,如果HLO在执行之间发生变化,仍然会发生重新编译。

这是一个目前的实验性自选API,必须在执行任何计算之前激活。初始化通过initialize_cache API完成:

import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)

这将初始化指定路径下的持久编译缓存。参数readonly可以用来控制工作线程是否能够写入缓存,这对于使用共享缓存挂载的SPMD工作负载非常有用。

如果您想在多进程训练中使用持久编译缓存(带torch_xla.launchxmp.spawn),您应该为不同的进程使用不同的路径。

def _mp_fn(index):
  # cache init needs to happens inside the mp_fn.
  xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
  ....

if __name__ == '__main__':
  torch_xla.launch(_mp_fn, args=())

如果您没有访问权限 index,您可以使用 xr.global_ordinal()。查看 这里 的可运行示例。

进一步阅读

有关PyTorch/XLA仓库的更多文档,请访问 PyTorch/XLA仓库。在TPU上运行网络的更多示例可在 此处找到。

PyTorch/XLA API

torch_xla

torch_xla.device(index: Optional[int] = None) device[source]

返回给定的XLA设备实例。

如果 SPMD 启用,返回一个虚拟设备,该设备包裹了此进程可用的所有设备。

Parameters

index – 返回的XLA设备索引。对应于 torch_xla.devices()

Returns

一个XLA torch.device

torch_xla.devices() List[device][source]

返回当前进程中的所有可用设备。

Returns

XLA torch.devices 的列表。

torch_xla.device_count() int[source]

返回当前进程中的可访问设备数量。

torch_xla.sync(wait: bool = False)[source]

启动所有待处理的图操作。

Parameters

等待 (布尔值) – 是否阻塞当前进程,直到执行完成。

torch_xla.compile(f: Optional[Callable] = None, full_graph: Optional[bool] = False, name: Optional[str] = None, num_different_graphs_allowed: Optional[int] = None)[source]

使用torch_xla的LazyTensor追踪模式优化给定的模型/函数。 PyTorch/XLA将根据给定的输入对给定的功能进行追踪,并生成表示该功能中执行的PyTorch操作的图。此图将由XLA编译并在加速器上执行(取决于张量的设备)。对于函数中编译的部分,Eager模式将被禁用。

Parameters
  • 模型 (Callable) – 模块/函数用于优化,如果不传递此函数将作为上下文管理器。

  • full_graph (Optional[bool]) – 是否生成单个图。如果设置为True 并且将生成多个图,torch_xla 将抛出带有调试信息的错误并退出。

  • 名称 (可选[名称]) – 编译程序的名称。如果没有指定,则将使用函数 f 的名称。此名称也将用于消息以及 HLO/IR dump 文件。

  • num不同图允许 (Optional[python:int]) – 数目给定模型/函数的可追踪图的不同数 量。如果超过这个限制,将抛出错误。

Example:

# usage 1
@torch_xla.compile()
def foo(x):
  return torch.sin(x) + torch.cos(x)

def foo2(x):
  return torch.sin(x) + torch.cos(x)
# usage 2
compiled_foo2 = torch_xla.compile(foo2)

# usage 3
with torch_xla.compile():
  res = foo2(x)
torch_xla.manual_seed(seed, device=None)[source]

为当前的XLA设备生成随机数设置种子。

Parameters
  • 种子 (python:整数) – 要设置的状态。

  • 设备 (torch.device, 可选) – 需要设置随机数生成器状态的设备。 如果未指定,则将默认设置为设备种子。

运行时

torch_xla.runtime.device_type() Optional[str][source]

返回当前的PyTorch设备类型。

如果未配置默认设备,则选择一个默认设备。

Returns

设备的字符串表示。

torch_xla.runtime.local_process_count() int[source]

返回此主机上运行的进程数量。

torch_xla.runtime.local_device_count() int[source]

返回该主机上的设备总数。

假设每个过程都有相同数量的可访问设备。

torch_xla.runtime.addressable_device_count() int[source]

返回该进程可见的设备数量。

torch_xla.runtime.global_device_count() int[source]

返回所有进程/主机中的设备总数。

torch_xla.runtime.global_runtime_device_count() int[source]

返回所有进程/主机中运行设备的总数,特别是在SPMD中特别有用。

torch_xla.runtime.world_size() int[source]

返回参与工作的进程总数。

torch_xla.runtime.global_ordinal() int[source]

返回此线程在所有进程中的全局顺序号。

全局序号在范围 [0, 全球设备计数) 内。全局序号不保证与 TPU 工作器 ID 有任何可预测的关系,也不保证在每个主机上是连续的。

torch_xla.runtime.local_ordinal() int[source]

返回此线程在本主机中的本地序号。

本地顺序在范围 [0, 当前设备数量)。

torch_xla.runtime.get_master_ip() str[source]

获取运行时的主工作器IP。这会调用后端特定的发现API。

Returns

主工作器的IP地址作为字符串。

torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]

启用SPMD模式的API。这是推荐的启用SPMD的方式。

这会强制SPMD模式,如果某些张量已经在非SPMD设备上初始化。这意味着这些张量会在设备之间进行复制。

Parameters

自动 (bool) – 是否启用自动分片。请参阅 https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding 以获取更多详细信息

torch_xla.runtime.is_spmd()[source]

返回是否设置了SPMD执行。

torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]

初始化持久编译缓存。此API必须在执行任何计算之前调用。

Parameters
  • 路径 (字符串) – 存储持久缓存的路径。

  • readonly (bool) – 是否允许这个工人对缓存有写入权限。

xla_model

torch_xla.core.xla_model.xla_device(n: Optional[int] = None, devkind: Optional[str] = None) device[source]

返回给定的XLA设备实例。

Parameters
  • n (python:int, optional) – 返回的特定实例(序号)。如果指定,则返回特定的XLA设备实例。否则,将返回 devkind 的第一个设备。

  • devkind (string..., optional) – 如果指定,设备类型如 TPU, CUDA, CPU, 或自定义 PJRT 设备。已废弃。

Returns

一个具有所需实例的torch.device

torch_xla.core.xla_model.xla_device_hw(device: Union[str, device]) str[source]

返回给定设备的硬件类型。

Parameters

设备 (字符串torch.device) – 将映射到的真实设备的xla设备。

Returns

给定设备的硬件类型的字符串表示。

torch_xla.core.xla_model.is_master_ordinal(local: bool = True) bool[source]

检查当前进程是否为主序号(0)。

Parameters

本地 (bool) – 是否检查本地或全局主序号。 在多主机复制的情况下,只有一个全局主序号(主机0,设备0),而有NUM_HOSTS个本地主序号。 默认值:True

Returns

一个布尔值,表示当前进程是否为主序号。

torch_xla.core.xla_model.all_reduce(reduce_type: str, inputs: Union[Tensor, List[Tensor]], scale: float = 1.0, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Union[Tensor, List[Tensor]][source]

对输入张量(s)执行 inplace reduce 操作。

Parameters
  • reduce_type (字符串) – 一个在xm.REDUCE_SUM, xm.REDUCE_MUL, xm.REDUCE_AND, xm.REDUCE_OR, xm.REDUCE_MINxm.REDUCE_MAX中的值。

  • 输入 – 或者一个单个 torch.Tensor 或者一个列表中的 torch.Tensor 来执行所有减少操作。

  • 缩放 (python:float) – 一个默认的缩放值,在减少后应用。 默认:1.0

  • 群组 (列表, 可选) –

    一个列表的列表,表示操作 all_reduce() 的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – 是否将布局固定为该通信操作。 布局固定可以防止当每个参与通信的进程具有略有不同的程序时,潜在的数据损坏。但是,它可能会导致某些xla编译失败。在看到类似“HloModule有混合布局约束”的错误消息时,解绑布局。

Returns

如果传递一个单个 torch.Tensor,返回值是一个 torch.Tensor,其中包含减少后的值(跨副本)。如果传递一个列表/元组,此函数将在输入张量上执行 inplace all-reduce 操作,并返回列表/元组本身。

torch_xla.core.xla_model.all_gather(value: Tensor, dim: int = 0, groups: Optional[List[List[int]]] = None, output: Optional[Tensor] = None, pin_layout: bool = True) Tensor[source]

在一个给定的维度上执行一个全聚集操作。

Parameters
  • (torch.Tensor) – 输入张量。

  • 维度 (python:int) – 聚集维度。 默认值:0

  • 群组 (列表, 可选) –

    一个列表的列表,表示操作 all_gather() 的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • 输出 (torch.Tensor) – 可选输出张量。

  • pin_layout (bool, optional) – 是否将布局固定为该通信操作。 布局固定可以防止当每个参与通信的进程具有略有不同的程序时,潜在的数据损坏。但是,它可能会导致某些xla编译失败。在看到类似“HloModule有混合布局约束”的错误消息时,解绑布局。

Returns

一个在dim维度上包含所有参与副本的值的张量。

torch_xla.core.xla_model.all_to_all(value: Tensor, split_dimension: int, concat_dimension: int, split_count: int, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Tensor[source]

执行一个XLA AllToAll()操作在输入张量上。

See: https://www.tensorflow.org/xla/operation_semantics#alltoall

Parameters
  • (torch.Tensor) – 输入张量。

  • split_dimension (python:int) – 在分割操作中应进行的维度。

  • 拼接维度 (python:int) – 拼接应发生在此维度上。

  • split_count (python:int) – 分割计数。

  • 群组 (列表, 可选) –

    一个列表的列表,表示操作 all_reduce() 的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – 是否将布局固定为该通信操作。 布局固定可以防止当每个参与通信的进程具有略有不同的程序时,潜在的数据损坏。但是,它可能会导致某些xla编译失败。在看到类似“HloModule有混合布局约束”的错误消息时,解绑布局。

Returns

结果 torch.Tensorall_to_all() 操作。

torch_xla.core.xla_model.add_step_closure(closure: Callable[[...], Any], args: Tuple[Any] = (), run_async: bool = False)[source]

在步骤结束时添加一个闭包到要运行的列表中。

许多时候在模型训练期间,需要打印/报告(打印到控制台、发布到tensorboard等)信息,这些信息需要检查中间层张量的内容。 在模型代码的不同点检查不同张量的内容需要多次执行,并且通常会导致性能问题。 添加一个步骤关闭将确保它将在屏障之后运行,在所有活跃张量已经转换为设备数据时。 包括由闭包参数捕获的活跃张量。因此,使用add_step_closure()将确保即使有多个闭包队列,也会进行一次执行,只需检查多个张量。 步骤关闭将按它们被排队的顺序依次运行。 请注意,尽管使用此API可以优化执行,但建议每N步后停止打印/报告事件。

Parameters
  • 闭包 (可调用对象) – 要被调用的函数。

  • 参数 (元组) – 传递给闭包的参数。

  • run_async – 如果为真,则异步运行闭包。

torch_xla.core.xla_model.wait_device_ops(devices: List[str] = [])[source]

等待给定设备上的所有异步操作完成。

Parameters

设备 (字符串..., 可选) – 需要等待的异步操作的设备。如果为空,则将等待所有本地设备。

torch_xla.core.xla_model.optimizer_step(optimizer: Optimizer, barrier: bool = False, optimizer_args: Dict = {}, groups: Optional[List[List[int]]] = None, pin_layout: bool = True)[source]

运行提供的优化器步骤,并在所有设备上同步梯度。

Parameters
  • 优化器 (torch.Optimizer) – 该 torch.Optimizer 实例的函数需要被调用。该 step() 函数将被调用 并带有 optimizer_args 命名参数。

  • 屏障 (bool, 可选) – 是否在本API中发出XLA张量屏障。如果使用PyTorch XLA ParallelLoaderDataParallel 支持,则不需要,因为屏障将在XLA数据加载器迭代器 next() 调用中发出。 默认值:False

  • optimizer_args (字典, 可选) – 用于 optimizer.step() 调用的命名参数字典。

  • 群组 (列表, 可选) –

    一个列表的列表,表示操作 all_reduce() 的副本组。例如:[[0, 1, 2, 3], [4, 5, 6, 7]]

    defines two groups, one with the [0, 1, 2, 3] replicas and one with the [4, 5, 6, 7] replicas. If None there will be only one group with all the replicas in it.

  • pin_layout (bool, optional) – 是否在梯度缩减时锁定布局。 请参阅 xm.all_reduce 以获取详细信息。

Returns

相同的值由optimizer.step()调用返回。

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.optimizer_step(self.optimizer)
torch_xla.core.xla_model.save(data: Any, file_or_path: Union[str, TextIO], master_only: bool = True, global_master: bool = False)[source]

将输入数据保存到文件中。

保存的数据在被保存之前转移到PyTorch CPU设备上,因此以下torch.load()将加载CPU数据。 在处理视图时必须小心。相反,建议您在加载张量并将其移动到目标设备后重新创建它们。

Parameters
  • 数据 – 要保存的输入数据。任何嵌套组合的Python对象(列表、元组、集合、字典等)。

  • file_or_path – 保存数据操作的目的地。可以是文件路径或Python文件对象。如果master_only等于False,则路径或文件对象必须指向不同的目的地,否则来自同一主机的所有写入将相互覆盖。

  • master_only (bool, 可选) – 是否仅主设备保存数据。如果为 False,则 file_or_path 参数应为每个参与复制的序数提供不同的文件或路径,否则同一主机上的所有副本将写入同一位置。 默认值: True

  • 全局主 (bool, 可选) – 当 master_onlyTrue 时,此标志控制是否每个主机的主(如果 global_masterFalse)保存内容,或者仅保存全局主(序号 0)。 默认值:False

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.wait_device_ops() # wait for all pending operations to finish.
>>> xm.save(obj_to_save, path_to_save)
>>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only
torch_xla.core.xla_model.rendezvous(tag: str, payload: bytes = b'', replicas: List[int] = []) List[bytes][source]

等待所有网格客户端到达指定的会合地点。

注意:PJRT 不支持 XRT 网格服务器,因此这实际上是一个 xla_rendezvous

Parameters
  • 标签 (字符串) – 加入会面的名称。

  • payload (字节, 可选) – 要发送到会面的负载。

  • 副本 (列表, Python:整数) – 参与会面的副本编号。 空表示网络中的所有副本。 默认值:[]

Returns

所有其他核心交换的payload,其中core 序号 i 的payload在返回元组中的位置为 i

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.rendezvous('example')
torch_xla.core.xla_model.mesh_reduce(tag: str, data, reduce_fn: Callable[[...], Any]) Union[Any, ToXlaTensorArena][source]

执行一个离图客户端网格减少。

Parameters
  • 标签 (字符串) – 加入会面的名称。

  • 数据 – 要减少的数据。reduce_fn 可调用函数将接收一个列表,其中包含来自所有网格客户端进程(每个核心一个)的相同数据的副本。

  • reduce_fn (callable) – 一个接收列表中的 data-like 对象并返回减少结果的函数。

Returns

减小的值。

示例

>>> import torch_xla.core.xla_model as xm
>>> import numpy as np
>>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
torch_xla.core.xla_model.set_rng_state(seed: int, device: Optional[str] = None)[source]

设置随机数生成器的状态。

Parameters
  • 种子 (python:整数) – 要设置的状态。

  • 设备 (字符串, 可选) – 需要设置RNG状态的设备。 如果缺失,则默认使用种子。

torch_xla.core.xla_model.get_rng_state(device: Optional[str] = None) int[source]

获取当前运行的随机数生成器状态。

Parameters

设备 (字符串, 可选) – 需要检索的随机数生成器状态所属的设备。 如果未指定,则默认设置为设备种子。

Returns

随机数生成器的状态,作为整数。

torch_xla.core.xla_model.get_memory_info(device: Optional[device] = None) MemoryInfo[source]

获取设备内存使用情况。

Parameters
  • 设备 – 可选[torch.device] 请求的内存信息所属的设备。

  • 设备。 (如果未传递将使用默认值) –

Returns

包含给定设备内存使用情况的MemoryInfo字典。

示例

>>> xm.get_memory_info()
{'bytes_used': 290816, 'bytes_limit': 34088157184}
torch_xla.core.xla_model.get_stablehlo(tensors: Optional[List[Tensor]] = None) str[source]

获取稳定的HLO格式的计算图。

如果 tensors 不为空,将输出为 tensors 的图将被保存。 如果 tensors 为空,整个计算图将被保存。

对于推理图,建议将模型输出传递给 tensors。 对于训练图,识别“输出”并不直观。使用空的 tensors 是推荐的。

要启用源行信息在StableHLO中,请设置环境变量XLA_HLO_DEBUG=1。

Parameters

张量 (列表[PyTorch张量], 可选) – 代表StableHLO图的输出/根的张量。

Returns

StableHLO 模块以字符串格式。

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors: Optional[Tensor] = None) bytes[source]

获得稳定HLO以获取字节码格式的计算图。

如果 tensors 不为空,将输出为 tensors 的图将被保存。 如果 tensors 为空,整个计算图将被保存。

对于推理图,建议将模型输出传递给 tensors。 对于训练图,识别“输出”并不直观。使用空的 tensors 是推荐的。

Parameters

张量 (列表[PyTorch张量], 可选) – 代表StableHLO图的输出/根的张量。

Returns

StableHLO 模块以字节码格式。

分布式

class torch_xla.distributed.parallel_loader.MpDeviceLoader(loader, device, **kwargs)[source]

将现有的 PyTorch DataLoader 包装在后台数据上传中。

这个类应该只在多进程数据并行主义下使用。它会将传递给ParallelLoader的数据加载器包装起来,并返回当前设备的per_device_loader。

Parameters
  • 加载器 (torch.utils.data.DataLoader) – 要被包装的PyTorch DataLoader。

  • 设备 (torch.device…) – 数据需要发送到的设备。

  • kwargs – 命名参数用于构造函数 ParallelLoader

示例

>>> device = torch_xla.device()
>>> train_device_loader = MpDeviceLoader(train_loader, device)
torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

支持基于多进程的复制。

Parameters
  • fn (callable) – 用于为每个参与复制的设备调用的函数,该函数接受部分 复制。该函数将被调用,第一个参数是全局进程在复制中的索引,随后是传递给 args的参数。

  • 参数 (元组) – 用于 fn 的参数。 默认值:空元组

  • nprocs (python:int) – 进程/设备的数量用于复制。目前,如果指定,则可以是1或最大数量的设备。

  • 加入 (布尔值) – 是否调用应该阻塞等待已启动的进程完成。 默认:True

  • daemon (bool) – 是否正在启动的进程应设置 daemon 标志(参见 Python 多进程 API)。 默认值:False

  • start_method (string) – Python multiprocessing 进程创建方法。 默认值:spawn

Returns

相同的对象由torch.multiprocessing.spawn API返回。如果 nprocs为1,则fn函数将直接调用,API将返回None。

spmd

torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: Mesh, partition_spec: Tuple[Optional[Union[Tuple, int, str]]]) XLAShardedTensor[source]

对提供的张量进行XLA分区注释。在内部,它会对相应的XLATensor进行分片注释,以便于XLA SpmdPartitioner通过程。

Parameters
  • t (Union[torch.Tensor, XLAShardedTensor]) – 输入张量,用于标注分区规格。

  • 网格 (网格) – 描述逻辑 XLA 设备拓扑和底层设备 ID。

  • partition_spec (Tuple[Tuple, python:int, str, None]) – 一个包含设备网格维度索引或 None. 每个索引是一个整数,如果网格轴命名,则为字符串,或者是一个包含整数或字符串的元组。 这指定如何将每个输入秩进行分片(索引到网格形状)或复制(None)。 当指定一个元组时,相应的输入张量轴将在元组中的所有逻辑轴上进行分片。注意,在元组中指定网格轴的顺序将影响结果的分片。

  • dynamo_custom_op (bool) – 如果设置为 True,它会调用 dynamo 自定义操作的 mark_sharding 变体,使其可识别并被 dynamo 追踪。

示例

>>> import torch_xla.runtime as xr
>>> import torch_xla.distributed.spmd as xs
>>> mesh_shape = (4, 2)
>>> num_devices = xr.global_runtime_device_count()
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> input = torch.randn(8, 32).to(xm.xla_device())
>>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor[source]

清除输入张量中的分片注释并返回一个cpu类型转换后的张量。这是一个原地操作,但也会将相同的torch.Tensor返回。

Parameters

t (Union[torch.Tensor, XLAShardedTensor]) – Tensor that we want to clear the sharding

Returns

未分片的张量。

Return type

t (torch.Tensor)

示例

>>> import torch_xla.distributed.spmd as xs
>>> torch_xla.runtime.use_spmd()
>>> t1 = torch.randn(8,8).to(torch_xla.device())
>>> mesh = xs.get_1d_mesh()
>>> xs.mark_sharding(t1, mesh, (0, None))
>>> xs.clear_sharding(t1)
torch_xla.distributed.spmd.set_global_mesh(mesh: Mesh)[source]

设置当前过程可以使用的全局网格。

Parameters

网格 – (网格) 将成为全局网格的对象。

示例

>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> xs.set_global_mesh(mesh)
torch_xla.distributed.spmd.get_global_mesh() Optional[Mesh][source]

获取当前进程的全局网格。

Returns

(可选[Mesh]) 如果全局网格被设置,则返回Mesh对象,否则返回None。

Return type

网格

示例

>>> import torch_xla.distributed.spmd as xs
>>> xs.get_global_mesh()
torch_xla.distributed.spmd.get_1d_mesh(axis_name: Optional[str] = None) Mesh[source]

一个辅助函数,返回所有设备在一个维度上的网格。

Parameters

axis_name – (Optional[str]) 可选字符串,用于表示网格的轴名

Returns

网格对象

Return type

网格

示例

>>> # This example is assuming 1 TPU v4-8
>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> print(mesh.mesh_shape)
(4,)
>>> print(mesh.axis_names)
('data',)
class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]

描述XLA设备拓扑网格的逻辑结构和底层资源。

Parameters
  • device_ids (Union[np.ndarray, List]) – 一个按自定义顺序排列的设备(ID)的展平列表。列表被重塑为一个 mesh_shape 矩阵,使用 C 样式索引顺序填充元素。

  • mesh_shape (Tuple[python:int, ...]) – 一个整数元组,描述设备网格的逻辑拓扑形状 以及每个元素描述对应轴上的设备数量。

  • axis_names (Tuple[str, ...]) – 一个资源轴名称序列,用于将指定的维度与 devices 参数的维度关联。其长度应与 devices 的秩匹配。

示例

>>> mesh_shape = (4, 2)
>>> num_devices = len(xm.get_xla_supported_devices())
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> mesh.get_logical_mesh()
>>> array([[0, 1],
          [2, 3],
          [4, 5],
          [6, 7]])
>>> mesh.shape()
OrderedDict([('x', 4), ('y', 2)])
class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]
Creates a hybrid device mesh of devices connected with ICI and DCN networks.

逻辑网的形状应按网络强度增加顺序排列,例如 [副本、数据、模型],其中mdl具有最多的网络通信需求。

Parameters
  • ici_mesh_shape – 逻辑网状结构的内部连接设备的形状。

  • dcn_mesh_shape – 逻辑网状结构的形状,用于连接外部设备。

示例

>>> # This example is assuming 2 slices of v4-8.
>>> ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
>>> dcn_mesh_shape = (2, 1, 1)
>>> mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
>>> print(mesh.shape())
>>> >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

实验

torch_xla.experimental.eager_mode(enable: bool)[source]

配置 torch_xla 的默认执行模式。

在急切模式下,只有被 `torch_xla.compile`d 的函数才会被追踪和编译。其他 torch 操作将以急切方式执行。

调试

torch_xla.debug.metrics.metrics_report()[source]

检索包含完整指标和计数器报告的字符串。

torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[list] = None, metric_names: Optional[list] = None)[source]

检索包含完整指标和计数器报告的字符串。

Parameters
  • counter_names (list) – 一个需要打印其数据的计数器名称列表。

  • metric_names (list) – 用于打印的数据的指标名称列表。

torch_xla.debug.metrics.counter_names()[source]

检索所有当前活跃的计数器名称。

torch_xla.debug.metrics.counter_value(name)[source]

返回活跃计数器的值。

Parameters

名称 (字符串) – 需要检索值的计数器的名称。

Returns

计数器值作为整数。

torch_xla.debug.metrics.metric_names()[source]

检索所有当前活跃的指标名称。

torch_xla.debug.metrics.metric_data(name)[source]

返回活跃指标的数据。

Parameters

名称 (字符串) – 需要检索其数据的指标的名称。

Returns

指标数据,是一个包含(TOTAL_SAMPLES, ACCUMULATOR, SAMPLES)的元组。 TOTAL_SAMPLES表示已提交到指标的样本总数。一个指标仅保留一定数量的样本(在一个环形缓冲区中)。 ACCUMULATORTOTAL_SAMPLES样本之和。 SAMPLES是一个包含(TIME, VALUE)元组的列表。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源