张量视图¶
PyTorch 允许一个张量成为现有张量的 View。视图张量与其基础张量共享相同的底层数据。支持 View 避免了显式的数据复制,从而允许我们进行快速且内存高效的重塑、切片和元素级操作。
例如,要查看现有的张量t,你可以调用t.view(...)。
>>> t = torch.rand(4, 4)
>>> b = t.view(2, 8)
>>> t.storage().data_ptr() == b.storage().data_ptr() # `t` and `b` share the same underlying data.
True
# Modifying view tensor changes base tensor as well.
>>> b[0][0] = 3.14
>>> t[0][0]
tensor(3.14)
由于视图与其基础张量共享底层数据,因此如果你编辑视图中的数据,基础张量中的数据也会相应反映出来。
通常,PyTorch 操作会返回一个新的张量作为输出,例如 add().
但在视图操作的情况下,输出是输入张量的视图,以避免不必要的数据复制。
创建视图时不会发生数据移动,视图张量只是改变了对相同数据的解释方式。
对连续张量进行视图操作可能会生成非连续张量。
用户应注意连续性可能对性能产生隐式影响。
transpose() 是一个常见示例。
>>> base = torch.tensor([[0, 1],[2, 3]])
>>> base.is_contiguous()
True
>>> t = base.transpose(0, 1) # `t` is a view of `base`. No data movement happened here.
# View tensors might be non-contiguous.
>>> t.is_contiguous()
False
# To get a contiguous tensor, call `.contiguous()` to enforce
# copying data when `t` is not contiguous.
>>> c = t.contiguous()
作为参考,以下是 PyTorch 中完整的视图操作列表:
基本切片和索引操作,例如
tensor[0, 2:, 1:7:2]返回基tensor的视图,参见下面的说明。view_as_real()split_with_sizes()indices()(仅限稀疏张量)values()(仅限稀疏张量)
注意
当通过索引访问张量的内容时,PyTorch 遵循 Numpy 的行为,即基本索引返回视图,而高级索引返回副本。 无论是基本索引还是高级索引的赋值都是原地进行的。更多示例请参见 Numpy 索引文档。
值得一提的还有一些具有特殊行为的操作:
reshape(),reshape_as()和flatten()可以返回视图或新张量,用户代码不应依赖于它是视图还是新张量。contiguous()返回其自身如果输入张量已经是连续的,否则它通过复制数据返回一个新的连续张量。
有关PyTorch内部实现的更详细说明,请参阅ezyang关于PyTorch内部的文章。