目录

数据类型

TorchRec 包含用于表示嵌入的数据类型,也称为稀疏特征。 稀疏特征通常是要馈送到嵌入表中的索引。对于给定的 batch,则嵌入查找索引的数量是可变的。因此,需要交维度来表示批次的可变嵌入查找索引量。

本节介绍用于表示稀疏特征的 3 种 TorchRec 数据类型的类: JaggedTensorKeyedJaggedTensorKeyedTensor

torchrec.sparse.jagged_tensor 类JaggedTensor*args**kwargs)

表示 (可选加权的) 交错张量。

JaggedTensor 是具有交错维度的张量,该维度是其 切片的长度可能不同。有关完整示例,请参阅 KeyedJaggedTensor

实现是 torch.jit.script-able 的。

注意

我们不会进行输入验证,因为它很昂贵,您应该始终传入 有效长度、偏移量等

参数
  • torch.Tensor) – 以密集表示形式值 tensor。

  • weights可选[torch.Tensor]) – 如果值具有权重。具有相同形状的 Tensor 作为值。

  • lengths可选[torch.Tensor]) – 锯齿状切片,以长度表示。

  • offsets可选[torch.Tensor]) – 锯齿状切片,表示为累积 补偿。

device 设备

获取 JaggedTensor 设备。

结果

值 Tensor 的装置。

返回类型

torch.device

static emptyis_weighted bool = Falsedevice 可选[device] = values_dtype:可选[dtype] = weights_dtype:可选[dtype] = lengths_dtypedtype = torch.int32 JaggedTensor

构造一个空的 JaggedTensor。

参数
  • is_weightedbool) – JaggedTensor 是否具有权重。

  • deviceOptional[torch.device]) – JaggedTensor 的设备。

  • values_dtypeOptional[torch.dtype]) – 值的 dtype。

  • weights_dtypeOptional[torch.dtype]) – 权重的 dtype。

  • lengths_dtypetorch.dtype) – 长度的 dtype。

结果

空 JaggedTensor 的 JaggedTensor 中。

返回类型

JaggedTensor 的

static from_dense List[Tensor]weights Optional[List[Tensor]] = JaggedTensor

从张量列表中构造 JaggedTensor 作为值,具有可选的权重。长度,形状为 (B,),其中 B 是 len(值),其中 表示批量大小。

参数
  • List[torch.Tensor]) – 用于密集表示的张量列表

  • weights可选[List[torch.Tensor]]) – 如果值具有权重,则 与 values 的形状相同。

结果

JaggedTensor 从 2D 密集张量创建。

返回类型

JaggedTensor 的

例:

values = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
weights = [
    torch.Tensor([1.0]),
    torch.Tensor(),
    torch.Tensor([7.0, 8.0]),
    torch.Tensor([10.0, 11.0, 12.0]),
]
j1 = JaggedTensor.from_dense(
    values=values,
    weights=weights,
)

# j1 = [[1.0], [], [7.0, 8.0], [10.0, 11.0, 12.0]]
static from_dense_lengths Tensorlengths Tensorweights 可选[Tensor] = JaggedTensor

从值和长度张量构造 JaggedTensor,具有可选的权重。 请注意,lengths 的形状仍然是 (B,),其中 B 是批量大小。

参数
  • torch.Tensor) —— 值的密集表示。

  • 长度割torch。Tensor) – 锯齿状切片,以长度表示。

  • weights可选[torch.Tensor]) – 如果值具有权重,则 与 values 的形状相同。

结果

JaggedTensor 从 2D 密集张量创建。

返回类型

JaggedTensor 的

lengths 张量

获取 JaggedTensor 长度。如果未计算,则根据偏移量计算。

结果

lengths 张量。

返回类型

torch.Tensor

lengths_or_none 可选[Tensor]

获取 JaggedTensor 长度。如果未计算,则返回 None。

结果

lengths 张量。

返回类型

可选[torch.张量]

