目录

TorchRec 简介

创建时间: Oct 02, 2024 |上次更新时间:2024 年 10 月 10 日 |上次验证: Oct 02, 2024

TorchRec 是一个 PyTorch 库,专为使用嵌入构建可扩展且高效的推荐系统而量身定制。 本教程将指导您完成安装过程,介绍嵌入的概念,并强调它们在 推荐系统。它提供了有关使用 PyTorch 实现嵌入的实际演示 以及 TorchRec,专注于通过分布式训练和高级优化来处理大型嵌入表。

您将学到什么
  • 嵌入的基础知识及其在推荐系统中的作用

  • 如何设置 TorchRec 以在 PyTorch 环境中管理和实施嵌入

  • 探索在多个 GPU 之间分配大型嵌入表的高级技术

先决条件
  • PyTorch v2.5 或更高版本,带有 CUDA 11.8 或更高版本

  • Python 3.9 或更高版本

  • FBGEMM的

安装依赖项

在 Google Colab 或其他环境中运行本教程之前,请先安装 以下依赖项:

!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121

注意

如果您在 Google Colab 中运行此操作,请确保切换到 GPU 运行时类型。 有关更多信息, 请参阅启用 CUDA

嵌入

在构建推荐系统时,分类特征通常 具有大量的基数、帖子、用户、广告等。

为了表示这些实体并对这些关系进行建模,使用了嵌入。在机器学习中,嵌入是一个向量 高维空间中的实数,用于表示 复杂数据,如文字、图像或用户

RecSys 中的嵌入

现在你可能想知道,这些 embedding 是如何在第一个 地方?嵌入在 Embedding Table 中表示为单独的行,也称为嵌入权重。原因 因为 embeddings 或 embeddings table weights 只是被训练的 就像模型的所有其他权重一样,通过 Gradient Descent!

嵌入表只是一个用于存储嵌入的大型矩阵,其中 两个维度 (B, N),其中:

  • B 是 table 存储的嵌入数

  • N 是每个嵌入的维度数(N 维嵌入)。

嵌入表的输入表示要检索的嵌入查找 特定索引或行的嵌入向量。在推荐系统中,例如 与许多大型系统中使用的 ID 一样,唯一 ID 不仅用于 特定用户,但也跨帖子和广告等实体投放 查找索引到相应的嵌入表!

嵌入通过以下过程在 RecSys 中训练:

  • 输入/查找索引作为唯一 ID 馈送到模型中。ID 为 hashed 到嵌入表的总大小,以防止在以下情况下出现问题 ID >行数

  • 然后检索并汇集嵌入,例如将 sum 或 嵌入的平均值。这是必需的,因为 embeddings,而模型需要一致的形状。

  • 嵌入与模型的其余部分结合使用,以 生成预测,例如 Click-Through Rate (点击率)对于广告。

  • 损失是使用预测和标签计算的 例如,模型的所有权重都通过 梯度下降和反向传播,包括与示例关联的嵌入权重。

这些嵌入对于表示分类特征至关重要,例如 作为用户、帖子和广告,为了捕捉关系并使 很好的推荐。深度学习建议 模型 (DLRM) 论文 有关在 RecSsys 中使用嵌入表的技术详细信息。

本教程介绍了 embeddings、showcase 的概念 TorchRec 特定的模块和数据类型,并描述了分布式训练 与 TorchRec 配合使用。

import torch

PyTorch 中的嵌入

在 PyTorch 中,我们有以下类型的嵌入:

  • :一个嵌入表,其中 forward pass 返回 按原样嵌入自身。

  • :前向传递返回的嵌入表 然后进行池化的嵌入,例如 sum 或 mean,否则为 作为 Pooled Embeddings

在本节中,我们将对性能进行非常简短的介绍 通过将索引传入表中来嵌入查找。

num_embeddings, embedding_dim = 10, 4

# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)

# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
    num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
    num_embeddings, embedding_dim, _weight=weights
)

# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)

embeddings = embedding_collection(ids)

# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)

# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)

print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)

# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
Weights: tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936],
        [0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294],
        [0.8854, 0.5739, 0.2666, 0.6274],
        [0.2696, 0.4414, 0.2969, 0.8317],
        [0.1053, 0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516, 0.0753],
        [0.8860, 0.5832, 0.3376, 0.8090],
        [0.5779, 0.9040, 0.5547, 0.3423]])
