torch.nested¶
介绍¶
警告
PyTorch 的嵌套张量 API 目前处于原型阶段,将在不久的将来发生变化。
警告
torch.NestedTensor 目前不支持 autograd。它需要在 torch.inference_mode() 的上下文中使用。
NestedTensor 允许用户将张量列表打包成一个高效的数据结构。
输入张量的唯一约束是它们的维度必须匹配。
这使得元数据表示和操作符覆盖更加高效。
构建过程非常直接,只需将一个张量列表传递给构造函数即可。
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested_tensor([a, b])
>>> nt
nested_tensor([
tensor([0, 1, 2]),
tensor([3, 4, 5, 6, 7])
])
数据类型和设备可以通过通常的关键词参数进行选择。
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0'),
tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
操作符覆盖¶
我们目前正在根据特定的机器学习使用案例,推进扩大操作符覆盖范围的工作。
因此,操作符的覆盖范围目前非常有限,仅支持 unbind。
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0'),
tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
>>> nt.unbind()
[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')]