offsets Tensor

获取 JaggedTensor 偏移量。如果未计算,则根据 lengths 计算。

结果

offsets 张量。

返回类型

torch.Tensor

offsets_or_none 可选[Tensor]

获取 JaggedTensor 偏移量。如果未计算,则返回 None。

结果

offsets 张量。

返回类型

可选[torch.张量]

record_streamstream Stream

查看 https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

todevicedevice non_blocking: bool = False JaggedTensor

将 JaggedTensor 移动到指定的设备。

参数
  • devicetorch.device) – 要移动到的设备。

  • non_blockingbool) – 是否异步执行复制。

结果

移动的 JaggedTensor 中。

返回类型

JaggedTensor 的

to_dense List[张量]

构造 JT 值的密集表示。

结果

张量列表。

返回类型

列表[torch.张量]

例:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

values_list = jt.to_dense()

# values_list = [
#     torch.tensor([1.0, 2.0]),
#     torch.tensor([]),
#     torch.tensor([3.0]),
#     torch.tensor([4.0]),
#     torch.tensor([5.0]),
#     torch.tensor([6.0, 7.0, 8.0]),
# ]
to_dense_weights 可选[List[Tensor]]

构造 JT 权重的密集表示。

结果

张量列表,如果没有权重,则为 None

返回类型

可选[List[torch.张量]]

例:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

weights_list = jt.to_dense_weights()

# weights_list = [
#     torch.tensor([0.1, 0.2]),
#     torch.tensor([]),
#     torch.tensor([0.3]),
#     torch.tensor([0.4]),
#     torch.tensor([0.5]),
#     torch.tensor([0.6, 0.7, 0.8]),
# ]
to_padded_densedesired_length 可选[int] = padding_value: 浮点= 0.0 张量

根据 JT 的形状值 (B, N,) 构造一个 2D 密集张量。

请注意,B 是 self.lengths() 的长度,N 是最长的特征长度或 desired_length

如果长度>desired_length,我们将用 padding_value 填充,否则我们将填充 将选择 desired_length 处的最后一个值。

参数
  • desired_lengthint) – 张量的长度。

  • padding_valuefloat) – 如果我们需要填充,则填充值。

结果

2D 密集张量。

返回类型

torch.Tensor

例:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, offsets=offsets)

dt = jt.to_padded_dense(
    desired_length=2,
    padding_value=10.0,
)

# dt = [
#     [1.0, 2.0],
#     [10.0, 10.0],
#     [3.0, 10.0],
#     [4.0, 10.0],
#     [5.0, 10.0],
#     [6.0, 7.0],
# ]
to_padded_dense_weightsdesired_length 可选[int] = padding_value: 浮点= 0,0 可选[Tensor]

从 JT 的形状权重 (B, N,) 构造一个 2D 密集张量。

请注意,B (batch size) 是 self.lengths() 的长度,N 是最长的特征长度或 desired_length

如果长度>desired_length,我们将用 padding_value 填充,否则我们将填充 将选择 desired_length 处的最后一个值。

to_padded_dense 类似,但用于 JT 的权重而不是值。

参数
  • desired_lengthint) – 张量的长度。

  • padding_valuefloat) – 如果我们需要填充,则填充值。

结果

2d 密集张量,如果没有权重,则为 None

返回类型

可选[torch.张量]

例:

values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
weights = torch.Tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8])
jt = JaggedTensor(values=values, weights=weights, offsets=offsets)

d_wt = jt.to_padded_dense_weights(
    desired_length=2,
    padding_value=1.0,
)

# d_wt = [
#     [0.1, 0.2],
#     [1.0, 1.0],
#     [0.3, 1.0],
#     [0.4, 1.0],
#     [0.5, 1.0],
#     [0.6, 0.7],
# ]
values Tensor

获取 JaggedTensor 值。

结果

值 Tensor。

返回类型

torch.Tensor

weights 张量

获取 JaggedTensor 权重。如果为 None,则引发错误。