Embedding Collection Table:  Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936],
        [0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294],
        [0.8854, 0.5739, 0.2666, 0.6274],
        [0.2696, 0.4414, 0.2969, 0.8317],
        [0.1053, 0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516, 0.0753],
        [0.8860, 0.5832, 0.3376, 0.8090],
        [0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Embedding Bag Collection Table:  Parameter containing:
tensor([[0.8823, 0.9150, 0.3829, 0.9593],
        [0.3904, 0.6009, 0.2566, 0.7936],
        [0.9408, 0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411, 0.4294],
        [0.8854, 0.5739, 0.2666, 0.6274],
        [0.2696, 0.4414, 0.2969, 0.8317],
        [0.1053, 0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516, 0.0753],
        [0.8860, 0.5832, 0.3376, 0.8090],
        [0.5779, 0.9040, 0.5547, 0.3423]], requires_grad=True)
Input row IDS:  tensor([[1, 3]])
Embedding Collection Results:
tensor([[[0.3904, 0.6009, 0.2566, 0.7936],
         [0.8694, 0.5677, 0.7411, 0.4294]]], grad_fn=<EmbeddingBackward0>)
Shape:  torch.Size([1, 2, 4])
Embedding Bag Collection Results:
tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<EmbeddingBagBackward0>)
Shape:  torch.Size([1, 4])
Mean:  tensor([[0.6299, 0.5843, 0.4988, 0.6115]], grad_fn=<MeanBackward1>)

祝贺!现在,您对如何使用 嵌入表 — 现代推荐的基础之一 系统!这些表表示实体及其关系。为 示例,给定用户与 Pages 和 Posts 之间的关系 他们喜欢。

TorchRec 功能概述

在上面的部分中,我们学习了如何使用嵌入表,这是 现代推荐系统!这些表表示实体和 关系,例如用户、页面、帖子等。鉴于这些 实体总是在增加,通常会应用哈希函数 以确保 ID 位于特定嵌入表的边界内。 但是,为了表示大量的实体并减少哈希 冲突,这些表格可能会变得非常庞大(想想广告的数量 例如)。事实上,这些表可能会变得如此巨大,以至于它们 即使有 80G 内存,也无法容纳 1 个 GPU。

为了训练具有大量嵌入表的模型,将这些 表是必需的,然后引入一组全新的 并行和优化中的问题和机遇。幸运的是,我们有 遇到、整合和解决的 TorchRec 库 其中许多担忧。TorchRec 作为一个库,它提供 用于大规模分布式嵌入的基元

接下来,我们将探讨 TorchRec 的主要功能 库。我们将从开始,并将其扩展到 自定义 TorchRec 模块,探索分布式训练环境 生成 embedding 的分片计划,看看固有的 TorchRec 优化,并扩展模型以准备好在 C++ 中进行推理。 以下是本节内容的简要概述:torch.nn.Embedding

  • TorchRec 模块和数据类型

  • 分布式训练、分片和优化

  • 推理

让我们从导入 TorchRec 开始:

import torchrec

本节介绍 TorchRec 模块和数据类型,包括 实体为 AND 、 、 等。EmbeddingCollectionEmbeddingBagCollectionJaggedTensorKeyedJaggedTensorKeyedTensor

由 至EmbeddingBagEmbeddingBagCollection

我们已经探索和 . TorchRec 通过在 其他 words 模块中,这些模块可以有多个嵌入表,其中 和 我们将用来表示一组 嵌入袋。EmbeddingCollectionEmbeddingBagCollectionEmbeddingBagCollection

在下面的示例代码中,我们创建了一个 (EBC) 带有两个嵌入袋,1 个代表产品,1 个代表用户。 每个表 和 都由一个 64 维度表示 嵌入大小为 4096。EmbeddingBagCollectionproduct_tableuser_table

ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        )
    ]
)
print(ebc.embedding_bags)
ModuleDict(
  (product_table): EmbeddingBag(4096, 64, mode='sum')
  (user_table): EmbeddingBag(4096, 64, mode='sum')
)

让我们检查 forward 方法和 模块的输入和输出:EmbeddingBagCollection

import inspect

# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
    """
    Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
    and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.

    Args:
        features (KeyedJaggedTensor): Input KJT
    Returns:
        KeyedTensor
    """
    flat_feature_names: List[str] = []
    for names in self._feature_names:
        flat_feature_names.extend(names)
    inverse_indices = reorder_inverse_indices(
        inverse_indices=features.inverse_indices_or_none(),
        feature_names=flat_feature_names,
    )
    pooled_embeddings: List[torch.Tensor] = []
    feature_dict = features.to_dict()
    for i, embedding_bag in enumerate(self.embedding_bags.values()):
        for feature_name in self._feature_names[i]:
            f = feature_dict[feature_name]
            res = embedding_bag(
                input=f.values(),
                offsets=f.offsets(),
                per_sample_weights=f.weights() if self._is_weighted else None,
            ).float()
            pooled_embeddings.append(res)
    return KeyedTensor(
        keys=self._embedding_names,
        values=process_pooled_embeddings(
            pooled_embeddings=pooled_embeddings,
            inverse_indices=inverse_indices,
        ),
        length_per_key=self._lengths_per_embedding,
    )

TorchRec 输入/输出数据类型

TorchRec 的模块的输入和输出具有不同的数据类型:、 和 。现在你 可能会问,为什么要创建新的数据类型来表示稀疏特征?自 回答这个问题,我们必须了解稀疏特征是怎样的 在代码中表示。JaggedTensorKeyedJaggedTensorKeyedTensor

