目录

命名张量算子覆盖率

请先阅读命名张量,了解命名张量的介绍。

本文档是名称推理的参考,该过程定义了如何 命名张量:

  1. 使用名称提供额外的自动运行时正确性检查

  2. 将名称从输入张量传播到输出张量

以下是命名张量支持的所有作的列表 及其关联的名称推理规则。

如果您在此处没有看到作,但它对您的使用案例有所帮助,请搜索是否已提交问题,如果没有,请提交一个问题。

警告

命名的张量 API 是实验性的,可能会发生变化。

支持的作

应用程序接口

名称推理规则

Tensor.abs(),torch.abs()

保留输入名称

Tensor.abs_()

保留输入名称

Tensor.acos(),torch.acos()

保留输入名称

Tensor.acos_()

保留输入名称

Tensor.add(),torch.add()

统一输入中的名称

Tensor.add_()

统一输入中的名称

Tensor.addmm(),torch.addmm()

Contracts away 暗淡

Tensor.addmm_()

Contracts away 暗淡

Tensor.addmv(),torch.addmv()

Contracts away 暗淡

Tensor.addmv_()

Contracts away 暗淡

Tensor.align_as()

查看文档

Tensor.align_to()

查看文档

Tensor.all(),torch.all()

没有

Tensor.any(),torch.any()

没有

Tensor.asin(),torch.asin()

保留输入名称

Tensor.asin_()

保留输入名称

Tensor.atan(),torch.atan()

保留输入名称

Tensor.atan2(),torch.atan2()

统一输入中的名称

Tensor.atan2_()

统一输入中的名称

Tensor.atan_()

保留输入名称

Tensor.bernoulli(),torch.bernoulli()

保留输入名称

Tensor.bernoulli_()

没有

Tensor.bfloat16()

保留输入名称

Tensor.bitwise_not(),torch.bitwise_not()

保留输入名称

Tensor.bitwise_not_()

没有

Tensor.bmm(),torch.bmm()

Contracts away 暗淡

Tensor.bool()

保留输入名称

Tensor.byte()

保留输入名称

torch.cat()

统一输入中的名称

Tensor.cauchy_()

没有

Tensor.ceil(),torch.ceil()

保留输入名称

Tensor.ceil_()

没有

Tensor.char()

保留输入名称

Tensor.chunk(),torch.chunk()

保留输入名称

Tensor.clamp(),torch.clamp()

保留输入名称

Tensor.clamp_()

没有

Tensor.copy_()

out 函数和就地变体

Tensor.cos(),torch.cos()

保留输入名称

Tensor.cos_()

没有

Tensor.cosh(),torch.cosh()

保留输入名称

Tensor.cosh_()

没有

Tensor.acosh(),torch.acosh()

保留输入名称

Tensor.acosh_()

没有

Tensor.cpu()

保留输入名称

Tensor.cuda()

保留输入名称

Tensor.cumprod(),torch.cumprod()

保留输入名称

Tensor.cumsum(),torch.cumsum()

保留输入名称

Tensor.data_ptr()

没有

Tensor.deg2rad(),torch.deg2rad()

保留输入名称

Tensor.deg2rad_()

没有

Tensor.detach(),torch.detach()

保留输入名称

Tensor.detach_()

没有

Tensor.device,torch.device()

没有

Tensor.digamma(),torch.digamma()

保留输入名称

Tensor.digamma_()

没有

Tensor.dim()

没有

Tensor.div(),torch.div()

统一输入中的名称

Tensor.div_()

统一输入中的名称

Tensor.dot(),torch.dot()

没有

Tensor.double()

保留输入名称

Tensor.element_size()

没有

torch.empty()

工厂功能

torch.empty_like()

工厂功能

Tensor.eq(),torch.eq()

统一输入中的名称

Tensor.erf(),torch.erf()

保留输入名称

Tensor.erf_()

没有

Tensor.erfc(),torch.erfc()

保留输入名称

Tensor.erfc_()

没有

Tensor.erfinv(),torch.erfinv()

保留输入名称

Tensor.erfinv_()

没有

Tensor.exp(),torch.exp()

保留输入名称

Tensor.exp_()

没有

Tensor.expand()

保留输入名称

Tensor.expm1(),torch.expm1()

保留输入名称

Tensor.expm1_()

没有

Tensor.exponential_()

没有

Tensor.fill_()

没有

Tensor.flatten(),torch.flatten()

查看文档

Tensor.float()

保留输入名称

Tensor.floor(),torch.floor()

保留输入名称

Tensor.floor_()

没有

Tensor.frac(),torch.frac()

保留输入名称

Tensor.frac_()

没有

Tensor.ge(),torch.ge()

统一输入中的名称

Tensor.get_device(),torch.get_device()

没有

Tensor.grad

没有

Tensor.gt(),torch.gt()

统一输入中的名称

Tensor.half()

保留输入名称

Tensor.has_names()

查看文档

Tensor.index_fill(),torch.index_fill()

保留输入名称

Tensor.index_fill_()

没有

Tensor.int()

保留输入名称

Tensor.is_contiguous()

没有