结果

weights 张量。

返回类型

torch.Tensor

weights_or_none 可选[Tensor]

获取 JaggedTensor 权重。如果为 None,则返回 None。

结果

weights 张量。

返回类型

可选[torch.张量]

torchrec.sparse.jagged_tensor 类KeyedJaggedTensor*args**kwargs)

表示一个 (可选加权的) 键控交错张量。

KeyedJaggedTensor 是具有交错维度的张量,该维度是其 切片的长度可能不同。在第一个维度上键入,在最后一个维度上呈锯齿状 尺寸。

实现是 torch.jit.script-able 的。

参数
  • keysList[str]) - 交错 Tensor 的键。

  • torch.Tensor) – 以密集表示形式值 tensor。

  • weights可选[torch.Tensor]) – 如果值具有权重。Tensor 替换为 形状与 Values 相同。

  • lengths可选[torch.Tensor]) – 锯齿状切片,以长度表示。

  • offsets可选[torch.Tensor]) – 锯齿状切片,表示为累积 补偿。

  • strideOptional[int]) – 每批样本数。

  • stride_per_key_per_rankOptional[List[List[int]]] – 批量大小 (样本数),外部列表表示 keys 和表示值的内部列表。 内部列表中的每个值都表示批处理中的示例数 从其索引在分布式上下文中的排名。

  • length_per_keyOptional[List[int]]) – 每个键的开始长度。

  • offset_per_keyOptional[List[int]]) – 每个键和最终的起始偏移量 抵消。

  • index_per_keyOptional[Dict[strint]]) – 每个键的索引。

  • jt_dictOptional[Dict[strJaggedTensor]]) – JaggedTensors 的键字典。 允许使 to_dict() 惰性/可缓存的能力。

  • inverse_indices可选[Tuple[List[str]torch.Tensor]]) – 反向索引 展开 Deduplicated Detiled Embedding Output for Variable stride per key(每个键的可变步幅的重复数据删除嵌入输出)。

例:

#              0       1        2  <-- dim_1
# "Feature0"   [V0,V1] None    [V2]
# "Feature1"   [V3]    [V4]    [V5,V6,V7]
#   ^
#  dim_0

dim_0: keyed dimension (ie. `Feature0`, `Feature1`)
dim_1: optional second dimension (ie. batch size)
dim_2: The jagged dimension which has slice lengths between 0-3 in the above example

# We represent this data with following inputs:

values: torch.Tensor = [V0, V1, V2, V3, V4, V5, V6, V7]  # V == any tensor datatype
weights: torch.Tensor = [W0, W1, W2, W3, W4, W5, W6, W7]  # W == any tensor datatype
lengths: torch.Tensor = [2, 0, 1, 1, 1, 3]  # representing the jagged slice
offsets: torch.Tensor = [0, 2, 2, 3, 4, 5, 8]  # offsets from 0 for each jagged slice
keys: List[str] = ["Feature0", "Feature1"]  # correspond to each value of dim_0
index_per_key: Dict[str, int] = {"Feature0": 0, "Feature1": 1}  # index for each key
offset_per_key: List[int] = [0, 3, 8]  # start offset for each key and final offset
static concatkjt_list: List[KeyedJaggedTensor] KeyedJaggedTensor

将 KeyedJaggedTensor 列表连接成一个 KeyedJaggedTensor。

参数

kjt_listList[KeyedJaggedTensor]) - 要连接的 KeyedJaggedTensor 列表。

结果

串联的 KeyedJaggedTensor 中。

返回类型

KeyedJaggedTensor

device 设备

返回 KeyedJaggedTensor 的装置。

结果

keyedJaggedTensor 的装置。

返回类型

torch.device

static emptyis_weighted bool = Falsedevice 可选[device] = values_dtype:可选[dtype] = weights_dtype:可选[dtype] = lengths_dtype:dtype = torch.int32 KeyedJaggedTensor

