torch.overrides¶
该模块为协议公开了各种辅助函数。有关协议的更多详细信息,请参阅扩展 torch Python API。__torch_function__
__torch_function__
功能¶
- torch.overrides 的get_ignored_functions()[来源]¶
返回不能被 覆盖的公共函数。
__torch_function__
- 返回
在 torch API 中公开可用但不能 被 覆盖。主要是因为 这些函数的参数是 Tensors 或 Tensor-likes。
__torch_function__
- 返回类型
Set[可调用]
例子
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
- torch.overrides 的get_overridable_functions()[来源]¶
列出可通过 __torch_function__ 覆盖的函数
- 返回
映射包含可重写函数的命名空间的字典 添加到该命名空间中可以覆盖的函数。
- 返回类型
dict[任意, 列表[可调用]]
- torch.overrides 的resolve_name(f)[来源]¶
获取传递给 __torch_function__
- 参数
f (Callable) – 要解析其名称的函数。
- 返回
函数的名称;如果 evaled,它应该返回输入 功能。
- 返回类型
- torch.overrides 的get_testing_overrides()[来源]¶
返回一个包含所有可重写函数的虚拟重写的 dict
- 返回
将 PyTorch API 中的可覆盖函数映射到 与实际函数具有相同签名的 Lambda 函数 并无条件返回 -1。这些 lambda 函数非常有用 用于测试定义 .
__torch_function__
- 返回类型
dict[可调用,可调用]
例子
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
- torch.overrides 的handle_torch_function(public_api, relevant_args, *args, **kwargs)[来源]¶
实现一个检查覆盖的函数。
__torch_function__
参见 torch::autograd::handle_torch_function 了解这个 函数C++实现中。
- 参数
- 返回
调用或方法的结果(如果适用)。
implementation
__torch_function__
- 返回类型
:raises TypeError : 如果未找到实现。
例
>>> def func(a): ... if has_torch_function_unary(a): ... return handle_torch_function(func, (a,), a) ... return a + 0
- torch.overrides 的has_torch_function()¶
检查 iterable 元素中的 __torch_function__ 实现 或者启用了 __torch_function__ 模式。考虑精确 和 s 不可调度。使用 this 来保护对
;不要用它来测试某 类似 Tensor,请改用
。 :p aram relevant_args:要检查__torch_function__方法的可迭代对象或参数。 :type relevant_args: 可迭代
Tensor
Parameter
- 返回
如果 relevant_args 的任何元素具有 __torch_function__,则为 True implementations,否则为 False。
- 返回类型
另请参阅
torch.is_tensor_like
检查某物是否为 Tensor-like,包括精确的 .
Tensor
- torch.overrides 的is_tensor_like(inp)[来源]¶
如果传入的输入是类似 Tensor 的,则返回。
True
目前,只要输入类型有属性,就会发生这种情况。
__torch_function__
例子
tensor 的子类通常是类似 Tensor 的。
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
内置类型或用户类型通常不是类似 Tensor 的。
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
但是,它们可以通过实现 __torch_function__ 来变得类似 Tensor。
>>> class TensorLike: ... @classmethod ... def __torch_function__(cls, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
- torch.overrides 的is_tensor_method_or_property(func)[来源]¶
如果传入的函数是 方法或属性属于 ,如传递的那样 到。
torch.Tensor
__torch_function__
注意
对于属性,必须传入其方法。
__get__
这可能是必要的,特别是出于以下原因:
方法/属性有时不包含__module__槽。
它们要求第一个传入的参数是一个实例 之。
torch.Tensor
例子
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- 返回类型
- torch.overrides 的wrap_torch_function(调度员)[来源]¶
使用 相关功能包装给定函数。
__torch_function__
- 参数
dispatcher (Callable) – 一个可调用对象,它返回传递到函数中的 Tensor 类可迭代对象。
注意
此装饰器可能会降低代码的性能。一般来说,表达 您的代码是一系列本身支持 __torch_function__ 的函数。如果你 发现自己处于极少数情况下,情况并非如此,例如,如果您将 low-level 库,并且您还需要它适用于 Tensor-like,那么此函数可用。
例子
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0