目录

TorchRec 概念

在本节中,我们将了解 TorchRec 的关键概念, 旨在使用 PyTorch 优化大规模推荐系统。 我们将详细了解每个概念的工作原理,并回顾其使用方法 与 TorchRec 的其余部分一起使用。

TorchRec 具有其模块的特定输入/输出数据类型,以 有效地表示稀疏要素,包括:

  • JaggedTensor:长度 / 偏移量和值的包装器 张量。

  • KeyedJaggedTensor:高效表示多个稀疏 features,可以把它看作是多个 S。JaggedTensor

  • KeyedTensor:允许访问 到 Tensor 值。torch.Tensor

为了实现高性能和高效率,规范在表示稀疏数据时效率非常低。 TorchRec 引入了这些新的数据类型,因为它们提供了高效的 稀疏输入数据的存储和表示。稍后您将看到 on,则 make 输入数据的通信 分布式环境 非常高效 导致关键之一 TorchRec 提供的性能优势。torch.TensorKeyedJaggedTensor

在端到端训练循环中,TorchRec 包括以下内容 主要组件:

  • 计划:接受嵌入表的配置, 环境设置,并为 型。

  • 分享者:根据分片计划划分的分片模型具有不同的 分片策略包括 data-parallel、table-wise、row-wise、 table-wise-row-wise、column-wise 和 table-wise-column-wise 分片。

  • DistributedModelParallel 的 DistributedModelParallel 函数:结合了 sharder、optimizer 和 提供在分布式 方式。

JaggedTensor 的

A 通过长度、值、 和 offsets 的 Offset 进行验证。它被称为“锯齿状”,因为它有效地表示 具有可变长度序列的数据。相反,规范假设每个序列具有相同的长度,即 而实际数据通常并非如此。A 有助于表示此类数据,而无需填充 高效。JaggedTensortorch.TensorJaggedTensor

关键组件:

  • Lengths:表示元素数量的整数列表 对于每个实体。

  • Offsets:表示 展平值张量中的每个序列。这些提供了 长度的替代方案。

  • Values:包含每个实体的实际值的 1D 张量, 连续存储。

这是一个简单的示例,演示了每个组件如何 肖:

# User interactions:
# - User 1 interacted with 2 items
# - User 2 interacted with 3 items
# - User 3 interacted with 1 item
lengths = [2, 3, 1]
offsets = [0, 2, 5]  # Starting index of each user's interactions
values = torch.Tensor([101, 102, 201, 202, 203, 301])  # Item IDs interacted with
jt = JaggedTensor(lengths=lengths, values=values)
# OR
jt = JaggedTensor(offsets=offsets, values=values)

KeyedJaggedTensor

A 扩展了 的功能 引入键(通常是特征名称)来标记不同的 特征组,例如 User Features 和 Item 特征。这 是 OF 中使用的数据类型,因为它们用于表示多个要素 在表中。KeyedJaggedTensorJaggedTensorforwardEmbeddingBagCollectionEmbeddingCollection

A 具有隐含的批量大小,即数字 的特征除以张量的长度。示例 Below 的批处理大小为 2。与 a 类似,则 and 函数相同。您还可以 通过以下方式访问特征的 、 和 从 访问密钥。KeyedJaggedTensorlengthsJaggedTensoroffsetslengthslengthsoffsetsvaluesKeyedJaggedTensor

keys = ["user_features", "item_features"]
# Lengths of interactions:
# - User features: 2 users, with 2 and 3 interactions respectively
# - Item features: 2 items, with 1 and 2 interactions respectively
lengths = [2, 3, 1, 2]
values = torch.Tensor([11, 12, 21, 22, 23, 101, 102, 201])
# Create a KeyedJaggedTensor
kjt = KeyedJaggedTensor(keys=keys, lengths=lengths, values=values)
# Access the features by key
print(kjt["user_features"])
# Outputs user features
print(kjt["item_features"])

计划

TorchRec 计划器有助于确定最佳分片配置 一个模型。它评估了分片嵌入的多种可能性 表并针对性能进行优化。规划器执行 以后:

  • 评估硬件的内存约束。

  • 根据内存提取估计计算需求,例如 嵌入查找。

  • 解决特定于数据的因素。

  • 考虑其他硬件细节(如带宽)以生成 最优分片计划。

为了确保准确考虑这些因素,Planner 可以 合并有关嵌入表、约束、硬件的数据 信息和拓扑来帮助生成最佳计划。

EmbeddingTable 的分片

TorchRec shard 提供多种分片策略,用途多样 案例中,我们概述了一些分片策略以及它们的工作原理 以及它们的好处和局限性。通常,我们建议使用 TorchRec 规划器为你生成一个分片计划,因为它会找到 模型中每个嵌入表的最佳分片策略。

