目录

torch.testing

torch.testing 中。assert_close实际预期*allow_subclasses=rtol=atol=equal_nan=Falsecheck_device=Truecheck_dtype=Truecheck_layout=Truecheck_stride=Falsemsg=None[来源]

断言 并且很接近。actualexpected

如果 和 是跨步的、非量化的、实值和有限的,则它们被认为是接近的,如果actualexpected

实际预期阿托尔+RTOL预期\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

非有限值 ( 和 ) 仅在相等时被视为接近值。的 是 如果 为 ,则仅被视为彼此相等。-infinfNaNequal_nanTrue

此外,只有当它们具有相同的

  • (如果是),check_deviceTrue

  • dtype(如果是),check_dtypeTrue

  • layout(如果是)和check_layoutTrue

  • stride (如果为 )。check_strideTrue

如果 or 是元张量,则仅执行属性检查。actualexpected

如果 和 是稀疏的(具有 COO、CSR、CSC、BSR 或 BSC 布局),则它们的跨步成员为 单独检查。指数,即 COO、CSR 和 BSR、 或 和 分别用于 CSC 和 BSC 布局, 始终检查是否相等,而根据上述定义检查值是否接近。actualexpectedindicescrow_indicescol_indicesccol_indicesrow_indices

如果 和 被量化,则如果它们具有相同的值,并且根据 定义。actualexpected

actual,可以是 ',也可以是 '或任何类似张量或标量的类张量或标量,其中 可以用 构造 。除 Python 标量外,输入类型 必须直接相关。此外,可以是 的 或 的 ,在这种情况下,如果它们的结构匹配并且所有 根据上述定义,它们的元素被认为是接近的。expectedactualexpected

注意

Python 标量是类型关系要求的例外,因为它们的 ,即 、 和 等价于类似张量的标量。因此 可以检查不同类型的 Python 标量,但需要 。type()dtypecheck_dtype=False

参数
  • actualAny) (实际输入) – 实际输入。

  • expectedAny) (预期输入) – 预期输入。

  • allow_subclassesbool) – If (默认) 并且除了 Python 标量之外,直接相关的类型的输入 都是允许的。否则,需要类型相等。True

  • rtolOptional[float]) - 相对容差。如果指定,还必须指定。如果省略,则默认 基于 的值是通过下表选择的。atoldtype

  • atolOptional[float]) - 绝对容差。如果指定,还必须指定。如果省略,则默认 基于 的值是通过下表选择的。rtoldtype

  • equal_nanUnion[boolstr]) – 如果 ,则认为两个值相等。TrueNaN

  • check_devicebool) – If (默认) – 断言相应的张量位于同一 .如果禁用此检查,则不同 上的张量将在比较之前移动到 CPU。True

  • check_dtypebool) - 如果 (默认) 断言相应的张量具有相同的 。如果此 check 时,具有不同 的张量将在比较之前提升为公共张量(根据)。Truedtypedtypedtype

  • check_layoutbool) - 如果 (默认) 断言相应的张量具有相同的 。如果此 check 时,具有不同 的张量会在 比较。Truelayoutlayout

  • check_stridebool) - 如果和相应的张量是跨步的,则断言它们具有相同的步幅。True

  • msgOptional[Union[str Callable[[str]str]]]) – 在期间发生故障时使用的可选错误消息 比较。也可以作为可调用对象传递,在这种情况下,它将与生成的消息一起调用 应返回新消息。

提高
  • ValueError – 如果为 no,则可以从输入构造。

  • ValueError – 如果仅指定了 or。rtolatol

  • AssertionError – 如果相应的输入不是 Python 标量并且没有直接关系。

  • AssertionError – 如果是 ,但相应的输入不是 Python 标量,并且具有 不同的类型。allow_subclassesFalse

  • AssertionError – 如果输入是 ,但它们的长度不匹配。

  • AssertionError – 如果输入是 ,但它们的键集不匹配。

  • AssertionError – 如果相应的张量不具有相同的 .

  • AssertionError – 如果为 ,但相应的张量不具有相同的 。check_layoutTruelayout

  • AssertionError – 如果只有一个相应的张量被量化。

  • AssertionError – 如果相应的张量已量化,但具有不同的 's。

  • AssertionError – 如果是 ,但相应的张量不在同一 .check_deviceTrue

  • AssertionError – 如果为 ,但相应的张量不具有相同的 。check_dtypeTruedtype

  • AssertionError – 如果为 ,但相应的跨步张量没有相同的步幅。check_strideTrue

  • AssertionError – 如果根据上述定义,相应张量的值不接近。