构造一个空的 KeyedJaggedTensor。

参数
  • is_weightedbool) – KeyedJaggedTensor 是否加权。

  • deviceOptional[torch.device]) – 将放置 KeyedJaggedTensor 的设备。

  • values_dtypeOptional[torch.dtype]) - 值张量的 dtype。

  • weights_dtypeOptional[torch.dtype]) - 权重张量的 dtype。

  • lengths_dtypetorch.dtype) - 长度张量的 dtype。

结果

空 KeyedJaggedTensor 的 KeyedJaggedTensor 中。

返回类型

KeyedJaggedTensor

static empty_likekjt KeyedJaggedTensor KeyedJaggedTensor

构造一个空的 KeyedJaggedTensor,其设备和 dtype 与输入 KeyedJaggedTensor 相同。

参数

kjtKeyedJaggedTensor) – 输入 KeyedJaggedTensor。

结果

空 KeyedJaggedTensor 的 KeyedJaggedTensor 中。

返回类型

KeyedJaggedTensor

static from_jt_dictjt_dict Dict[str JaggedTensor] KeyedJaggedTensor

从 JaggedTensors 的字典构造 KeyedJaggedTensor。 在新创建的 KJT 上自动调用 kjt.sync()。

注意

仅当 JaggedTensors 全部 具有相同的 “隐式” batch_size 维度。

基本上,我们可以将 JaggedTensors 可视化为 2-D 张量 的格式为 [batch_size x variable_feature_dim]。 在这种情况下,我们有一些没有 feature 值的 batch, 输入 JaggedTensor 可以不包含任何值。

但 KeyedJaggedTensor(默认情况下)通常会填充 “None” 因此,所有 JaggedTensor 都存储在 KeyedJaggedTensor 中 具有相同的 batch_size 维度。也就是说,在这种情况下, JaggedTensor 输入没有自动填充 对于空批处理,此函数将错误/不起作用。

考虑以下 KeyedJaggedTensor 的可视化: # 0 1 2 <– dim_1 # “feature0” [V0,V1] 无 [V2] # “特性 1” [V3] [V4] [V5,V6,V7] # ^ # dim_0

现在,如果输入 jt_dict = {

# “特性 0” [V0,V1] [V2] # “特性 1” [V3] [V4] [V5,V6,V7]

} 并且每个 JaggedTensor 中省略了 “None”, 那么这个函数会失败,因为我们不能正确地 能够填充 “None”,因为它在技术上不知道 在 JaggedTensor 中填充的正确批处理 / 位置。

本质上,该函数推断的 Tensor 长度 将为 [2, 1, 1, 1, 3] 表示变量 batch_size dim_1 违反了现有的假设/前提条件 KeyedJaggedTensor 的维度应该是固定的batch_size。

参数

jt_dictDict[strJaggedTensor]) - JaggedTensor 的字典。

结果

构造了 KeyedJaggedTensor。

返回类型

KeyedJaggedTensor

static from_lengths_sync List[str] Tensor长度 Tensorweights 可选[Tensor] = None步幅 Optional[int] = stride_per_key_per_rank:可选[List[List[int]]] = inverse_indices 可选[Tuple[List[str] Tensor]] = None KeyedJaggedTensor

从 key、length 和 offset 的列表构造 KeyedJaggedTensor。 与 from_offsets_sync 相同,但使用的是 length 而不是 offset。

参数
  • keysList[str]) - 键列表。

  • torch.Tensor) – 以密集表示形式值 tensor。

  • 长度割torch。Tensor) – 锯齿状切片,以长度表示。

  • weights可选[torch.Tensor]) – 如果值具有权重。Tensor 替换为 形状与 Values 相同。

  • strideOptional[int]) – 每批样本数。

  • stride_per_key_per_rankOptional[List[List[int]]] – 批量大小 (样本数),外部列表表示 keys 和表示值的内部列表。

  • inverse_indices可选[Tuple[List[str]torch.Tensor]]) – 反向索引 展开 Deduplicated Detiled Embedding Output for Variable stride per key(每个键的可变步幅的重复数据删除嵌入输出)。

