目录

torch.overrides

该模块为协议公开了各种辅助函数。有关协议的更多详细信息,请参阅扩展 torch Python API__torch_function____torch_function__

功能

torch.overrides 的get_ignored_functions[来源]

返回不能被 覆盖的公共函数。__torch_function__

返回

在 torch API 中公开可用但不能 被 覆盖。主要是因为 这些函数的参数是 Tensors 或 Tensor-likes。__torch_function__

返回类型

Set[可调用]

例子

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides 的get_overridable_functions[来源]

列出可通过 __torch_function__ 覆盖的函数

返回

映射包含可重写函数的命名空间的字典 添加到该命名空间中可以覆盖的函数。

返回类型

dict[任意, 列表[可调用]]

torch.overrides 的resolve_namef[来源]

获取传递给 __torch_function__

参数

fCallable) – 要解析其名称的函数。

返回

函数的名称;如果 evaled,它应该返回输入 功能。

返回类型

str

torch.overrides 的get_testing_overrides[来源]

返回一个包含所有可重写函数的虚拟重写的 dict

返回

将 PyTorch API 中的可覆盖函数映射到 与实际函数具有相同签名的 Lambda 函数 并无条件返回 -1。这些 lambda 函数非常有用 用于测试定义 .__torch_function__

返回类型

dict[可调用,可调用]

例子

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides 的handle_torch_functionpublic_apirelevant_args*args**kwargs[来源]

实现一个检查覆盖的函数。__torch_function__

参见 torch::autograd::handle_torch_function 了解这个 函数C++实现中。

参数
  • public_apifunction) – 最初调用的公共 torch API 公开的函数,类似于现在正在调用的参数 检查。public_api(*args, **kwargs)

  • relevant_argsiterable) – 要检查__torch_function__方法的参数的可迭代对象。

  • argstuple) – 最初传入的任意位置参数。public_api

  • kwargstuple) – 最初传入的任意关键字参数。public_api

返回

调用或方法的结果(如果适用)。implementation__torch_function__

返回类型

对象

:raises TypeError : 如果未找到实现。

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides 的has_torch_function()

检查 iterable 元素中的 __torch_function__ 实现 或者启用了 __torch_function__ 模式。考虑精确 和 s 不可调度。使用 this 来保护对 ;不要用它来测试某 类似 Tensor,请改用。 :p aram relevant_args:要检查__torch_function__方法的可迭代对象或参数。 :type relevant_args: 可迭代TensorParameter

返回

如果 relevant_args 的任何元素具有 __torch_function__,则为 True implementations,否则为 False。

返回类型

布尔

另请参阅

torch.is_tensor_like

检查某物是否为 Tensor-like,包括精确的 .Tensor

torch.overrides 的is_tensor_likeinp[来源]

如果传入的输入是类似 Tensor 的,则返回。True

目前,只要输入类型有属性,就会发生这种情况。__torch_function__

例子

tensor 的子类通常是类似 Tensor 的。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

内置类型或用户类型通常不是类似 Tensor 的。

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,它们可以通过实现 __torch_function__ 来变得类似 Tensor。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides 的is_tensor_method_or_propertyfunc[来源]

如果传入的函数是 方法或属性属于 ,如传递的那样 到。torch.Tensor__torch_function__

注意

对于属性,必须传入其方法。__get__

这可能是必要的,特别是出于以下原因:

  1. 方法/属性有时不包含__module__槽。

  2. 它们要求第一个传入的参数是一个实例 之。torch.Tensor

例子

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
返回类型

布尔

torch.overrides 的wrap_torch_function调度员[来源]

使用 相关功能包装给定函数。__torch_function__

参数

dispatcherCallable) – 一个可调用对象,它返回传递到函数中的 Tensor 类可迭代对象。

注意

此装饰器可能会降低代码的性能。一般来说,表达 您的代码是一系列本身支持 __torch_function__ 的函数。如果你 发现自己处于极少数情况下,情况并非如此,例如,如果您将 low-level 库,并且您还需要它适用于 Tensor-like,那么此函数可用。

例子

>>> def dispatcher(a):  # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a):  # This will make func dispatchable by __torch_function__
...     return a + 0

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源