目录

torch.nested

介绍

警告

嵌套张量的 PyTorch API 处于原型阶段,将在不久的将来发生变化。

警告

torch。NestedTensor 目前不支持 autograd。它需要在上下文中使用 的 torch.inference_mode() 中。

NestedTensor 允许用户将 Tensor 列表打包到一个高效的数据结构中。

对输入 Tensor 的唯一约束是它们的维度必须匹配。

这可实现更高效的元数据表示和运算符覆盖率。

构造很简单,涉及将 Tensor 列表传递给构造函数。

>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested_tensor([a, b])
>>> nt
nested_tensor([
  tensor([0, 1, 2]),
    tensor([3, 4, 5, 6, 7])
    ])

数据类型和设备可以通过通常的关键字参数进行选择

>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
  tensor([0., 1., 2.], device='cuda:0'),
  tensor([3., 4., 5., 6., 7.], device='cuda:0')
])

操作员覆盖范围

我们目前正在以特定 ML 使用案例为指导,大规模扩展运维覆盖范围。

因此,运算符覆盖范围目前非常有限,仅支持 unbind。

>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
  tensor([0., 1., 2.], device='cuda:0'),
  tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
>>> nt.unbind()
[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')]

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源