目录

torch.masked

介绍

动机

警告

PyTorch 中掩码张量的 API 处于原型阶段,未来可能会或不会发生变化。

MaskedTensor 作为 torch.Tensor 的扩展,为用户提供了以下能力:

  • 使用任何掩码语义(例如,可变长度张量、nan* 运算符等)

  • 区分 0 和 NaN 梯度

  • 各种稀疏应用(请参见下面的教程)

“指定”和“未指定”在PyTorch中有着悠久的历史,但从未有正式的语义,更不用说一致性了;实际上,MaskedTensor 的诞生源于普通 torch.Tensor 类无法妥善解决的一系列问题。因此,MaskedTensor 的主要目标是成为 PyTorch 中所述“指定”和“未指定”值的权威来源,使其成为首要考虑因素而非事后补救。 反过来,这应该进一步释放 稀疏性 的潜力, 实现更安全、更一致的操作,并为用户和开发者提供更顺畅、更直观的体验。

什么是 MaskedTensor?

一个 MaskedTensor 是一个张量的子类,它由以下两部分组成:1) 输入(数据),以及 2) 一个掩码。该掩码告诉我们输入中的哪些条目应该被包含或忽略。

举个例子,假设我们想要屏蔽所有等于 0(用灰色表示)的值,并取最大值:

_images/tensor_comparison.jpg

上方是普通张量示例,而下方是 MaskedTensor,其中所有的 0 都被屏蔽掉了。 这显然会根据我们是否拥有掩码产生不同的结果,但这种灵活的结构允许用户在计算过程中系统地忽略任何他们想要忽略的元素。

我们已经编写了若干教程,以帮助用户快速上手,例如:

支持的操作符

一元运算符

一元运算符是仅包含单个输入的运算符。 将其应用于 MaskedTensors 相对简单:如果在给定索引处数据被屏蔽,则应用该运算符;否则,我们继续屏蔽这些数据。

可用的单目运算符有:

abs

计算 input 中每个元素的绝对值。

absolute

别名为 torch.abs()

acos

计算 input 中每个元素的反余弦值。

arccos

别名为 torch.acos()

acosh

返回一个新的张量,其中包含input元素的反双曲余弦值。

arccosh

别名为 torch.acosh()

angle

计算给定input张量的元素角度(以弧度为单位)。

asin

返回一个新的张量,其中包含input元素的反正弦值。

arcsin

别名为 torch.asin()

asinh

返回一个新的张量,其中包含input元素的反双曲正弦值。

arcsinh

别名为 torch.asinh()

atan

返回一个新的张量,其中包含input元素的反正切值。

arctan

别名为 torch.atan()

atanh

返回一个新的张量,其中包含input元素的反双曲正切值。

arctanh

别名为 torch.atanh()

bitwise_not

计算给定输入张量的按位NOT。

ceil

返回一个新的张量,其中包含input元素的ceil值,即每个元素的大于或等于其的最小整数。

clamp

将所有元素在input限制在范围[ minmax ]之间。

clip

别名为 torch.clamp()

conj_physical

计算给定input张量的逐元素共轭。

cos

返回一个新的张量,其中包含input元素的余弦值。

cosh

返回一个新的张量,其中包含input元素的双曲余弦值。

deg2rad

返回一个新的张量,其中input的每个元素都从角度转换为弧度。

digamma

别名为 torch.special.digamma()

erf

别名为 torch.special.erf()

erfc

别名为 torch.special.erfc()

erfinv

别名为 torch.special.erfinv()

exp

返回一个新的张量,其中包含输入张量 input 元素的指数。

exp2

别名为 torch.special.exp2()

expm1

别名为 torch.special.expm1()

fix

别名为 torch.trunc()

floor

返回一个新的张量,其中包含input元素的下取整值,即小于或等于每个元素的最大整数。

frac

计算input中每个元素的小数部分。

lgamma

计算伽玛函数绝对值的自然对数在 input 上。

log

返回一个新的张量,其中包含input元素的自然对数。

log10

返回一个新的张量,其中包含input元素的以10为底的对数。

log1p

返回一个新的张量,其中包含 (1 + input) 的自然对数。

