torch.testing¶
Warning
This module is in a PROTOTYPE state. New functions are still being added, and the available functions may change in future PyTorch releases. We are actively looking for feedback for UI/UX improvements or missing functionalities.
-
torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=False, check_is_coalesced=True, msg=None)[source]¶ Asserts that
actualandexpectedare close.If
actualandexpectedare strided, non-quantized, real-valued, and finite, they are considered close ifand they have the same
device(ifcheck_deviceisTrue), samedtype(ifcheck_dtypeisTrue), and the same stride (ifcheck_strideisTrue). Non-finite values (-infandinf) are only considered close if and only if they are equal.NaN’s are only considered equal to each other ifequal_nanisTrue.If
actualandexpectedare sparse (either having COO or CSR layout), their strided members are checked individually. Indices, namelyindicesfor COO orcrow_indicesandcol_indicesfor CSR layout, are always checked for equality whereas the values are checked for closeness according to the definition above. Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (ifcheck_is_coalescedisTrue).If
actualandexpectedare quantized, they are considered close if they have the sameqscheme()and the result ofdequantize()is close according to the definition above.actualandexpectedcan beTensor’s or any tensor-or-scalar-likes from whichtorch.Tensor’s can be constructed withtorch.as_tensor(). Except for Python scalars the input types have to be directly related. In addition,actualandexpectedcan beSequence’s orMapping’s in which case they are considered close if their structure matches and all their elements are considered close according to the above definition.Note
Python scalars are an exception to the type relation requirement, because their
type(), i.e.int,float, andcomplex, is equivalent to thedtypeof a tensor-like. Thus, Python scalars of different types can be checked, but requirecheck_dtypeto be set toFalse.- Parameters
actual (Any) – Actual input.
expected (Any) – Expected input.
allow_subclasses (bool) – If
True(default) and except for Python scalars, inputs of directly related types are allowed. Otherwise type equality is required.rtol (Optional[float]) – Relative tolerance. If specified
atolmust also be specified. If omitted, default values based on thedtypeare selected with the below table.atol (Optional[float]) – Absolute tolerance. If specified
rtolmust also be specified. If omitted, default values based on thedtypeare selected with the below table.equal_nan (Union[bool, str]) – If
True, twoNaNvalues will be considered equal.check_device (bool) – If
True(default), asserts that corresponding tensors are on the samedevice. If this check is disabled, tensors on differentdevice’s are moved to the CPU before being compared.check_dtype (bool) – If
True(default), asserts that corresponding tensors have the samedtype. If this check is disabled, tensors with differentdtype’s are promoted to a commondtype(according totorch.promote_types()) before being compared.check_stride (bool) – If
Trueand corresponding tensors are strided, asserts that they have the same stride.check_is_coalesced (bool) – If
True(default) and corresponding tensors are sparse COO, checks that bothactualandexpectedare either coalesced or uncoalesced. If this check is disabled, tensors arecoalesce()’ed before being compared.msg (Optional[Union[str, Callable[[Tensor, Tensor, Diagnostics], str]]]) – Optional error message to use if the values of corresponding tensors mismatch. Can be passed as callable in which case it will be called with the mismatching tensors and a namespace of diagnostics about the mismatches. See below for details.
- Raises
ValueError – If no
torch.Tensorcan be constructed from an input.ValueError – If only
rtoloratolis specified.AssertionError – If corresponding inputs are not Python scalars and are not directly related.
AssertionError – If
allow_subclassesisFalse, but corresponding inputs are not Python scalars and have different types.AssertionError – If the inputs are
Sequence’s, but their length does not match.AssertionError – If the inputs are
Mapping’s, but their set of keys do not match.AssertionError – If corresponding tensors do not have the same
shape.AssertionError – If corresponding tensors do not have the same
layout.AssertionError – If corresponding tensors are quantized, but have different
qscheme()’s.AssertionError – If
check_deviceisTrue, but corresponding tensors are not on the samedevice.AssertionError – If
check_dtypeisTrue, but corresponding tensors do not have the samedtype.AssertionError – If
check_strideisTrue, but corresponding strided tensors do not have the same stride.AssertionError – If
check_is_coalescedisTrue, but corresponding sparse COO tensors are not both either coalesced or uncoalesced.AssertionError – If the values of corresponding tensors are not close according to the definition above.
The following table displays the default
rtolandatolfor differentdtype’s. In case of mismatchingdtype’s, the maximum of both tolerances is used.dtypertolatolfloat161e-31e-5bfloat161.6e-21e-5float321.3e-61e-5float641e-71e-7complex321e-31e-5complex641.3e-61e-5complex1281e-71e-7other
0.00.0The namespace of diagnostics that will be passed to
msgif its a callable has the following attributes:number_of_elements(int): Number of elements in each tensor being compared.total_mismatches(int): Total number of mismatches.max_abs_diff(Union[int, float]): Greatest absolute difference of the inputs.max_abs_diff_idx(Union[int, Tuple[int, …]]): Index of greatest absolute difference.atol(float): Allowed absolute tolerance.max_rel_diff(Union[int, float]): Greatest relative difference of the inputs.max_rel_diff_idx(Union[int, Tuple[int, …]]): Index of greatest relative difference.rtol(float): Allowed relative tolerance.
For
max_abs_diffandmax_rel_diffthe type depends on thedtypeof the inputs.Note
assert_close()is highly configurable with strict default settings. Users are encouraged topartial()it to fit their use case. For example, if an equality check is needed, one might define anassert_equalthat uses zero tolrances for everydtypeby default:>>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Absolute difference: 8.999999703829253e-10 Relative difference: 8.999999583666371
Examples
>>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... AssertionError: Except for Python scalars, type equality is required if allow_subclasses=False, but got <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'> instead. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... AssertionError: Except for Python scalars, input types need to be directly related, but got <class 'numpy.ndarray'> and <class 'torch.Tensor'> instead. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default mismatch message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # The error message can also created at runtime by passing a callable. >>> def custom_msg(actual, expected, diagnostics): ... ratio = diagnostics.total_mismatches / diagnostics.number_of_elements ... return ( ... f"Argh, we found {diagnostics.total_mismatches} mismatches! " ... f"That is {ratio:.1%}!" ... ) >>> torch.testing.assert_close(actual, expected, msg=custom_msg) Traceback (most recent call last): ... AssertionError: Argh, we found 2 mismatches! That is 66.7%!
-
torch.testing.make_tensor(shape, device, dtype, *, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False)[source]¶ Creates a tensor with the given
shape,device, anddtype, and filled with values uniformly drawn from[low, high).If
loworhighare specified and are outside the range of thedtype’s representable finite values then they are clamped to the lowest or highest representable finite value, respectively. IfNone, then the following table describes the default values forlowandhigh, which depend ondtype.dtypelowhighboolean type
02unsigned integral type
010signed integral types
-910floating types
-99complex types
-99- Parameters
shape (Tuple[int, ..]) – A sequence of integers defining the shape of the output tensor.
device (Union[str, torch.device]) – The device of the returned tensor.
dtype (
torch.dtype) – The data type of the returned tensor.low (Optional[Number]) – Sets the lower limit (inclusive) of the given range. If a number is provided it is clamped to the least representable finite value of the given dtype. When
None(default), this value is determined based on thedtype(see the table above). Default:None.high (Optional[Number]) – Sets the upper limit (exclusive) of the given range. If a number is provided it is clamped to the greatest representable finite value of the given dtype. When
None(default) this value is determined based on thedtype(see the table above). Default:None.requires_grad (Optional[bool]) – If autograd should record operations on the returned tensor. Default:
False.noncontiguous (Optional[bool]) – If True, the returned tensor will be noncontiguous. This argument is ignored if the constructed tensor has fewer than two elements.
exclude_zero (Optional[bool]) – If
Truethen zeros are replaced with the dtype’s small positive value depending on thedtype. For bool and integer types zero is replaced with one. For floating point types it is replaced with the dtype’s smallest positive normal number (the “tiny” value of thedtype’sfinfo()object), and for complex types it is replaced with a complex number whose real and imaginary parts are both the smallest positive normal number representable by the complex type. DefaultFalse.
- Raises
ValueError – If
low > high.ValueError – If either
loworhighisnan.TypeError – If
dtypeisn’t supported by this function.
Examples
>>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0')