每个分片策略都决定了如何进行分表,是否 表应该被剪掉以及如何,是保留一份还是几份 一些表格,等等。每张表的结果 分片,无论是一个嵌入表还是其中的一部分,都引用 to 作为分片。

可视化 TorchRec 中提供的分片类型的差异

图 1:可视化 TorchRec 中提供的不同分片方案下表分片的放置

以下是 TorchRec 中可用的所有分片类型的列表:

  • 表级 (TW):顾名思义,嵌入表保持为 整块并置于一个等级。

  • 逐列 (CW):表格沿维度拆分, 例如,被拆分为 4 个分片:。emb_dimemb_dim=256[64, 64, 64, 64]

  • Row-wise (RW):表格沿维度分割, 通常在所有等级之间平均分配。hash_size

  • table-wise-row-wise (TWRW):table 被放置在一个主机上,被分割 在该主机上的排名中按行排列。

  • 网格分片 (GS):表是 CW 分片的,每个 CW 分片都放置 TWRW 在主机上。

  • 数据并行 (DP):每个排名都保留一个表的副本。

分片后,模块将转换为 它们本身,在 TorchRec 中称为 and。这些模块处理 输入数据的通信、嵌入查找和梯度。ShardedEmbeddingCollectionShardedEmbeddingBagCollection

使用 TorchRec 分片模块进行分布式训练

有很多分片策略可用,我们如何确定是哪一个 使用?每个分片方案都有一个相关的成本,在 与模型大小和 GPU 数量相结合,确定哪个分片 策略最适合模型。

没有分片,其中每个 GPU 都保留嵌入表的副本 (DP) 时,主要成本是每个 GPU 查找 在前向通道中将向量嵌入到其内存中,并更新 gradients 的 Gradients 进行转换。

使用分片时,会有额外的通信成本:每个 GPU 都需要 向其他 GPU 请求嵌入向量查找,并传达 梯度。这通常称为通信。在 TorchRec 中,对于给定 GPU 上的输入数据,我们确定 其中,数据每个部分的嵌入分片所在的位置,并发送 it 到目标 GPU。然后,该目标 GPU 返回嵌入向量 返回到原始 GPU。在向后传递中,将发送梯度 返回到目标 GPU,分片将使用 优化。all2all

如上所述,分片要求我们传达输入数据 和嵌入查找。TorchRec 分三个主要阶段处理这个问题,我们 将这称为使用的分片嵌入模块 forward 在 TorchRec 模型的训练和推理中:

  • 功能 全部到全部/输入分布 (input_dist)

    • 将输入数据(以 的形式)传送到 包含相关嵌入表分片的适当设备KeyedJaggedTensor

  • Embedding Lookup

    • 在 feature all 之后形成的新输入数据的 Lookup embeddings all to 所有交易所

  • 将 All 嵌入到 All/Output Distribution (output_dist)

    • 将嵌入的查找数据传回相应的设备 要求它(根据设备的输入数据 已接收)

  • 向后传递执行相同的操作,但顺序相反。

下图演示了它的工作原理:

可视化前向传递,包括分片 TorchRec 模块的input_dist、查找和output_dist

图 2:表级分片表的正向传递,包括分片 TorchRec 模块的input_dist、查找和output_dist

DistributedModelParallel

以上所有内容最终都形成了 TorchRec 使用的主要入口点 对计划进行分片和集成。概括地说,执行以下操作:DistributedModelParallel

  • 通过设置进程组和 分配设备类型。

  • 如果未提供着色器,则使用默认着色器,默认着色器包括 。EmbeddingBagCollectionSharder

  • 接受提供的分片计划,如果未提供,则 生成一个。

  • 创建模块的分片版本并替换原始 例如,模块转换为 .EmbeddingCollectionShardedEmbeddingCollection

  • 默认情况下,将 with 包装起来,使模块既是 model 又是 data 平行。DistributedModelParallelDistributedDataParallel

优化

TorchRec 模块提供了一个无缝的 API 来融合向后传递和 optimizer 步骤,在 性能并减少使用的内存,以及 将不同的优化器分配给不同的模型参数。

向后可视化优化器的融合以更新稀疏嵌入表

图 3:将嵌入向后融合与稀疏优化器

推理

推理环境与训练不同,它们非常 对性能和模型大小敏感。有两个关键 差异TorchRec 推理针对以下方面进行了优化:

  • 量化:对推理模型进行量化以降低延迟 并减小了模型大小。这种优化让我们可以使用尽可能少的设备 尽可能地进行推理,以最大限度地减少延迟。

  • C++ 环境:为了进一步减少延迟,该模型是 在 C++ 环境中运行。

TorchRec 提供了以下功能来将 TorchRec 模型转换为 推理就绪:

  • 用于量化模型的 API,包括自动优化 带 FBGEMM TBE

  • 用于分布式推理的分片嵌入

  • 将模型编译为 TorchScript (兼容 C++)

另请参阅

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源