log2

返回一个新的张量,其中包含input元素的以2为底的对数。

logit

别名为 torch.special.logit()

i0

别名为 torch.special.i0()

isnan

返回一个新的张量,其中的布尔元素表示input中的每个元素是否为NaN。

nan_to_num

NaN、正无穷和负无穷值在 input 中替换为由 nanposinfneginf 分别指定的值。

neg

返回一个新的张量,其中包含input元素的负值。

negative

别名为 torch.neg()

positive

返回 input

pow

input 中的每个元素使用 exponent 的幂,并返回一个包含结果的张量。

rad2deg

返回一个新的张量,其中input的每个元素都从弧度转换为角度。

reciprocal

返回一个新的张量,其元素是input中元素的倒数

round

input的元素四舍五入到最接近的整数。

rsqrt

返回一个新的张量,其中每个元素是input中对应元素平方根的倒数。

sigmoid

别名为 torch.special.expit()

sign

返回一个新的张量,其中元素的符号与 input 中元素的符号相同。

sgn

此函数是torch.sign()对复数张量的扩展。

signbit

测试input的每个元素是否设置了符号位。

sin

返回一个新的张量,其中包含input元素的正弦值。

sinc

别名为 torch.special.sinc()

sinh

返回一个新的张量,其中包含input元素的双曲正弦值。

sqrt

返回一个新的张量,其中包含input元素的平方根。

square

返回一个新的张量,其中包含input元素的平方。

tan

返回一个新的张量,其中包含input元素的正切值。

tanh

返回一个新的张量,其中包含input元素的双曲正切值。

trunc

返回一个新的张量,其中包含元素 input 的截断整数值。

可用的原地一元运算符包括以上所有除外

angle

计算给定input张量的元素角度(以弧度为单位)。

positive

返回 input

signbit

测试input的每个元素是否设置了符号位。

isnan

返回一个新的张量,其中的布尔元素表示input中的每个元素是否为NaN。

二元运算符

如您在教程中所见,MaskedTensor 也实现了二进制操作,但需要注意的是两个 MaskedTensors 中的掩码必须匹配,否则将引发错误。如错误信息所述,如果您需要特定运算符的支持或提出了它们应该如何行为的语义,请在 GitHub 上提交问题。目前,我们决定采用最保守的实现方式,以确保用户清楚了解发生了什么,并且在使用掩码语义时是经过深思熟虑的。

可用的二元运算符有:

add

other 乘以 alpha 后加到 input 上。

atan2

逐元素反正切值计算inputi/otheri\text{input}_{i} / \text{other}_{i},并考虑象限。

arctan2

别名为 torch.atan2()

bitwise_and

计算 inputother 的按位与。

bitwise_or

计算 inputother 的按位或。

bitwise_xor

计算 inputother 的按位异或。

bitwise_left_shift

计算 input 向左算术移位 other 位。

bitwise_right_shift

计算 input 向右算术移位 other 位。

div

将输入 input 的每个元素除以 other 中对应的元素。

divide

别名为 torch.div()

floor_divide

fmod

逐元素应用C++的 std::fmod 函数。

logaddexp

输入的指数和的对数。

logaddexp2

以2为底的输入指数和的对数。

mul

input 乘以 other

multiply

别名为 torch.mul()

nextafter

返回input之后的下一个浮点值,逐元素地朝向other

remainder

计算 Python的取模运算 逐元素。

sub

input 中减去 other,并乘以 alpha

subtract

别名为 torch.sub()

true_divide

别名为 torch.div(),具有 rounding_mode=None

eq

计算元素级别的相等性

ne

计算每个inputother\text{input} \neq \text{other}元素。

le

计算每个inputother\text{input} \leq \text{other}元素。

ge

计算每个inputother\text{input} \geq \text{other}元素。

greater

别名为 torch.gt()

greater_equal

别名为 torch.ge()

gt

计算每个input>other\text{input} > \text{other}元素。

less_equal

别名为 torch.le()

lt

计算每个input<other\text{input} < \text{other}元素。

less

别名为 torch.lt()

maximum

