torch.overrides¶
该模块为协议公开了各种辅助函数。有关协议的更多详细信息,请参阅扩展 torch。__torch_function__
__torch_function__
功能¶
-
torch.overrides.
get_ignored_functions
()[来源]¶ 返回不能被 覆盖的公共函数。
__torch_function__
- 返回
在 torch API 中公开可用但不能 被 覆盖。主要是因为 这些函数的参数是 Tensors 或 Tensor-likes。
__torch_function__
- 返回类型
Set[可调用]
例子
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
-
torch.overrides.
get_overridable_functions
()[来源]¶ 列出可通过 __torch_function__ 覆盖的函数
- 返回
映射包含可重写函数的命名空间的字典 添加到该命名空间中可以覆盖的函数。
- 返回类型
dict[任意, 列表[可调用]]
-
torch.overrides.
get_testing_overrides
()[来源]¶ 返回一个包含所有可重写函数的虚拟重写的 dict
- 返回
将 PyTorch API 中的可覆盖函数映射到 与实际函数具有相同签名的 Lambda 函数 并无条件返回 -1。这些 lambda 函数非常有用 用于测试定义 .
__torch_function__
- 返回类型
dict[可调用,可调用]
例子
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
-
torch.overrides.
handle_torch_function
(public_api、relevant_args、*args、**kwargs)[来源]¶ 实现一个检查覆盖的函数。
__torch_function__
参见 torch::autograd::handle_torch_function 了解这个 函数C++实现中。
- 参数
- 返回
调用或方法的结果(如果适用)。
implementation
__torch_function__
- 返回类型
:raises TypeError : 如果未找到实现。
例
>>> def func(a): ... if type(a) is not torch.Tensor: # This will make func dispatchable by __torch_function__ ... return handle_torch_function(func, (a,), a) ... return a + 0
-
torch.overrides.
has_torch_function
()¶ 检查 iterable 元素中的 __torch_function__ 实现。 将确切的 s 和 s 视为不可调度的。 :p aram relevant_args:要检查__torch_function__方法的可迭代对象或 aguments。 :type relevant_args: 可迭代
Tensor
Parameter
- 返回
如果 relevant_args 的任何元素具有 __torch_function__,则为 True implementations,否则为 False。
- 返回类型
另请参阅
torch.is_tensor_like
检查某物是否为 Tensor-like,包括精确的 .
Tensor
-
torch.overrides.
is_tensor_like
(inp)[来源]¶ 如果传入的输入是类似 Tensor 的,则返回。
True
目前,只要输入类型有属性,就会发生这种情况。
__torch_function__
例子
tensor 的子类通常是类似 Tensor 的。
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
内置类型或用户类型通常不是类似 Tensor 的。
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
但是,它们可以通过实现 __torch_function__ 来变得类似 Tensor。
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
-
torch.overrides.
is_tensor_method_or_property
(func)[来源]¶ 如果传入的函数是 方法或属性属于 ,如传递的那样 到。
torch.Tensor
__torch_function__
注意
对于属性,必须传入其方法。
__get__
这可能是必要的,特别是出于以下原因:
方法/属性有时不包含__module__槽。
它们要求第一个传入的参数是一个实例 之。
torch.Tensor
例子
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
-
torch.overrides.
wrap_torch_function
(调度员)[来源]¶ 使用 相关功能包装给定函数。
__torch_function__
- 参数
dispatcher (Callable) – 一个可调用对象,它返回传递到函数中的 Tensor 类可迭代对象。
注意
此装饰器可能会降低代码的性能。一般来说,表达 您的代码是一系列本身支持 __torch_function__ 的函数。如果你 发现自己处于极少数情况下,情况并非如此,例如,如果您将 low-level 库,并且您还需要它适用于 Tensor-like,那么此函数可用。
例子
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0