Tensor.is_cuda

没有

Tensor.is_floating_point(),torch.is_floating_point()

没有

Tensor.is_leaf

没有

Tensor.is_pinned()

没有

Tensor.is_shared()

没有

Tensor.is_signed(),torch.is_signed()

没有

Tensor.is_sparse

没有

Tensor.is_sparse_csr

没有

torch.is_tensor()

没有

Tensor.item()

没有

Tensor.kthvalue(),torch.kthvalue()

删除维度

Tensor.le(),torch.le()

统一输入中的名称

Tensor.log(),torch.log()

保留输入名称

Tensor.log10(),torch.log10()

保留输入名称

Tensor.log10_()

没有

Tensor.log1p(),torch.log1p()

保留输入名称

Tensor.log1p_()

没有

Tensor.log2(),torch.log2()

保留输入名称

Tensor.log2_()

没有

Tensor.log_()

没有

Tensor.log_normal_()

没有

Tensor.logical_not(),torch.logical_not()

保留输入名称

Tensor.logical_not_()

没有

Tensor.logsumexp(),torch.logsumexp()

删除维度

Tensor.long()

保留输入名称

Tensor.lt(),torch.lt()

统一输入中的名称

torch.manual_seed()

没有

Tensor.masked_fill(),torch.masked_fill()

保留输入名称

Tensor.masked_fill_()

没有

Tensor.masked_select(),torch.masked_select()

将蒙版与输入对齐,然后unifies_names_from_input_tensors

Tensor.matmul(),torch.matmul()

Contracts away 暗淡

Tensor.mean(),torch.mean()

删除维度

Tensor.median(),torch.median()

删除维度

Tensor.nanmedian(),torch.nanmedian()

删除维度

Tensor.mm(),torch.mm()

Contracts away 暗淡

Tensor.mode(),torch.mode()

删除维度

Tensor.mul(),torch.mul()

统一输入中的名称

Tensor.mul_()

统一输入中的名称

Tensor.mv(),torch.mv()

Contracts away 暗淡

Tensor.names

查看文档

Tensor.narrow(),torch.narrow()

保留输入名称

Tensor.ndim

没有

Tensor.ndimension()

没有

Tensor.ne(),torch.ne()

统一输入中的名称

Tensor.neg(),torch.neg()

保留输入名称

Tensor.neg_()

没有

torch.normal()

保留输入名称

Tensor.normal_()

没有

Tensor.numel(),torch.numel()

没有

torch.ones()

工厂功能

Tensor.pow(),torch.pow()

统一输入中的名称

Tensor.pow_()

没有

Tensor.prod(),torch.prod()

删除维度

Tensor.rad2deg(),torch.rad2deg()

保留输入名称

Tensor.rad2deg_()

没有

torch.rand()

工厂功能

torch.rand()

工厂功能

torch.randn()

工厂功能

torch.randn()

工厂功能

Tensor.random_()

没有

Tensor.reciprocal(),torch.reciprocal()

保留输入名称

Tensor.reciprocal_()

没有

Tensor.refine_names()

查看文档

Tensor.register_hook()

没有

Tensor.rename()

查看文档

Tensor.rename_()

查看文档

Tensor.requires_grad

没有

Tensor.requires_grad_()

没有

Tensor.resize_()

仅允许不改变形状的调整大小

Tensor.resize_as_()

仅允许不改变形状的调整大小

Tensor.round(),torch.round()

保留输入名称

Tensor.round_()

没有

Tensor.rsqrt(),torch.rsqrt()

保留输入名称

Tensor.rsqrt_()

没有

Tensor.select(),torch.select()

删除维度

Tensor.short()

保留输入名称

Tensor.sigmoid(),torch.sigmoid()

保留输入名称

Tensor.sigmoid_()

没有

Tensor.sign(),torch.sign()

保留输入名称

Tensor.sign_()

没有

Tensor.sgn(),torch.sgn()

保留输入名称

Tensor.sgn_()

没有

Tensor.sin(),torch.sin()

保留输入名称

Tensor.sin_()

没有

Tensor.sinh(),torch.sinh()

保留输入名称

Tensor.sinh_()

没有

Tensor.asinh(),torch.asinh()

保留输入名称

Tensor.asinh_()

没有

Tensor.size()

没有

Tensor.split(),torch.split()

保留输入名称

Tensor.sqrt(),torch.sqrt()

保留输入名称

Tensor.sqrt_()

没有

Tensor.squeeze(),torch.squeeze()

删除维度

Tensor.std(),torch.std()

删除维度

torch.std_mean()

删除维度

Tensor.stride()

没有

Tensor.sub(),torch.sub()

统一输入中的名称

Tensor.sub_()

统一输入中的名称

Tensor.sum(),torch.sum()

删除维度

Tensor.tan(),torch.tan()

保留输入名称

Tensor.tan_()

没有

Tensor.tanh(),torch.tanh()

保留输入名称

Tensor.tanh_()

没有

Tensor.atanh(),torch.atanh()

保留输入名称

Tensor.atanh_()

没有

torch.tensor()

工厂功能

Tensor.to()

保留输入名称