计算 inputother 的逐元素最大值。

minimum

计算 inputother 的逐元素最小值。

fmax

计算 inputother 的逐元素最大值。

fmin

计算 inputother 的逐元素最小值。

not_equal

别名为 torch.ne()

可用的原地二元运算符包括以上所有除了

logaddexp

输入的指数和的对数。

logaddexp2

以2为底的输入指数和的对数。

equal

True 如果两个张量具有相同的大小和元素,则为 False 否则。

fmin

计算 inputother 的逐元素最小值。

minimum

计算 inputother 的逐元素最小值。

fmax

计算 inputother 的逐元素最大值。

归约

以下归约操作可用(支持自动求导)。更多信息,请参阅 概述教程 其中包含一些归约操作的示例,而 高级语义教程 则深入讨论了我们如何决定某些归约语义。

sum

返回input张量中所有元素的总和。

mean

amin

返回给定维度的input张量每个切片的最小值dim

amax

返回给定维度 dim 中每个切片的 input 张量的最大值。

argmin

返回展平张量或沿某个维度的最小值的索引。

argmax

返回所有元素在input张量中的最大值的索引。

prod

返回input张量中所有元素的乘积。

all

测试input中的所有元素是否评估为True

norm

返回给定张量的矩阵范数或向量范数。

var

计算在由 dim 指定的维度上的方差。

std

在指定的维度dim上计算标准差。

查看并选择函数

我们还包含了许多查看和选择函数;直观地说,这些操作符将同时应用于数据和掩码,然后将结果包装在MaskedTensor中。举个快速的例子,请考虑select()

>>> data = torch.arange(12, dtype=torch.float).reshape(3, 4)
>>> data
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
>>> mask = torch.tensor([[True, False, False, True], [False, True, False, False], [True, True, True, True]])
>>> mt = masked_tensor(data, mask)
>>> data.select(0, 1)
tensor([4., 5., 6., 7.])
>>> mask.select(0, 1)
tensor([False,  True, False, False])
>>> mt.select(0, 1)
MaskedTensor(
  [      --,   5.0000,       --,       --]
)

目前支持以下操作:

atleast_1d

返回每个输入张量的零维视图的一维表示。

broadcast_tensors

根据广播语义广播给定的张量。

broadcast_to

input 广播到形状 shape

cat

在给定的维度中,将给定的张量序列连接起来tensors

chunk

尝试将张量分割成指定数量的块。

column_stack

通过水平堆叠tensors中的张量来创建一个新的张量。

dsplit

input,一个具有三个或更多维度的张量,根据 indices_or_sections 按深度分割成多个张量。

flatten

input 展平,将其重塑为一维张量。

hsplit

input,一个具有一个或多个维度的张量,根据 indices_or_sections 水平分割成多个张量。

hstack

按序列水平(列方向)堆叠张量。

kron

计算克罗内克积,记作\otimesinputother的。

meshgrid

创建由1D输入在attr:张量指定的坐标网格。

narrow

返回一个新的张量,它是input张量的缩小版本。

nn.functional.unfold

从批量输入张量中提取滑动局部块。

ravel

返回一个连续的展平张量。

select

沿选定维度在给定索引处切片input张量。

split

将张量分割成块。

stack

沿着新维度连接一系列张量。

t

期望 input 为 <= 2-D 张量,并转置维度 0 和 1。

transpose

返回一个张量,它是input的转置版本。

vsplit

input, 一个具有两个或更多维度的张量,根据 indices_or_sections 垂直分割成多个张量。

vstack

按垂直顺序(逐行)堆叠张量。

Tensor.expand

返回一个新视图的self张量,将单例维度扩展到更大的尺寸。

Tensor.expand_as

将此张量扩展为与 other 相同的大小。

Tensor.reshape

返回一个与self具有相同数据和元素数量但具有指定形状的张量。

Tensor.reshape_as

返回与此张量形状相同的other

Tensor.unfold

返回原始张量的视图,该视图包含维度 dimension 中从 self 张量中获取的所有大小为 size 的切片。

Tensor.view

返回一个与 self 张量具有相同数据但不同 shape 的新张量。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源