torch.testing¶
- torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source]¶
断言
actual和expected是接近的。如果
actual和expected是步长的、非量化的、实值的和有限的,那么它们被认为是接近的非有限值(
-inf和inf)仅在它们相等时才被认为接近。NaN仅在equal_nan等于True时才被认为彼此相等。此外,只有当它们具有相同的
device(如果check_device是True),dtype(如果check_dtype是True),layout(如果check_layout是True),并且步长(如果
check_stride是True)。
如果
actual或expected是元张量,则仅执行属性检查。如果
actual和expected是稀疏的(具有 COO、CSR、CSC、BSR 或 BSC 布局),则它们的步幅成员将被单独检查。索引,即 COO 的indices,CSR 和 BSR 的crow_indices和col_indices,或 CSC 和 BSC 布局的ccol_indices和row_indices,总是会被检查是否相等,而值则根据上述定义检查是否接近。如果
actual和expected被量化,当它们具有相同的qscheme()并且根据上述定义dequantize()的结果接近时,它们被认为是接近的。actual和expected可以是Tensor或任何张量或标量类似的对象,可以通过torch.as_tensor()构造出torch.Tensor。除了 Python 标量外,输入类型必须直接相关。此外,actual和expected可以是Sequence或Mapping,在这种情况下,如果它们的结构匹配,并且所有元素都根据上述定义被认为是接近的,则认为它们是接近的。注意
Python 标量是类型关系要求的例外,因为它们的
type(),即int、float和complex,等同于张量类似结构的dtype。因此, 不同类型的 Python 标量可以被检查,但需要check_dtype=False。- Parameters
实际值 (Any) – 实际输入。
预期 (Any) – 预期输入。
allow_subclasses (bool) – 如果为
True(默认值),则除了 Python 标量外,允许直接相关类型的输入。否则需要类型完全一致。rtol (Optional[float]) – 相对容差。如果指定了
atol,也必须指定该参数。如果省略,则根据dtype选择默认值,并参考下表进行选择。atol (Optional[float]) – 绝对容差。如果指定了
rtol,也必须指定。如果省略,则根据dtype选择默认值,并使用下表进行选择。check_device (bool) – 如果为
True(默认值),则断言相应的张量位于同一device上。如果禁用此检查,则在比较之前,不同device上的张量会被移动到 CPU 上。check_dtype (布尔值) – 如果
True(默认),断言对应的张量具有相同的数据类型。如果禁用此检查,具有不同数据类型的张量将被提升为一个通用的数据类型(根据torch.promote_types())后再进行比较。check_layout (bool) – 如果
True(默认值),则断言相应的张量具有相同的layout。如果禁用此检查,则在比较之前,将不同layout的张量转换为步幅张量。check_stride (bool) – 如果为
True且相应的张量是步幅的,则断言它们具有相同的步长。msg (Optional[Union[str, Callable[[str], str]]]) – 如果在比较过程中发生失败,可以选择使用的错误信息。也可以作为可调用对象传递,在这种情况下,它将使用生成的消息进行调用,并应返回新的消息。
- Raises
ValueError – 如果无法从输入构建
torch.Tensor。ValueError – 如果仅指定了
rtol或atol。AssertionError – 如果对应的输入不是Python标量且没有直接关系。
断言错误 – 如果
allow_subclasses是False,但对应的输入不是 Python 标量且类型不同。AssertionError – 如果输入是
Sequence,但它们的长度不匹配。AssertionError – 如果输入是
Mapping,但它们的键集合不匹配。AssertionError – 如果对应的张量不具有相同的
shape。断言错误 – 如果
check_layout是True,但对应的张量没有相同的layout。AssertionError – 如果只有一个对应的张量被量化。
AssertionError – 如果对应的张量被量化,但具有不同的
qscheme()。断言错误 – 如果
check_dtype是True,但相应的张量不具有相同的dtype。断言错误 – 如果
check_stride是True,但对应的带步长的张量没有相同的步长。AssertionError – 如果对应张量的值根据上述定义不接近,则会引发此错误。
以下表格显示了不同
dtype的默认rtol和atol。如果出现不匹配的dtype,将使用两者的最大容差。dtypertolatolfloat161e-31e-5bfloat161.6e-21e-5float321.3e-61e-5float641e-71e-7complex321e-31e-5complex641.3e-61e-5complex1281e-71e-7quint81.3e-61e-5quint2x41.3e-61e-5quint4x21.3e-61e-5qint81.3e-61e-5qint321.3e-61e-5其他
0.00.0注意
assert_close()具有高度可配置性,并且具有严格的默认设置。鼓励用户根据其使用场景对partial()进行调整。例如,如果需要进行相等性检查,可以定义一个assert_equal,该检查在默认情况下对每个dtype使用零容差:>>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0
示例
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer
- torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source]¶
使用给定的
shape,device和dtype创建一个张量,并用从[low, high)均匀抽取的值填充。如果指定了
low或high且它们超出dtype可表示有限值的范围,则它们将分别被限制为最低或最高可表示的有限值。 如果指定的是None,则下表描述了low和high的默认值,这些值取决于dtype。dtypelowhigh布尔类型
02无符号整数类型
010有符号整数类型
-910浮点类型
-99复杂类型
-99- Parameters
shape (Tuple[int, ...]) – 单个整数或定义输出张量形状的一系列整数。
数据类型 (
torch.dtype) – 返回张量的数据类型。device (Union[str, torch.device]) – 返回张量的设备。
low (可选[数字]) – 设置给定范围的下限(包含)。如果提供了一个数字,它将被限制为给定数据类型可以表示的最小有限值。当
None(默认), 此值基于dtype确定(请参见上表)。默认值:None。高 (可选[数字]) –
设置给定范围的上限(不包括)。如果提供一个数字,它将被限制为给定 dtype 可表示的最大有限值。当
None(默认)时,该值根据dtype确定(见上表)。默认值:None。自 2.1 版本开始弃用:将
low==high传递给make_tensor()用于浮点或复数类型自 2.1 版本开始弃用,并将在 2.3 版本中移除。请使用torch.full()代替。requires_grad (Optional[bool]) – 如果自动求导应该记录对返回张量的操作。默认值:
False.noncontiguous (可选[bool]) – 如果为 True,则返回的张量将是非连续的。如果构造的张量元素少于两个,则忽略此参数。与
memory_format互斥。exclude_zero (Optional[bool]) – 如果
True,则零将根据dtype被替换为数据类型的小正数值。对于布尔型和整数类型,零被替换为一。对于浮点类型,它被替换为该数据类型最小的正常正数(即dtype的finfo()对象的“极小”值),而对于复数类型,则被替换为其实部和虚部都为该复数类型可表示的最小正常正数的复数。默认值False。memory_format (Optional[torch.memory_format]) – 返回张量的内存格式。与
noncontiguous互斥。
- Raises
ValueError – 如果为整数 dtype 传递了
requires_grad=True值错误 – 如果
low >= high。ValueError – 如果
low或high是nan。ValueError – 如果同时传递了
noncontiguous和memory_format。TypeError – 如果
dtype不被此函数支持。
- Return type
示例
>>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0')
- torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source]¶
警告
torch.testing.assert_allclose()自1.12起已弃用,并将在将来的版本中移除。 请改用torch.testing.assert_close()。您可以在 此处 找到详细的升级说明。