稀疏要素也称为 和 ,是将用作 ID 的 ID indices 添加到嵌入表中检索该 ID 的嵌入。自 举一个非常简单的例子,想象一个稀疏特征是 Ads 用户与之交互的 Paypal S T输入本身将是一组广告 ID 用户与之交互,并且检索到的嵌入向量将是 这些广告的语义表示形式。表示 代码中的这些功能是,在每个输入示例中, ID 是可变的。某一天,用户可能只与一个广告互动 而第二天他们与 Three 互动。id_list_featureid_score_list_feature

一个简单的表示形式如下所示,其中我们有一个张量表示一个批次的示例中有多少个索引,还有一个包含索引本身的张量。lengthsvalues

# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])

# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])

接下来,让我们看看偏移量以及每个批次中包含的内容

# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
    "Second Batch: ",
    id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)

from torchrec import JaggedTensor

# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)

# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())

# Convert to list of values
print("List of Values: ", jt.to_dense())

# ``__str__`` representation
print(jt)

from torchrec import KeyedJaggedTensor

# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())

# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())

# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())

# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())

# ``KeyedJaggedTensor`` string representation
print(kjt)

# Q2: What are the offsets for the ``KeyedJaggedTensor``?

# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result

# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())

# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)

# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)
Offsets:  tensor([1, 3])
First Batch:  tensor([5])
Second Batch:  tensor([7, 1])
Offsets:  tensor([0, 1, 3])
List of Values:  [tensor([5]), tensor([7, 1])]
JaggedTensor({
    [[5], [7, 1]]
})

Keys:  ['product', 'user']
Lengths:  tensor([3, 1, 2, 2])
Values:  tensor([1, 2, 1, 5, 2, 3, 4, 1])
to_dict:  {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f4edb57c940>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7f4edb57c9a0>}
KeyedJaggedTensor({
    "product": [[1, 2, 1], [5]],
    "user": [[2, 3], [4, 1]]
})

['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])

恭喜!您现在了解了 TorchRec 模块和数据类型。 为能走到这一步而拍拍自己的后背。接下来,我们将 了解分布式训练和分片。

分布式训练和分片

现在我们已经掌握了 TorchRec 模块和数据类型,是时候了 将其提升到一个新的水平。

请记住,TorchRec 的主要目的是为 分布式嵌入。到目前为止,我们只使用嵌入表 在单个设备上。考虑到嵌入表有多小,这是可能的 已经存在,但在生产环境中,情况通常并非如此。 嵌入表通常会变得很大,一个表无法容纳在单个表中 GPU,从而对多个设备和分布式 环境。

在本节中,我们将探讨如何设置分布式环境 实际生产训练的确切方式,并探索分片 embedding table,全部使用 TorchRec。

此部分也将仅使用 1 个 GPU,尽管它将在 分布式时尚。这只是 training 的限制,因为 training 每个 GPU 都有一个进程。推理不会遇到此要求

在下面的示例代码中,我们设置了 PyTorch 分布式环境。

警告

如果你在 Google Colab 中运行它,你只能调用这个 cell 一次, 再次调用它会导致错误,因为您只能初始化进程 group 一次。

import os

import torch.distributed as dist

# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"

# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")

print(f"Distributed environment initialized: {dist}")
Distributed environment initialized: <module 'torch.distributed' from '/usr/local/lib/python3.10/dist-packages/torch/distributed/__init__.py'>

分布式嵌入

我们已经使用了主 TorchRec 模块:.我们已经研究了它的工作原理以及 data 以 TorchRec 表示。但是,我们还没有探索过 TorchRec 的主要部分,即分布式 embeddingsEmbeddingBagCollection

GPU 是当今 ML 工作负载最受欢迎的选择,因为它们 能够执行数量级以上的浮点运算 (FLOPs) 比 CPU 多。然而 GPU 具有稀缺快速内存(HBM,即 类似于 CPU 的 RAM),通常为 ~10 GB 的 GB。

RecSys 模型可以包含远远超过内存的嵌入表 limit 的 1 个 GPU,因此需要分布嵌入表 跨多个 GPU,也称为模型并行。在 另一方面,数据并行是复制整个模型的地方 每个 GPU,每个 GPU 为其获取不同的数据批次 training,在 backwards pass 上同步梯度。

需要较少计算但更多内存的模型部分 (嵌入)与模型并行一起分布,而需要更多计算和更少内存的部分(密集层、MLP 等)则使用 使用 Data Parallel 进行分布式

分片

为了分发 embedding 表,我们将 embedding 拆分 table 放入 parts 中,并将这些 part 放置在不同的设备上 称为 “分片”。

有多种方法可以对嵌入表进行分片。最常见的方法是:

  • Table-Wise:表格完全放置在一台设备上

  • Column-Wise:对 embedding table 的列进行分片

  • Row-Wise:对 embedding table 的行进行分片

分片模块

虽然所有这些似乎都需要处理和实施,但您已经陷入了 运气。TorchRec 提供了所有原语,便于分发 训练和推理!实际上,TorchRec 模块有两个对应的 类,用于在分布式 环境:

  • 模块 sharder:此类公开一个 API 处理 TorchRec 模块的分片,从而生成分片模块。 * 对于 ,分片是 EmbeddingBagCollectionShardershardEmbeddingBagCollection

  • 分片模块:此类是 TorchRec 模块的分片变体。 它具有与常规 TorchRec 模块相同的输入/输出,但要多得多 更优化,可在分布式环境中工作。 * 对于 ,分片变体为 ShardedEmbeddingBagCollectionEmbeddingBagCollection

每个 TorchRec 模块都有一个 unsharded 和 sharded 变体。

  • 未分片版本旨在进行原型设计和试验。

  • 分片版本旨在用于分布式环境 分布式训练和推理。

例如, TorchRec 模块的分片版本将处理 Model 所需的一切 并行性,例如 GPU 之间的通信以进行分发 嵌入到正确的 GPU 中。EmbeddingBagCollection

模块复习EmbeddingBagCollection

ebc

from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv

# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()

# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"

print(f"Process Group: {pg}")
Process Group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f50a8c5cf30>

计划

在展示分片的工作原理之前,我们必须了解 Planner,这有助于我们确定最佳分片配置。

给定许多嵌入表和许多秩,则有许多 可能的不同分片配置。例如,给定 2 个嵌入表和 2 个 GPU,您可以:

  • 在每个 GPU 上放置 1 张表

  • 将两个表放在一个 GPU 上,而不将表放在另一个 GPU 上

  • 在每个 GPU 上放置特定的行和列

考虑到所有这些可能性,我们通常需要一个分片 配置。

这就是规划师的用武之地。规划师能够确定 给定嵌入表的数量和 GPU 的数量,什么是最佳 配置。事实证明,手动完成这非常困难, 工程师必须考虑大量因素来确保 最优分片计划。幸运的是,TorchRec 在 使用 Planner。

TorchRec 规划器:

  • 评估硬件的内存限制

  • 估计基于内存获取作为嵌入查找的计算

  • 解决数据特定因素

  • 考虑其他硬件细节(如带宽)以生成最佳分片计划

为了考虑所有这些变量,TorchRec Planner 可以接收各种数量的数据来嵌入表, constraints、hardware information 和 topology 来帮助为模型生成最佳分片计划,即 通常跨堆栈提供。

要了解有关分片的更多信息,请参阅我们的分片 教程

# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
    topology=Topology(
        world_size=1,
        compute_device="cuda",
    )
)

# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)

print(f"Sharding Plan generated: {plan}")
Sharding Plan generated: module:

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
product_table | table_wise    | fused          | [0]
user_table    | table_wise    | fused          | [0]

    param     | shard offsets | shard sizes |   placement
------------- | ------------- | ----------- | -------------
product_table | [0, 0]        | [4096, 64]  | rank:0/cuda:0
user_table    | [0, 0]        | [4096, 64]  | rank:0/cuda:0

Planner 结果

正如你在上面看到的,当运行 planner 时,有相当多的输出。 我们可以看到很多统计数据正在计算,以及我们的 桌子最终会被放置。

运行 planner 的结果是一个静态计划,可以重用 用于分片!这允许分片对于生产模型是静态的 而不是每次都确定新的分片计划。下面,我们使用 分片计划最终生成我们的 .ShardedEmbeddingBagCollection

# The static plan that was generated
plan

env = ShardingEnv.from_process_group(pg)

# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))

print(f"Sharded EBC Module: {sharded_ebc}")
Sharded EBC Module: ShardedEmbeddingBagCollection(
  (lookups):
   GroupedPooledEmbeddingsLookup(
      (_emb_modules): ModuleList(
        (0): BatchedFusedEmbeddingBag(
          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
        )
      )
    )
   (_output_dists):
   TwPooledEmbeddingDist()
  (embedding_bags): ModuleDict(
    (product_table): Module()
    (user_table): Module()
  )
)

GPU 训练LazyAwaitable

请记住,TorchRec 是一个高度优化的分布式库 嵌入。TorchRec 引入的一个概念是为了实现更高的 在 GPU 上训练的性能是 LazyAwaitable。 您将看到类型作为各种分片的输出 TorchRec 模块。类型所做的只是延迟计算一些 result,并且它通过像异步类型一样来实现。LazyAwaitableLazyAwaitable

from typing import List

from torchrec.distributed.types import LazyAwaitable


# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
    def __init__(self, size: List[int]) -> None:
        super().__init__()
        self._size = size

    def _wait_impl(self) -> torch.Tensor:
        return torch.ones(self._size)


awaitable = ExampleAwaitable([3, 2])
awaitable.wait()

kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)

kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))

print(kt.keys())

print(kt.values().shape)

# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)
<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x7f4ed824a9e0>
<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])

