广播语义¶
许多 PyTorch 操作都支持 NumPy 的广播语义。 有关详细信息,请参阅 https://numpy.org/doc/stable/user/basics.broadcasting.html。
简而言之,如果 PyTorch 操作支持 broadcast,那么它的 Tensor 参数可以是 自动扩展为相等的大小(不制作数据副本)。
一般语义¶
如果满足以下规则,则两个张量是“可广播的”:
每个张量至少有一个维度。
迭代维度大小时,从尾随维度开始, 维度大小必须相等,其中一个为 1,或其中一个 不存在。
例如:
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)
>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3
如果两个张量 , 是 “可广播的” ,则生成的张量大小
计算如下:x
y
如果 的维数 和 不相等,则在 1 前面加上 添加到维度较少的张量的维度中,使它们的长度相等。
x
y
然后,对于每个维度大小,生成的维度大小是该维度的大小和沿该维度的大小的最大值。
x
y
例如:
# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty( 3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])
# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
就地语义¶
一个复杂性是就地操作不允许就地张量更改形状 作为广播的结果。
例如:
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])
# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.
向后兼容性¶
早期版本的 PyTorch 允许在具有不同形状的张量上执行某些逐点函数, 只要每个张量中的元素数量相等。然后进行逐点操作 out 通过将每个张量视为 1 维来查看。PyTorch 现在支持广播和“一维” pointwise 行为被视为已弃用,并且在张量 不可广播,但具有相同数量的元素。
请注意,在以下情况下,广播的引入可能会导致向后不兼容的更改 两个张量的形状不同,但可广播且元素数相同。 例如:
>>> torch.add(torch.ones(4,1), torch.randn(4))
之前会生成一个大小为: torch 的 Tensor。size([4,1]),但现在会生成一个 size为 torch 的 Tensor。Size([4,4]) 的 URL 中。 为了帮助识别代码中可能存在广播引入的向后不兼容的情况, 您可以将 torch.utils.backcompat.broadcast_warning.enabled 设置为 True,这将生成 Python 警告 在这种情况下。
例如:
>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.