目录

torch.testing

警告

该模块处于原型(PROTOTYPE)状态。新功能仍在不断增加,未来 PyTorch 版本中可用的功能可能会发生变化。我们正在积极寻求关于 UI/UX 改进或缺失功能的反馈。

torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=False, check_is_coalesced=True, msg=None)[source]

断言 actualexpected 是接近的。

如果 actualexpected 是步长、非量化、实数值且有限的,那么它们被认为是 接近的,当

actualexpectedatol+rtolexpected\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

并且它们具有相同的 device(如果 check_deviceTrue),相同的 dtype(如果 check_dtypeTrue),以及相同的步长(如果 check_strideTrue)。非有限值 (-infinf) 仅在它们相等时才被视为接近。NaN 仅在 equal_nanTrue 时才被视为彼此相等。

如果 actualexpected 是稀疏的(具有 COO 或 CSR 布局),则会分别检查它们的步长成员。索引,即 COO 的 indices 或 CSR 布局的 crow_indicescol_indices,始终检查是否相等,而值则根据上述定义检查是否接近。 只有当两个稀疏 COO 张量都为合并或未合并时(如果 check_is_coalesced 等于 True),才认为它们是接近的。

如果 actualexpected 被量化,当它们具有相同的 qscheme() 并且根据上述定义 dequantize() 的结果接近时,它们被认为是接近的。

actualexpected 可以是 Tensor 的或任何可以从中 torch.Tensor 构造的张量或标量类似的对象,使用 torch.as_tensor()。除了 Python 标量之外,输入类型 必须直接相关。此外,actualexpected 可以是 SequenceMapping,在这种情况下,如果它们的结构匹配,并且所有元素根据上述定义被认为是接近的,则认为它们是接近的。

注意

Python 标量是类型关系要求的一个例外,因为它们的 type(),即 intfloatcomplex,等同于张量类似的 dtype。因此, 不同类型的 Python 标量可以被检查,但需要将 check_dtype 设置为 False

Parameters
  • 实际值 (Any) – 实际输入。

  • 预期 (Any) – 预期输入。

  • allow_subclasses (bool) – 如果为 True(默认值),则除了 Python 标量外,允许直接相关类型的输入。否则需要类型完全一致。

  • rtol (Optional[float]) – 相对容差。如果指定了 atol,也必须指定。如果省略, 将根据 dtype 选择默认值,如下表所示。

  • atol (Optional[float]) – 绝对容差。如果指定了 rtol,也必须指定。如果省略, 将根据 dtype 选择默认值,如下表所示。

  • equal_nan (Union[bool, str]) – 如果 True,两个 NaN 值将被视为相等。

  • check_device (bool) – 如果为 True(默认值),则断言相应的张量位于同一 device 上。如果禁用此检查,则在比较之前,不同 device 上的张量会被移动到 CPU 上。

  • check_dtype (布尔值) – 如果 True(默认),断言对应的张量具有相同的 数据类型。如果禁用此检查,具有不同 数据类型 的张量将被提升为一个通用的 数据类型(根据 torch.promote_types())后再进行比较。

  • check_stride (bool) – 如果为 True 且相应的张量是步幅的,则断言它们具有相同的步长。

  • check_is_coalesced (bool) – 如果 True(默认)且对应的张量是稀疏COO格式,检查 actualexpected 是否都是合并的或未合并的。如果禁用此检查,则在比较之前将张量 coalesce()

  • msg (可选[Union[str, Callable[[Tensor, Tensor, Diagnostics], str]]]) – 如果对应张量的值不匹配时使用的可选错误信息。可以作为可调用对象传递,此时它将被调用并传入不匹配的张量以及关于不匹配的诊断信息的命名空间。有关详细信息,请参见下文。

Raises
  • ValueError – 如果无法从输入构建 torch.Tensor

  • ValueError – 如果仅指定了 rtolatol

  • AssertionError – 如果对应的输入不是Python标量且没有直接关系。

  • AssertionError – 如果 allow_subclassesFalse,但对应的输入不是Python标量并且 具有不同的类型。

  • AssertionError – 如果输入是 Sequence,但它们的长度不匹配。

  • AssertionError – 如果输入是 Mapping,但它们的键集合不匹配。

  • AssertionError – 如果对应的张量没有相同的shape

  • AssertionError – 如果对应的张量没有相同的layout

  • AssertionError – 如果对应的张量被量化,但具有不同的 qscheme()

  • 断言错误 – 如果 check_deviceTrue,但相应的张量不在同一个 设备 上。

  • 断言错误 – 如果 check_dtypeTrue,但相应的张量不具有相同的 dtype

  • AssertionError – 如果 check_strideTrue,但对应的带步幅张量的步幅不相同。

  • AssertionError – 如果 check_is_coalescedTrue,但对应的稀疏 COO 张量不是都为合并或都为非合并。

  • AssertionError – 如果对应张量的值根据上述定义不接近,则会引发此错误。