分片 TorchRec 模块剖析

我们现在已经成功地将给定的 我们生成的分片计划!分片模块具有来自 TorchRec 抽象出分布式通信/计算 多个 GPU。事实上,这些 API 针对性能进行了高度优化 在训练和推理中。以下是 TorchRec 提供的分布式训练/推理EmbeddingBagCollection

  • input_dist:处理从 GPU 到 GPU 的输入分配。

  • lookups:实际嵌入是否在优化的 使用 FBGEMM TBE 的批处理方式(稍后会详细介绍)。

  • output_dist:处理从 GPU 到 GPU 的分配输出。

输入和输出的分配是通过 NCCL 完成的 集合, 即 All-to-Alls, 这是所有 GPU 相互发送和接收数据的地方。 TorchRec 与 PyTorch 的接口,为 Collective 和 为最终用户提供干净的抽象,消除了对 较低级别的详细信息。

向后传递执行所有这些集合,但反之则 order 的梯度分布。、 和 都依赖于分片方案。由于我们在 按表方式,这些 API 是由 TwPooledEmbeddingSharding 构建的模块。input_distlookupoutput_dist

sharded_ebc

# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists

# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
[TwPooledEmbeddingDist(
  (_dist): PooledEmbeddingsAllToAll()
)]

优化嵌入查找

在对嵌入表的集合执行查找时,一个微不足道的 解决方案是遍历所有 and do 每个表一次查找。这正是标准的 unsharded 所做的。但是,虽然此解决方案 很简单,它非常慢。nn.EmbeddingBagsEmbeddingBagCollection

FBGEMM 是一个 库,该库提供 GPU 运算符(也称为内核),该运算符 都非常优化。这些运算符之一称为 Table Batched 嵌入 (TBE) 提供两个主要优化:

  • Table batching,允许您使用 一次内核调用。

  • Optimizer Fusion,它允许模块在给定 规范的 PyTorch 优化器和参数。

使用 FBGEMM TBE 作为查找 而不是传统,以优化嵌入 查找。ShardedEmbeddingBagCollectionnn.EmbeddingBags

sharded_ebc._lookups
[GroupedPooledEmbeddingsLookup(
  (_emb_modules): ModuleList(
    (0): BatchedFusedEmbeddingBag(
      (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
    )
  )
)]

DistributedModelParallel

我们现在探索了单个 !我们是 能够获取 并使用 unsharded 生成模块。此工作流很好,但 通常在实现模型并行时,DistributedModelParallel (DMP) 用作标准接口。当将模型包装(在 我们的情况),使用 DMP 时,将发生以下情况:EmbeddingBagCollectionEmbeddingBagCollectionSharderEmbeddingBagCollectionShardedEmbeddingBagCollectionebc

  1. 决定如何对模型进行分片。DMP 将收集可用的 分片程序中,并提出一个最佳分片方式的计划 嵌入表(例如EmbeddingBagCollection)

  2. 实际上对模型进行分片。这包括为每个 embedding table 的 embedding table 中。

DMP 接收我们刚刚试验过的所有内容,例如静态 分片计划、分片列表等。然而,它也有一些不错的 默认以无缝分片 TorchRec 模型。在这个玩具示例中, 由于我们有两个嵌入表和一个 GPU,因此 TorchRec 会将两者 在单个 GPU 上。

ebc

model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))

out = model(kjt)
out.wait()

model
WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.