结果

构造了 KeyedJaggedTensor。

返回类型

KeyedJaggedTensor

static from_offsets_sync List[str] Tensor偏移量: Tensorweights 可选[Tensor] = None步幅 Optional[int] = stride_per_key_per_rank:可选[List[List[int]]] = inverse_indices 可选[Tuple[List[str] Tensor]] = None KeyedJaggedTensor

从键、值和偏移量的列表构造 KeyedJaggedTensor。

参数
  • keysList[str]) - 键列表。

  • torch.Tensor) – 以密集表示形式值 tensor。

  • 偏移量割torch.Tensor) – 锯齿状切片,表示为累积偏移量。

  • weights可选[torch.Tensor]) – 如果值具有权重。Tensor 替换为 形状与 Values 相同。

  • strideOptional[int]) – 每批样本数。

  • stride_per_key_per_rankOptional[List[List[int]]] – 批量大小 (样本数),外部列表表示 keys 和表示值的内部列表。

  • inverse_indices可选[Tuple[List[str]torch.Tensor]]) – 反向索引 展开 Deduplicated Detiled Embedding Output for Variable stride per key(每个键的可变步幅的重复数据删除嵌入输出)。

结果

构造了 KeyedJaggedTensor。

返回类型

KeyedJaggedTensor

index_per_key Dict[str int]

返回 KeyedJaggedTensor 的每个键的索引。

结果

keyedJaggedTensor 的每个键的索引。

返回类型

字典 [str, int]

inverse_indices 元组[List[str] Tensor]

返回 KeyedJaggedTensor 的逆索引。 如果 inverse indices 为 None,这将引发错误。

结果

KeyedJaggedTensor 的反向索引。

返回类型

Tuple[List[str], torch 的 .张量]

inverse_indices_or_none 可选[tuple[list[str] tensor]]

返回 KeyedJaggedTensor 或 None 的反向索引(如果它们不存在)。

结果

KeyedJaggedTensor 的反向索引。

返回类型

可选[Tuple[List[str], torch.张量]]

keys List[str]

返回 KeyedJaggedTensor 的键。

结果

keyedJaggedTensor 的 key 中。

返回类型

列表[str]

length_per_key 列表[int]

返回 KeyedJaggedTensor 的每个键的长度。 如果 length per key 为 None,则将计算它。

结果

keyedJaggedTensor 的每个键的长度。

返回类型

列表[int]

length_per_key_or_none 可选[List[int]]

返回 KeyedJaggedTensor 的每个键的长度,如果尚未计算,则返回 None。

结果

keyedJaggedTensor 的每个键的长度。

返回类型

列表[int]

lengths 张量

返回 KeyedJaggedTensor 的长度。 如果尚未计算长度,它将计算它们。

结果

KeyedJaggedTensor 的长度。

返回类型

torch.Tensor

lengths_offset_per_key List[int]

返回 KeyedJaggedTensor 的每个键的长度偏移量。 如果 lengths offset per key 为 None,则将计算它。

结果

length,每个 KeyedJaggedTensor 的 key 偏移量。

返回类型

列表[int]

lengths_or_none 可选[Tensor]

返回 KeyedJaggedTensor 的长度,如果尚未计算,则返回 None。

结果

KeyedJaggedTensor 的长度。

返回类型

torch.Tensor

offset_per_key 列表[int]

返回 KeyedJaggedTensor 的每个键的偏移量。 如果 offset per key 为 None,则将计算它。

结果

keyedJaggedTensor 的每个键的 offset。

返回类型

列表[int]

offset_per_key_or_none 可选[List[int]]

返回 KeyedJaggedTensor 的每个键的偏移量,如果尚未计算,则返回 None。

结果

keyedJaggedTensor 的每个键的 offset。

返回类型

