torch.nested¶
介绍¶
警告
嵌套张量的 PyTorch API 处于原型阶段,将在不久的将来发生变化。
警告
torch。NestedTensor 目前不支持 autograd。它需要在上下文中使用 的 torch.inference_mode() 中。
NestedTensor 允许用户将 Tensor 列表打包到一个高效的数据结构中。
对输入 Tensor 的唯一约束是它们的维度必须匹配。
这可实现更高效的元数据表示和运算符覆盖率。
构造很简单,涉及将 Tensor 列表传递给构造函数。
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested_tensor([a, b])
>>> nt
nested_tensor([
tensor([0, 1, 2]),
tensor([3, 4, 5, 6, 7])
])
数据类型和设备可以通过通常的关键字参数进行选择
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0'),
tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
操作员覆盖范围¶
我们目前正在以特定 ML 使用案例为指导,大规模扩展运维覆盖范围。
因此,运算符覆盖范围目前非常有限,仅支持 unbind。
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0'),
tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
>>> nt.unbind()
[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')]