DistributedModelParallel(
  (_dmp_wrapped_module): ShardedEmbeddingBagCollection(
    (lookups):
     GroupedPooledEmbeddingsLookup(
        (_emb_modules): ModuleList(
          (0): BatchedFusedEmbeddingBag(
            (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
          )
        )
      )
     (_input_dists):
     TwSparseFeaturesDist(
        (_dist): KJTAllToAll()
      )
     (_output_dists):
     TwPooledEmbeddingDist(
        (_dist): PooledEmbeddingsAllToAll()
      )
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
  )
)

分片最佳实践

目前,我们的配置仅在 1 个 GPU(或等级)上进行分片,这 很简单:只需将所有表放在 1 个 GPU 内存上即可。然而,实际上 生产使用案例中,嵌入表通常分片 数百个 GPU,具有不同的分片方法,例如 table-wise, 逐行和逐列。确定 适当的分片配置(以防止内存不足问题),而 不仅在内存方面保持平衡,而且在计算方面也保持平衡 最佳性能。

在 Optimizer 中添加

请记住,TorchRec 模块针对大规模进行了超优化 分布式训练。一个重要的优化是关于 优化。

TorchRec 模块提供了一个无缝的 API 来融合 向后传递和优化训练步骤,提供显著的 优化性能和减少使用的内存,以及 将不同优化器分配给不同模型的粒度 参数。

优化器类

TorchRec 使用 ,其中包含 .A 有效让一切变得简单 来处理模型中各种子组的多个优化器。A 扩展了 和 是 initialized through a dictionary of parameters 公开参数。 中的每个模块都有自己的模块,该模块组合成一个 .CombinedOptimizerKeyedOptimizersCombinedOptimizerKeyedOptimizertorch.optim.OptimizerTBEEmbeddingBagCollectionKeyedOptimizerCombinedOptimizer

TorchRec 中的融合优化器

使用 ,优化器是融合的,它 表示优化器更新是在 backward 中完成的。这是一个 优化,其中优化器嵌入 渐变不会具体化并直接应用于参数。 这带来了显著的内存节省,因为嵌入梯度 通常是参数本身的大小。DistributedModelParallel

但是,您可以选择使用不 应用此优化,让您检查 embedding gradients 或 根据需要对其应用计算。在这种情况下,密集优化器 将是您的规范 PyTorch 模型训练循环,其中 优化。dense

通过 创建优化器后,您可以 仍然需要管理其他参数的优化器,而不是 与 TorchRec 嵌入模块相关联。查找其他 参数 用。 像使用普通 Torch 一样对这些参数应用优化器 optimizer 并将 this 和 the 合并为一个,您可以在 Training Loop 中使用它来。DistributedModelParallelin_backward_optimizer_filter(model.named_parameters())model.fused_optimizerCombinedOptimizerzero_gradstep

将 Optimizer 添加到EmbeddingBagCollection

我们将以两种方式来实现,这两种方式是等效的,但为您提供选项 根据您的偏好:

  1. 在 sharder 中传递优化器 kwargs。fused_params

  2. 通过 ,它将优化器 参数传递给 或 。apply_optimizer_in_backwardfused_paramsTBEEmbeddingBagCollectionEmbeddingCollection

# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
from fbgemm_gpu.split_embedding_configs import EmbOptimType


# We initialize the sharder with
fused_params = {
    "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
    "learning_rate": 0.02,
    "eps": 0.002,
}

# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)

# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))

# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")

print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")

from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
import copy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it

# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}

for name, param in ebc_apply_opt.named_parameters():
    print(f"{name=}")
    apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)

sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))

# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))

# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())

# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")

loss.backward()

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
Original Sharded EBC fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
    lr: 0.01
)
Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
    lr: 0.02
)
Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
/var/lib/workspace/intermediate_source/torchrec_intro_tutorial.py:876: DeprecationWarning:

`TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.

name='embedding_bags.product_table.weight'
name='embedding_bags.user_table.weight'
: EmbeddingFusedOptimizer (
Parameter Group 0
    lr: 0.5
)
<class 'torchrec.optim.keyed.CombinedOptimizer'>
Non Fused Model Parameters:
dict_keys([])
First Iteration Loss: 255.66006469726562
Second Iteration Loss: 245.43795776367188

推理

既然我们已经能够训练分布式嵌入,那么我们该如何获取 训练好的模型并对其进行优化以进行推理?推理通常是 对模型的性能和大小非常敏感。只运行 在 Python 环境中训练的模型效率非常低。 推理和训练之间有两个主要区别 环境:

  • 量化:推理模型通常 量化,其中模型参数会失去精度,以降低 预测和缩小的模型大小。例如,FP32(4 字节)在 对于每个嵌入权重,将模型训练到 INT8(1 字节)。这也是 考虑到嵌入表的规模庞大,这是必要的,因为我们想使用 尽可能少的设备进行推理,以最大限度地减少延迟。

  • C++ 环境:推理延迟非常重要,因此为了确保 性能充足,该模型通常在 C++ 环境中运行, 以及我们没有 Python 运行时的情况,例如 装置。

TorchRec 提供了用于将 TorchRec 模型转换为 Present 的基元 Inference Ready with (推理就绪):

  • 用于量化模型的 API,简介 使用 FBGEMM TBE 自动优化

  • 用于分布式推理的分片嵌入

  • 将模型编译为 TorchScript (兼容 C++)

在本节中,我们将介绍以下整个工作流程:

  • 量化模型

  • 对量化模型进行分片

  • 将分片量化模型编译成 TorchScript

ebc

class InferenceModule(torch.nn.Module):
    def __init__(self, ebc: torchrec.EmbeddingBagCollection):
        super().__init__()
        self.ebc_ = ebc

    def forward(self, kjt: KeyedJaggedTensor):
        return self.ebc_(kjt)

module = InferenceModule(ebc)
for name, param in module.named_parameters():
    # Here, the parameters should still be FP32, as we are using a standard EBC
    # FP32 is default, regularly used for training
    print(name, param.shape, param.dtype)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 64]) torch.float32
ebc_.embedding_bags.user_table.weight torch.Size([4096, 64]) torch.float32

量化

如上所示,普通 EBC 包含嵌入表权重,如 FP32 精度(每个权重 32 位)。在这里,我们将使用 TorchRec 推理库,用于将模型的嵌入权重量化为 INT8

from torch import quantization as quant
from torchrec.modules.embedding_configs import QuantConfig
from torchrec.quant.embedding_modules import (
    EmbeddingBagCollection as QuantEmbeddingBagCollection,
)


quant_dtype = torch.int8


qconfig = QuantConfig(
    # dtype of the result of the embedding lookup, post activation
    # torch.float generally for compatibility with rest of the model
    # as rest of the model here usually isn't quantized
    activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
    # quantized type for embedding weights, aka parameters to actually quantize
    weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
    # Map of module type to qconfig
    torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
    # Map of module type to quantized module type
    torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}


module = InferenceModule(ebc)

# Quantize the module
qebc = quant.quantize_dynamic(
    module,
    qconfig_spec=qconfig_spec,
    mapping=mapping,
    inplace=False,
)


print(f"Quantized EBC: {qebc}")

kjt = kjt.to("cpu")

qebc(kjt)

# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
    # The shapes of the tables should be the same but the dtype should be int8 now
    # post quantization
    print(name, buffer.shape, buffer.dtype)
Quantized EBC: InferenceModule(
  (ebc_): QuantizedEmbeddingBagCollection(
    (_kjt_to_jt_dict): ComputeKJTToJTDict()
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
  )
)
ebc_.embedding_bags.product_table.weight torch.Size([4096, 80]) torch.uint8
ebc_.embedding_bags.user_table.weight torch.Size([4096, 80]) torch.uint8

碎片

在这里,我们执行 TorchRec 量化模型的分片。这是为了 确保我们通过 FBGEMM TBE 使用 performant 模块。在这里,我们 正在使用一种设备以与训练 (1 TBE) 保持一致。

from torchrec import distributed as trec_dist
from torchrec.distributed.shard import _shard_modules


sharded_qebc = _shard_modules(
    module=qebc,
    device=torch.device("cpu"),
    env=trec_dist.ShardingEnv.from_local(
        1,
        0,
    ),
)


print(f"Sharded Quantized EBC: {sharded_qebc}")

sharded_qebc(kjt)
WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.
Sharded Quantized EBC: InferenceModule(
  (ebc_): ShardedQuantEmbeddingBagCollection(
    (lookups):
     InferGroupedPooledEmbeddingsLookup()
    (_output_dists): ModuleList()
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
    (_input_dist_module): ShardedQuantEbcInputDist()
  )
)

<torchrec.sparse.jagged_tensor.KeyedTensor object at 0x7f4ed828c310>

汇编

现在我们有了优化的 Eager TorchRec 推理模型。下一步 是为了确保这个模型可以用 C++ 加载,因为目前它只是 在 Python 运行时中运行。

Meta 推荐的编译方法有两种:torch.fx tracing (生成 模型的中间表示形式),并将结果转换为 TorchScript,其中 TorchScript 与 C++ 兼容。

from torchrec.fx import Tracer


tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])

graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)

print("Graph Module Created!")

print(gm.code)

scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")

print(scripted_gm.code)
Graph Module Created!

torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embeddingbag_flatten_feature_lengths")
torch.fx._symbolic_trace.wrap("torchrec_fx_utils__fx_marker")
torch.fx._symbolic_trace.wrap("torchrec_distributed_quant_embedding_kernel__unwrap_kjt")
torch.fx._symbolic_trace.wrap("torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference")

def forward(self, kjt : torchrec_sparse_jagged_tensor_KeyedJaggedTensor):
    flatten_feature_lengths = torchrec_distributed_quant_embeddingbag_flatten_feature_lengths(kjt);  kjt = None
    _fx_marker = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_BEGIN', flatten_feature_lengths);  _fx_marker = None
    split = flatten_feature_lengths.split([2])
    getitem = split[0];  split = None
    to = getitem.to(device(type='cuda', index=0), non_blocking = True);  getitem = None
    _fx_marker_1 = torchrec_fx_utils__fx_marker('KJT_ONE_TO_ALL_FORWARD_END', flatten_feature_lengths);  flatten_feature_lengths = _fx_marker_1 = None
    _unwrap_kjt = torchrec_distributed_quant_embedding_kernel__unwrap_kjt(to);  to = None
    getitem_1 = _unwrap_kjt[0]
    getitem_2 = _unwrap_kjt[1]
    getitem_3 = _unwrap_kjt[2];  _unwrap_kjt = getitem_3 = None
    _tensor_constant0 = self._tensor_constant0
    _tensor_constant1 = self._tensor_constant1
    bounds_check_indices = torch.ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1, None);  _tensor_constant0 = _tensor_constant1 = bounds_check_indices = None
    _tensor_constant2 = self._tensor_constant2
    _tensor_constant3 = self._tensor_constant3
    _tensor_constant4 = self._tensor_constant4
    _tensor_constant5 = self._tensor_constant5
    _tensor_constant6 = self._tensor_constant6
    _tensor_constant7 = self._tensor_constant7
    _tensor_constant8 = self._tensor_constant8
    _tensor_constant9 = self._tensor_constant9
    int_nbit_split_embedding_codegen_lookup_function = torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(dev_weights = _tensor_constant2, uvm_weights = _tensor_constant3, weights_placements = _tensor_constant4, weights_offsets = _tensor_constant5, weights_tys = _tensor_constant6, D_offsets = _tensor_constant7, total_D = 128, max_int2_D = 0, max_int4_D = 0, max_int8_D = 64, max_float16_D = 0, max_float32_D = 0, indices = getitem_1, offsets = getitem_2, pooling_mode = 0, indice_weights = None, output_dtype = 0, lxu_cache_weights = _tensor_constant8, lxu_cache_locations = _tensor_constant9, row_alignment = 16, max_float8_D = 0, fp8_exponent_bits = -1, fp8_exponent_bias = -1);  _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = _tensor_constant7 = getitem_1 = getitem_2 = _tensor_constant8 = _tensor_constant9 = None
    embeddings_cat_empty_rank_handle_inference = torchrec_distributed_embedding_lookup_embeddings_cat_empty_rank_handle_inference([int_nbit_split_embedding_codegen_lookup_function], dim = 1, device = 'cuda:0', dtype = torch.float32);  int_nbit_split_embedding_codegen_lookup_function = None
    to_1 = embeddings_cat_empty_rank_handle_inference.to(device(type='cpu'));  embeddings_cat_empty_rank_handle_inference = None
    keyed_tensor = torchrec_sparse_jagged_tensor_KeyedTensor(keys = ['product', 'user'], length_per_key = [64, 64], values = to_1, key_dim = 1);  to_1 = None
    return keyed_tensor

/usr/local/lib/python3.10/dist-packages/torch/jit/_check.py:178: UserWarning:

The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.

Scripted Graph Module Created!
def forward(self,
    kjt: __torch__.torchrec.sparse.jagged_tensor.KeyedJaggedTensor) -> __torch__.torchrec.sparse.jagged_tensor.KeyedTensor:
  _0 = __torch__.torchrec.distributed.quant_embeddingbag.flatten_feature_lengths
  _1 = __torch__.torchrec.fx.utils._fx_marker
  _2 = __torch__.torchrec.distributed.quant_embedding_kernel._unwrap_kjt
  _3 = __torch__.torchrec.distributed.embedding_lookup.embeddings_cat_empty_rank_handle_inference
  flatten_feature_lengths = _0(kjt, )
  _fx_marker = _1("KJT_ONE_TO_ALL_FORWARD_BEGIN", flatten_feature_lengths, )
  split = (flatten_feature_lengths).split([2], )
  getitem = split[0]
  to = (getitem).to(torch.device("cuda", 0), True, None, )
  _fx_marker_1 = _1("KJT_ONE_TO_ALL_FORWARD_END", flatten_feature_lengths, )
  _unwrap_kjt = _2(to, )
  getitem_1 = (_unwrap_kjt)[0]
  getitem_2 = (_unwrap_kjt)[1]
  _tensor_constant0 = self._tensor_constant0
  _tensor_constant1 = self._tensor_constant1
  ops.fbgemm.bounds_check_indices(_tensor_constant0, getitem_1, getitem_2, 1, _tensor_constant1)
  _tensor_constant2 = self._tensor_constant2
  _tensor_constant3 = self._tensor_constant3
  _tensor_constant4 = self._tensor_constant4
  _tensor_constant5 = self._tensor_constant5
  _tensor_constant6 = self._tensor_constant6
  _tensor_constant7 = self._tensor_constant7
  _tensor_constant8 = self._tensor_constant8
  _tensor_constant9 = self._tensor_constant9
  int_nbit_split_embedding_codegen_lookup_function = ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function(_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, _tensor_constant7, 128, 0, 0, 64, 0, 0, getitem_1, getitem_2, 0, None, 0, _tensor_constant8, _tensor_constant9, 16)
  _4 = [int_nbit_split_embedding_codegen_lookup_function]
  embeddings_cat_empty_rank_handle_inference = _3(_4, 1, "cuda:0", 6, )
  to_1 = torch.to(embeddings_cat_empty_rank_handle_inference, torch.device("cpu"))
  _5 = ["product", "user"]
  _6 = [64, 64]
  keyed_tensor = __torch__.torchrec.sparse.jagged_tensor.KeyedTensor.__new__(__torch__.torchrec.sparse.jagged_tensor.KeyedTensor)
  _7 = (keyed_tensor).__init__(_5, _6, to_1, 1, None, None, )
  return keyed_tensor

结论

在本教程中,您已完成分布式 RecSys 模型的训练 使其准备好进行推理。TorchRec 存储库有一个 如何将 TorchRec TorchScript 模型加载到 C++ 中的完整示例 推理。

有关更多信息,请参阅我们的 dlrm 示例,其中包括 Criteo 1TB 上的多节点训练 数据集,使用深度学习推荐模型中描述的方法 用于个性化和推荐系统

脚本总运行时间:(0 分 0.767 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源