目录

torch.overrides

此模块公开了__torch_function__协议的各种辅助函数。 详见扩展 torch Python API以获取更多关于__torch_function__协议的详细信息。

功能

torch.overrides.get_ignored_functions()[source]

返回不能被__torch_function__覆盖的公共函数。

Returns

一个函数元组,这些函数在torch API中公开可用但不能被__torch_function__覆盖。主要是因为这些函数的参数都不是张量或类似的张量。

Return type

Set[Callable]

示例

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()[source]

列出可通过 __torch_function__ 覆盖的函数。

Returns

一个字典,将包含可覆盖函数的命名空间映射到该命名空间中可以被覆盖的函数。

Return type

字典[任何, 列表[可调用]]

torch.overrides.resolve_name(f)[source]

获取传递给 __torch_function__ 的函数的人可读字符串名称。

Parameters

f (可调用对象) – 解析名称的函数。

Returns

函数名称;如果进行求值,应返回输入函数。

Return type

字符串

torch.overrides.get_testing_overrides()[source]

返回一个字典,其中包含所有可覆盖函数的模拟覆盖项。

Returns

一个字典,它将可覆盖的PyTorch API函数映射到具有与真实函数相同签名的lambda函数,并无条件返回-1。这些lambda函数对于测试定义了__torch_function__类型的API覆盖率很有用。

Return type

字典[可调用对象,可调用对象]

示例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)[source]

实现一个带有__torch_function__重载检查的函数。

参见 torch::autograd::handle_torch_function 以查看此函数在 C++ 实现中的等效实现。

Parameters
  • public_api (函数) – 由公共 torch API 提供的函数,最初被调用为 public_api(*args, **kwargs),现在正在检查其参数。

  • 相关参数 (可迭代对象) – 需要检查__torch_function__方法的参数可迭代对象。

  • args (元组) – 传递给 public_api 的任意位置参数。

  • kwargs (元组) – 传递给 public_api 的任意关键字参数。

Returns

调用implementation或适当情况下调用__torch_function__方法的结果。

Return type

对象

引发 TypeError 异常:如果没有找到实现。

示例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

检查可迭代对象的元素中是否存在__torch_function__实现,或者是否启用了__torch_function__模式。认为精确的Tensors和Parameters不可分派。使用此方法来保护对handle_torch_function()的调用;不要使用它来测试某个东西是否为张量类似,而是使用is_tensor_like()。 :param relevant_args: 可迭代对象或要检查的__torch_function__方法的参数。 :type relevant_args: 迭代器

Returns

如果相关参数 relevant_args 中有任何元素实现了 __torch_function__,则为真,否则为假。

Return type

布尔

另请参见

torch.is_tensor_like

检查某个东西是否为张量类似项,包括一个确切的 Tensor

torch.overrides.is_tensor_like(inp)[source]

返回 True 如果传入的输入是一个张量类似对象。

目前,这会在输入类型上存在__torch_function__属性时发生。

示例

张量的一个子类通常是一个类似于张量的对象。

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

内置或用户定义的类型通常不是张量类型的。

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,通过实现__torch_function__,它们可以被制成类似于Tensor的形式。

>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)[source]

如果传入的函数是一个属于torch.Tensor的方法或属性的处理程序,则返回True,如传递给__torch_function__

注意

对于属性,必须传入其 __get__ 方法。

这可能出于以下几个原因特别需要:

  1. 方法/属性有时不包含__module__槽。

  2. 它们要求传入的第一个参数是torch.Tensor的一个实例。

示例

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
Return type

布尔

torch.overrides.wrap_torch_function(dispatcher)[source]

用与__torch_function__相关的功能包装给定的函数。

Parameters

调度程序 (可调用对象) – 返回张量样式的可迭代对象的可调用对象,这些对象被传递到函数中。

注意

此装饰器可能会降低代码的性能。通常情况下,将你的代码表示为一系列支持 __torch_function__ 的函数就足够了。如果你发现自己处于罕见的情况,例如,你在封装一个底层库并且也需要它对张量类似物起作用,则可以使用该函数。

示例

>>> def dispatcher(a):  # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a):  # This will make func dispatchable by __torch_function__
...     return a + 0

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源