下表显示了 default 和 for different 's.如果 不匹配,则使用两个容差的最大值。rtolatoldtypedtype

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

quint8

1.3e-6

1e-5

quint2x4

1.3e-6

1e-5

quint4x2

1.3e-6

1e-5

qint8

1.3e-6

1e-5

qint32

1.3e-6

1e-5

其他

0.0

0.0

注意

具有严格的默认设置,具有高度可配置性。鼓励用户 以适应他们的用例。例如,如果需要进行相等性检查,则可能会 定义一个 默认情况下对 every 使用零容差:assert_equaldtype

>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: Scalars are not equal!

Expected 1e-10 but got 1e-09.
Absolute difference: 9.000000000000001e-10
Relative difference: 9.0

例子

>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # By default, directly related instances can be compared
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # This check can be made more strict with allow_subclasses=False
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type
<class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
>>> # If the inputs are not directly related, they are never considered close
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
and <class 'torch.Tensor'>.
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of
>>> # their type if check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: Scalars are not close!

Expected nan but got nan.
Absolute difference: nan (up to 1e-05 allowed)
Relative difference: nan (up to 1.3e-06 allowed)
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default error message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # If msg is a callable, it can be used to augment the generated message with
>>> # extra information
>>> torch.testing.assert_close(
...     actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
... )
Traceback (most recent call last):
...
AssertionError: Header

Tensor-likes are not close!

Mismatched elements: 2 / 3 (66.7%)
Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)

Footer
torch.testing 中。make_tensor*shapedtypedevicelow=high=requires_grad=Falsenoncontiguous=Falseexclude_zero=Falsememory_format=None[来源]

创建具有给定 、 、 和 的张量,并填充有 统一从 中抽取的值。shapedevicedtype[low, high)

如果指定了 或 ,并且超出 的 的 可表示 finite 值,则它们分别被钳制为最低或最高可表示的有限值。 如果 ,则下表描述了 和 的默认值 , 它们依赖于 。lowhighdtypeNonelowhighdtype

dtype

low

high

布尔类型

0

2

unsigned 整数类型

0

10

有符号整型

-9

10

浮动类型

-9

9

复杂类型

-9

9

参数
  • shapeTuple[int...]) – 定义输出张量形状的单个整数或整数序列。

  • dtype) – 返回的张量的数据类型。

  • deviceUnion[strtorch.device]) – 返回张量的设备。

  • lowOptional[Number]) – 设置给定范围的下限(含)。如果提供了数字,则为 固定到给定 dtype 的最小可表示的有限值。When (默认)、 此值是根据 (请参阅上表) 确定的。违约:。NonedtypeNone

  • high可选 [Number]) –

    设置给定范围的上限 (不包括) 。如果提供了数字,则为 钳制到给定 dtype 的最大可表示有限值。当 (默认) 此值时 是根据 确定的(见上表)。违约:。NonedtypeNone

    2.1 版后已移除: 对于浮点类型或复杂类型,不推荐使用 to 从 2.1 开始,并将在 2.3 中删除。请改用low==high

  • requires_gradOptional[bool]) – autograd 是否应记录对返回的张量的操作。违约:。False

  • noncontiguousOptional[bool]) – 如果为 True,则返回的张量将是非连续的。这个参数是 如果构造的 Tensor 少于两个元素,则忽略。与 互斥。memory_format

  • exclude_zeroOptional[bool]) – 如果 then 将零替换为 dtype 的小正值 取决于 .对于 bool 和 integer 类型,0 将替换为 1。用于浮动 Point 类型,它被替换为 dtype 的最小正法线数( 对象的“tiny”值),而对于复杂类型,它被替换为复数 其实部和虚部都是复数可表示的最小正正规数 类型。违约。Truedtypedtypefinfo()False

  • memory_formatOptional[torch.memory_format]) – 返回的张量的内存格式。互斥 跟。noncontiguous

提高
  • ValueError – 如果为整型 dtype 传递requires_grad=True

  • ValueError – 如果 .low >= high

  • ValueError – 如果 或 是 。lowhighnan

  • ValueError – 如果同时传递 和 。noncontiguousmemory_format

  • TypeError – 如果此函数不支持。dtype

返回类型

张肌

例子

>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
        [False, True]], device='cuda:0')
torch.testing 中。assert_allclose实际预期rtol=atol=equal_nan=msg=''[来源]

警告

已弃用,并将在未来发行版中删除。 请改用。您可以在此处找到详细的升级说明。1.12

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源