列表[int]

offsets Tensor

返回 KeyedJaggedTensor 的偏移量。 如果尚未计算偏移量,它将计算它们。

结果

offsets 的 keyedJaggedTensor 的 offsets 中。

返回类型

torch.Tensor

offsets_or_none 可选[Tensor]

返回 KeyedJaggedTensor 或 None 的偏移量(如果尚未计算)。

结果

offsets 的 keyedJaggedTensor 的 offsets 中。

返回类型

torch.Tensor

permuteindices List[int]indices_tensor: 可选[Tensor] = KeyedJaggedTensor

置换 KeyedJaggedTensor。

参数
  • indicesList[int]) – 索引列表。

  • indices_tensor可选[torch.Tensor]) – 索引的张量。

结果

排列 KeyedJaggedTensor。

返回类型

KeyedJaggedTensor

record_streamstream Stream

查看 https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

splitsegments List[int] List[KeyedJaggedTensor]

将 KeyedJaggedTensor 拆分为 KeyedJaggedTensor 列表。

参数

segmentsList[int]) – 段列表。

结果

KeyedJaggedTensor 的列表。

返回类型

列表[KeyedJaggedTensor]

stride() int

返回 KeyedJaggedTensor 的步幅。 如果 stride 为 None,则将计算它。

结果

keyedJaggedTensor 的 stride 中。

返回类型

int

stride_per_key 列表[int]

返回 KeyedJaggedTensor 的每个键的步幅。 如果 stride per key 为 None,则将计算它。

结果

keyedJaggedTensor 的每个键的 stride 。

返回类型

列表[int]

stride_per_key_per_rank List[List[int]]

返回 KeyedJaggedTensor 的每个 rank 的每个 key 的步幅。

结果

stride per key 的 keyedJaggedTensor 的秩。

返回类型

列表[List[int]]

sync KeyedJaggedTensor

通过计算 offset_per_key 和 length_per_key 来同步 KeyedJaggedTensor。

结果

synced KeyedJaggedTensor 的 KeyedJaggedTensor 中。

返回类型

KeyedJaggedTensor

todevice devicenon_blocking bool = Falsedtype 可选[dtype] = KeyedJaggedTensor

返回指定设备和 dtype 中的 KeyedJaggedTensor 副本。

参数
  • devicetorch.device) – 副本的所需设备。

  • non_blockingbool) – 是否以非阻塞方式复制张量。

  • dtypeOptional[torch.dtype]) – 副本所需的数据类型。

结果

复制的 KeyedJaggedTensor。

返回类型

KeyedJaggedTensor

to_dict Dict[str JaggedTensor]

返回每个键的 JaggedTensor 字典。 将缓存结果self._jt_dict。

结果

每个键的 JaggedTensor 字典。

返回类型

Dict[str, JaggedTensor]

unsync() KeyedJaggedTensor

通过清除 KeyedJaggedTensor 来取消同步 offset_per_key 和 length_per_key。

结果

unsynced KeyedJaggedTensor 的 KeyedJaggedTensor 中。

返回类型

KeyedJaggedTensor

values Tensor

返回 KeyedJaggedTensor 的值。

结果

KeyedJaggedTensor 的值。

返回类型

torch.Tensor

variable_stride_per_key bool

返回 KeyedJaggedTensor 是否具有每个键的可变步幅。

结果

KeyedJaggedTensor 是否具有每个键的可变步幅。

返回类型

布尔

weights 张量

返回 KeyedJaggedTensor 的权重。 如果 weights 为 None,则会引发错误。

结果

weights 的 weights。

返回类型

torch.Tensor

weights_or_none 可选[Tensor]

返回 KeyedJaggedTensor 的权重,如果不存在,则返回 None。

结果

weights 的 weights。

返回类型

torch.Tensor

torchrec.sparse.jagged_tensor 类KeyedTensor*args**kwargs)

KeyedTensor 保存密集张量的串联列表,每个张量都可以是 通过密钥访问。