以下表格显示了不同dtype的默认rtolatol。如果出现不匹配的dtype,将使用两者的最大容差。

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

其他

0.0

0.0

将传递给 msg 的诊断命名空间,如果它是可调用的,具有以下属性:

  • number_of_elements (int): 每个被比较张量中的元素数量。

  • total_mismatches (int): 总不匹配数。

  • max_abs_diff (Union[int, float]): 输入的最大绝对差异。

  • max_abs_diff_idx (Union[int, Tuple[int, …]]): 最大绝对差异的索引。

  • atol (float): 允许的绝对容差。

  • max_rel_diff (Union[int, float]): 输入的最大相对差异。

  • max_rel_diff_idx (Union[int, Tuple[int, …]]): 最大相对差异的索引。

  • rtol (float): 允许的相对容差。

对于 max_abs_diffmax_rel_diff,类型取决于 dtype 的输入。

注意

assert_close() 配置高度灵活,默认设置严格。鼓励用户 partial() 它以适应其使用场景。例如,如果需要进行等值检查,可以 定义一个 assert_equal,默认情况下对每个 dtype 使用零容忍度:

>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: Scalars are not equal!

Absolute difference: 8.999999703829253e-10
Relative difference: 8.999999583666371

示例

>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # By default, directly related instances can be compared
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # This check can be made more strict with allow_subclasses=False
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
AssertionError: Except for Python scalars, type equality is required if
allow_subclasses=False, but got <class 'torch.nn.parameter.Parameter'> and
<class 'torch.Tensor'> instead.
>>> # If the inputs are not directly related, they are never considered close
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
AssertionError: Except for Python scalars, input types need to be directly
related, but got <class 'numpy.ndarray'> and <class 'torch.Tensor'> instead.
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of
>>> # their type if check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: Scalars are not close!

Absolute difference: nan (up to 1e-05 allowed)
Relative difference: nan (up to 1.3e-06 allowed)
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default mismatch message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # The error message can also created at runtime by passing a callable.
>>> def custom_msg(actual, expected, diagnostics):
...     ratio = diagnostics.total_mismatches / diagnostics.number_of_elements
...     return (
...         f"Argh, we found {diagnostics.total_mismatches} mismatches! "
...         f"That is {ratio:.1%}!"
...     )
>>> torch.testing.assert_close(actual, expected, msg=custom_msg)
Traceback (most recent call last):
...
AssertionError: Argh, we found 2 mismatches! That is 66.7%!
torch.testing.make_tensor(shape, device, dtype, *, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False)[source]

使用给定的 shape, devicedtype 创建一个张量,并用从 [low, high) 均匀抽取的值填充。

如果指定了 lowhigh 且它们超出 dtype 可表示有限值的范围,则它们将分别被限制为最低或最高可表示的有限值。 如果指定的是 None,则下表描述了 lowhigh 的默认值,这些值取决于 dtype

dtype

low

high

布尔类型

0

2

无符号整数类型

0

10

有符号整数类型

-9

10

浮点类型

-9

9

复杂类型

-9

9

Parameters
  • shape (元组[整数, ..]) – 一个定义输出张量形状的整数序列。

  • device (Union[str, torch.device]) – 返回张量的设备。

  • dtype (torch.dtype) – 返回张量的数据类型。

  • low (可选[数字]) – 设置给定范围的下限(包含)。如果提供了一个数字,它将被限制为给定数据类型可以表示的最小有限值。当 None(默认), 此值基于 dtype 确定(请参见上表)。默认值:None

  • high (Optional[Number]) – 设置给定范围的上界(不包含)。如果提供一个数字,它会被截断到给定 dtype 可表示的最大有限值。当 None(默认)时,此值根据 dtype 确定(见上表)。默认值: None

  • requires_grad (Optional[bool]) – 如果自动求导应该记录对返回张量的操作。默认值: False.

  • noncontiguous (Optional[bool]) – 如果为 True,返回的张量将是不连续的。如果构造的张量元素少于两个,则忽略此参数。

  • exclude_zero (Optional[bool]) – 如果 True,则零将根据 dtype 被替换为数据类型的小正数值。对于布尔型和整数类型,零被替换为一。对于浮点类型,它被替换为该数据类型最小的正常正数(即 dtypefinfo() 对象的“极小”值),而对于复数类型,则被替换为其实部和虚部都为该复数类型可表示的最小正常正数的复数。默认值 False

Raises

示例

>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
        [False, True]], device='cuda:0')

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源