模块¶
标准的 TorchRec 模块表示嵌入表集合:
EmbeddingBagCollection是一个包含torch.nn.EmbeddingBagEmbeddingCollection是一个包含torch.nn.Embedding
这些模块通过标准化配置类构建而成:
EmbeddingBagConfigforEmbeddingBagCollectionEmbeddingConfigforEmbeddingCollection
- class torchrec.modules.embedding_configs.EmbeddingBagConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False, pooling: ~torchrec.modules.embedding_configs.PoolingType = PoolingType.SUM)¶
Bases:
BaseEmbeddingConfigEmbeddingBagConfig 是一个表示单个嵌入表的数据类,输出是用于池化的。
- Parameters:
池化 (PoolingType) – 池化类型。
- class torchrec.modules.embedding_configs.EmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)¶
Bases:
BaseEmbeddingConfigEmbeddingConfig 是一个表示单个嵌入表的数据类。
- class torchrec.modules.embedding_configs.BaseEmbeddingConfig(num_embeddings: int, embedding_dim: int, name: str = '', data_type: ~torchrec.types.DataType = DataType.FP32, feature_names: ~typing.List[str] = <factory>, weight_init_max: ~typing.Optional[float] = None, weight_init_min: ~typing.Optional[float] = None, num_embeddings_post_pruning: ~typing.Optional[int] = None, init_fn: ~typing.Optional[~typing.Callable[[~torch.Tensor], ~typing.Optional[~torch.Tensor]]] = None, need_pos: bool = False)¶
基础类用于嵌入配置。
- Parameters:
num_embeddings (int) – 数量的嵌入。
embedding维度 (整数) – embedding维度。
名称 (字符串) – 嵌入表的名称。
数据类型 (Data类型) – 嵌入表的数据类型。
feature_names (List[str]) – 列表中的特征名称。
weight_init_max (Optional[float]) – 最大值用于权重初始化。
weight_init_min (Optional[float]) – 初始化权重的最小值。
num_embeddings_post_pruning (Optional[int]) – 排序后推理时的嵌入数量。 如果为None,则不进行排序。
init_fn (Optional[Callable[[torch.Tensor], Optional[torch.Tensor]]]) – 初始化函数,用于初始化嵌入权重。
需要位置 (bool) – 是否为表格进行位置加权。
- class torchrec.modules.embedding_modules.EmbeddingBagCollection(tables: List[EmbeddingBagConfig], is_weighted: bool = False, device: Optional[device] = None)¶
EmbeddingBagCollection代表一个池化嵌入集合(EmbeddingBags)。
注意
嵌入袋集合是一个未分块的模块,且没有进行性能优化。对于对性能敏感的场景,请考虑使用分块版本 ShardedEmbeddingBagCollection。
它是可以调用的,参数代表稀疏数据的形式为KeyedJaggedTensor,值的形状为(F, B, L[f][i]),其中:
F: 特征数量(键)
B: 批次大小
L[f][i]: 空间特征的长度(每个特征 f 和批次索引 i 可能不同,即锯齿状)
并输出一个 KeyedTensor,其中值的形状为 (B, D):
B: 批次大小
D: 所有嵌入表的嵌入维度之和,即 sum([config.embedding_dim for config in tables])
假设参数是一个 KeyedJaggedTensor 或 J,具有 F 个特征,批量大小 B 和 L[f][i] 个稀疏长度,使得 J[f][i] 是特征 f 的包和批次索引 i,输出 KeyedTensor 和 KT 定义如下: KT[i] = torch.cat([emb[f](J[f][i]) for f in J.keys()]),其中 emb[f] 是对应于特征 f 的 EmbeddingBag。
注意,J[f][i]是一个可变长度的整数值列表(一个袋子),而emb[f](J[f][i])是通过将J[f][i]中每个值的嵌入进行聚合得到的池化嵌入(默认模式为平均)。
- Parameters:
表格 (List[EmbeddingBagConfig]) – 列表中的嵌套表格。
是加权 (bool) – 是否输入 KeyedJaggedTensor 加权。
设备 (可选[torch.device]) – 默认计算设备。
Example:
table_0 = EmbeddingBagConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) table_1 = EmbeddingBagConfig( name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] ) ebc = EmbeddingBagCollection(tables=[table_0, table_1]) # i = 0 i = 1 i = 2 <-- batch indices # "f1" [0,1] None [2] # "f2" [3] [4] [5,6,7] # ^ # features features = KeyedJaggedTensor( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, # feature 'f1' 3, 4, 5, 6, 7]), # feature 'f2' # i = 1 i = 2 i = 3 <--- batch indices offsets=torch.tensor([ 0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3] 3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8] ) pooled_embeddings = ebc(features) print(pooled_embeddings.values()) tensor([ # f1 pooled embeddings f2 pooled embeddings # from bags (dim. 3) from bags (dim. 4) [-0.8899, -0.1342, -1.9060, -0.0905, -0.2814, -0.9369, -0.7783], # i = 0 [ 0.0000, 0.0000, 0.0000, 0.1598, 0.0695, 1.3265, -0.1011], # i = 1 [-0.4256, -1.1846, -2.1648, -1.0893, 0.3590, -1.9784, -0.7681]], # i = 2 grad_fn=<CatBackward0>) print(pooled_embeddings.keys()) ['f1', 'f2'] print(pooled_embeddings.offset_per_key()) tensor([0, 3, 7]) # embeddings have dimensions 3 and 4, so embeddings are at [0, 3) and [3, 7).
- property device: device¶
返回: torch.device:计算设备。
- embedding_bag_configs() List[EmbeddingBagConfig]¶
- Returns:
嵌入袋配置。
- Return type:
List[EmbeddingBagConfig]
- forward(features: KeyedJaggedTensor) KeyedTensor¶
运行嵌入袋集合的前向传播。此方法接受一个 KeyedJaggedTensor 并返回一个 KeyedTensor,这是为每个特征池化嵌入的结果。
- Parameters:
功能 (KeyedJaggedTensor) – 输入 KJT
- Returns:
KeyedTensor
- is_weighted() bool¶
- Returns:
是否使用权重计算EmbeddingBagCollection。
- Return type:
布尔
- reset_parameters() None¶
重置 EmbeddingBagCollection 的参数。参数值基于每个 EmbeddingBagConfig 的 init_fn 初始化。
- class torchrec.modules.embedding_modules.EmbeddingCollection(tables: List[EmbeddingConfig], device: Optional[device] = None, need_indices: bool = False)¶
嵌入集合表示非池化嵌入的集合。
注意
嵌入集合是一个未分块的模块,且没有进行性能优化。对于对性能敏感的场景,请考虑使用分块版本 ShardedEmbeddingCollection。
它是可以调用的,参数代表稀疏数据的形式为KeyedJaggedTensor,值的形状为(F, B, L[f][i]),其中:
F: 特征数量(键)
B: 批次大小
L[f][i]: 空间特征的长度(每个特征 f 和批次索引 i 可能不同,即锯齿状)
并输出一个类型为 Dict[Feature, JaggedTensor] 的 result, 其中 result[f] 是一个形状为 (EB[f], D[f]) 的 JaggedTensor:
EB[f]: 一个“扩展的批量大小”用于特征 f 等于其袋值长度之和, 也就是说,sum([len(J[f][i]) for i in range(B)])。
D[f]: 是特征 f 的嵌入维度。
- Parameters:
表格 (List[EmbeddingConfig]) – 列表中的嵌入表。
设备 (可选[torch.device]) – 默认计算设备。
需要索引 (bool) – 如果我们需要将索引传递给最终查找字典。
Example:
e1_config = EmbeddingConfig( name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] ) e2_config = EmbeddingConfig( name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] ) ec = EmbeddingCollection(tables=[e1_config, e2_config]) # 0 1 2 <-- batch # 0 [0,1] None [2] # 1 [3] [4] [5,6,7] # ^ # feature features = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2"], values=torch.tensor([0, 1, 2, # feature 'f1' 3, 4, 5, 6, 7]), # feature 'f2' # i = 1 i = 2 i = 3 <--- batch indices offsets=torch.tensor([ 0, 2, 2, # 'f1' bags are values[0:2], values[2:2], and values[2:3] 3, 4, 5, 8]), # 'f2' bags are values[3:4], values[4:5], and values[5:8] ) feature_embeddings = ec(features) print(feature_embeddings['f2'].values()) tensor([ # embedding for value 3 in f2 bag values[3:4]: [-0.2050, 0.5478, 0.6054], # embedding for value 4 in f2 bag values[4:5]: [ 0.7352, 0.3210, -3.0399], # embedding for values 5, 6, 7 in f2 bag values[5:8]: [ 0.1279, -0.1756, -0.4130], [ 0.7519, -0.4341, -0.0499], [ 0.9329, -1.0697, -0.8095], ], grad_fn=<EmbeddingBackward>)
- property device: device¶
返回: torch.device:计算设备。
- embedding_configs() List[EmbeddingConfig]¶
- Returns:
嵌入配置。
- Return type:
List[EmbeddingConfig]
- embedding_dim() int¶
- Returns:
嵌入维度。
- Return type:
整数
- embedding_names_by_table() List[List[str]]¶
- Returns:
表格中的嵌入名称。
- Return type:
List[List[str]]
- forward(features: KeyedJaggedTensor) Dict[str, JaggedTensor]¶
运行嵌入袋集合的前向传播。此方法接受一个 KeyedJaggedTensor 并返回一个 Dict[str, JaggedTensor],这是每个特征的个体嵌入结果。
- Parameters:
特性 (KeyedJaggedTensor) – 形式为 [F X B X L] 的 KJT。
- Returns:
字典[str, JaggedTensor]
- need_indices() bool¶
- Returns:
是否需要 EmbeddingCollection 的索引。
- Return type:
布尔
- reset_parameters() None¶
重置EmbeddingCollection的参数。参数值基于每个EmbeddingConfig的init_fn初始化,如果存在。