键控维度可以是可变长度 (length_per_key)。 常见用例用途包括存储不同维度的池化嵌入。

实现是 torch.jit.script-able 的。

参数
  • keysList[str]) - 键列表。

  • length_per_keyList[int]) – 沿 key 维度的每个 key 的长度。

  • torch.Tensor) – 密集张量,通常沿键维度连接。

  • key_dimint) – 键维度,零索引 – 默认为 1 (通常 B 为 0 维)。

例:

# kt is KeyedTensor holding

#                         0           1           2
#     "Embedding A"    [1,1]       [1,1]        [1,1]
#     "Embedding B"    [2,1,2]     [2,1,2]      [2,1,2]
#     "Embedding C"    [3,1,2,3]   [3,1,2,3]    [3,1,2,3]

tensor_list = [
    torch.tensor([[1,1]] * 3),
    torch.tensor([[2,1,2]] * 3),
    torch.tensor([[3,1,2,3]] * 3),
]

keys = ["Embedding A", "Embedding B", "Embedding C"]

kt = KeyedTensor.from_tensor_list(keys, tensor_list)

kt.values()
# torch.Tensor(
#     [
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#         [1, 1, 2, 1, 2, 3, 1, 2, 3],
#     ]
# )

kt["Embedding B"]
# torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
device 设备
结果

value tensor 的 device 的 device 的 Tensor 的 Tensor 中

返回类型

torch.device

static from_tensor_list List[str]张量 List[Tensor]key_dim int = 1,cat_dim:int = 1 KeyedTensor

从张量列表创建 KeyedTensor。张量是串联的 沿着cat_dim。这些键用于为张量编制索引。

参数
  • keysList[str]) - 键列表。

  • 张量List[torch.Tensor]) – 张量列表。

  • key_dimint) – 键维度,零索引 – 默认为 1 (通常 B 为 0 维)。

  • cat_dimint) – 连接张量的维度 - 默认值

结果

键控张量。

返回类型

KeyedTensor 的

key_dim int
结果

键维度,零索引 - 通常 B 为 0 维度。

返回类型

int

keys List[str]
结果

键列表。

返回类型

列表[str]

length_per_key 列表[int]
结果

沿 Key 维度的每个 Key 的长度。

返回类型

列表[int]

offset_per_key 列表[int]

获取每个键沿键维度的偏移量。 Compute 和 cache(如果尚未计算)。

结果

每个键沿键维度的偏移量。

返回类型

列表[int]

record_streamstream Stream

查看 https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html

static regroupkeyed_tensors: List[KeyedTensor] List[List[str]] List[Tensor]

将 KeyedTensor 列表重新分组为张量列表。

参数
  • keyed_tensorsList[KeyedTensor]) - KeyedTensor 列表。

  • groupsList[List[str]]) – 键组列表。

结果

张量列表。

返回类型

列表[torch.张量]

static regroup_as_dictkeyed_tensors List[KeyedTensor] List[List[str]]List[str] Dict[str Tensor]

将 KeyedTensor 列表重新组合到张量字典中。

参数
  • keyed_tensorsList[KeyedTensor]) - KeyedTensor 列表。

  • groupsList[List[str]]) – 键组列表。

  • keysList[str]) - 键列表。

结果

张量字典。

返回类型

Dict[str, Torch.张量]

todevicedevice non_blocking bool = False KeyedTensor

将 values tensor 移动到指定的设备。

参数
  • devicetorch.device) - 将值张量移动到的设备。

  • non_blockingbool) – 是否异步执行作 (默认值:False)。

结果

值为 Tensor 的键控张量移动到指定设备。

返回类型

KeyedTensor 的

to_dict Dict[str Tensor]
结果

按键键的张量字典。

返回类型

Dict[str, Torch.张量]

values Tensor

获取 values 张量。

结果

密集张量,通常沿键维度连接。

返回类型

torch.Tensor

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源