Tensor.topk(),torch.topk()

删除维度

Tensor.transpose(),torch.transpose()

排列维度

Tensor.trunc(),torch.trunc()

保留输入名称

Tensor.trunc_()

没有

Tensor.type()

没有

Tensor.type_as()

保留输入名称

Tensor.unbind(),torch.unbind()

删除维度

Tensor.unflatten()

查看文档

Tensor.uniform_()

没有

Tensor.var(),torch.var()

删除维度

torch.var_mean()

删除维度

Tensor.zero_()

没有

torch.zeros()

工厂功能

保留输入名称

所有逐点一元函数以及其他一些一元函数都遵循此规则。

  • 检查名称:无

  • Propagate names:输入张量的名称将传播到输出。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')

删除维度

所有 reduction作,如sum()通过减少来删除尺寸 超过所需的尺寸。其他作,如select()squeeze()删除维度。

只要可以将整数维度索引传递给运算符,也可以将 维度名称。采用维度索引列表的函数也可以采用 维度名称列表。

  • 检查名称:如果 或 作为名称列表传入, 检查这些名称是否存在于 中。dimdimsself

  • Propagate names:如果输入张量的维度由输出张量指定或不存在,则相应的名称 的 中未显示在 中。dimdimsoutput.names

>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')

>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')

# Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')

统一输入中的名称

所有二进制算术运算都遵循此规则。仍然广播的作 从右侧进行位置广播,以保持与 unnamed 的兼容性 张。要按名称执行显式广播,请使用Tensor.align_as().

  • 检查名称:所有名称必须从右侧开始位置匹配。即,对于 IN 中的所有内容,都必须为 true。tensor + othermatch(tensor.names[i], other.names[i])i(-min(tensor.dim(), other.dim()) + 1, -1]

  • 检查名称:此外,所有命名尺寸必须从右侧对齐。 在匹配过程中,如果我们将命名维度与 unnamed dimension 匹配,则不得出现在具有 unnamed 维度的张量中。ANoneA

  • Propagate names:将两个张量从右侧到的名称对统一到 生成输出名称。

例如

# tensor: Tensor[   N, None]
# other:  Tensor[None,    C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')

检查姓名:

  • match(tensor.names[-1], other.names[-1])True

  • match(tensor.names[-2], tensor.names[-2])True

  • 因为我们在Nonetensor跟 检查以确保 中不存在'C''C'tensor(它没有)。

  • 检查以确保 中不存在 (it does not)。'N'other

最后,使用[unify('N', None), unify(None, 'C')] = ['N', 'C']

更多示例:

# Dimensions don't match from the right:
# tensor: Tensor[N, C]
# other:  Tensor[   N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.

# Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
# tensor: Tensor[N, None]
# other:  Tensor[      N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.

注意

在最后两个示例中,都可以按名称对齐张量 ,然后执行加法。用Tensor.align_as()对齐 Tensors by name 或Tensor.align_to()将张量与自定义 维度排序。

排列维度

某些作(如Tensor.t(),则排列维度的顺序。维度名称 附加到各个维度,因此它们也会被置换。

如果运算符采用 positional index ,它也能够采用维度 name 设置为 .dimdim

  • 检查名称:如果作为名称传递,请检查它是否存在于张量中。dim

  • 传播名称:以与维度相同的方式排列维度名称 正在排列。

>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')

Contracts away 暗淡

矩阵乘法函数遵循 this 的一些变体。让我们来看看torch.mm()首先,然后通用化批量矩阵乘法的规则。

为:torch.mm(tensor, other)

  • 检查名称:无

  • 传播名称:结果名称为 。(tensor.names[-2], other.names[-1])

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')

从本质上讲,矩阵乘法在两个维度上执行点积, 折叠它们。当两个张量进行矩阵相乘时,收缩维度 disappear 的 intent 值,并且不会显示在 output tensor 中。

torch.mv(),torch.dot()以类似的方式工作:名称推理不会 检查 Input names 并删除 .product 中涉及的维度:

>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)

现在,我们来看看 。假设 和 。torch.matmul(tensor, other)tensor.dim() >= 2other.dim() >= 2

  • 检查名称:检查输入的批量维度是否对齐且可广播。 请参阅 统一输入中的名称 ,了解输入对齐的含义。

  • Propagate names:通过统一批处理维度并删除 合同规定的尺寸: 。unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])

例子:

# Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
# 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D'))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F'))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')

最后,还有许多 matmul 函数的融合版本。即addaddmm()addmv().这些被视为 i.i.mm()和 的名称推理add().

工厂功能

Factory 函数现在采用一个关联名称的新参数 与每个维度。names

>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
        [0., 0., 0.]], names=('N', 'C'))

out 函数和就地变体

指定为张量的张量具有以下行为:out=

  • 如果它没有命名维度,则从作中计算的名称 传播到它。

  • 如果它有任何命名维度,则从作中计算的名称 必须与现有名称完全相等。否则,作会出错。

所有就地方法都会修改输入,使其名称等于计算的名称 从名称推断。例如:

>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)

>>> x += y
>>> x.names
('N', 'C')

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源