目录

扩展PyTorch

在本说明中,我们将介绍扩展 torch.nn, torch.autograd, torch 的方法,以及利用我们的 C 库编写自定义 C 扩展的方法。

扩展 torch.autograd

autograd 添加操作需要为每个操作实现一个新的 Function 子类。回想一下,函数是 autograd 用来编码操作历史和计算梯度的方式。

何时使用

一般来说,如果你想在模型中执行不可微分的计算或依赖非PyTorch库(例如NumPy)的计算,但仍然希望你的操作能够与其他操作链接并使用自动梯度引擎,那么可以实现一个自定义函数。

在某些情况下,自定义函数也可以用于提高性能和 内存使用:如果你使用C++扩展实现了前向和后向传播, 你可以将它们包装在 Function 中以与自动求导 引擎接口。如果你想减少为反向传播保存的缓冲区数量, 可以使用自定义函数将操作组合在一起。

何时不应使用

如果你已经能够用PyTorch内置的操作来编写你的函数,那么它的反向图(很可能)已经被autograd记录下来了。在这种情况下,你不需要自己实现反向函数。考虑使用普通的Python函数。

如果你需要维护状态,即可训练参数,你应该(也)使用自定义模块。有关扩展 torch.nn 的更多信息,请参见下面的部分。

如果您想在反向传播过程中修改梯度或执行副作用,请考虑注册一个 张量模块钩子。

如何使用

按照以下步骤操作: 1. 继承 Function 并实现 forward()backward() 方法。 2. 在 ctx 参数上调用适当的方法。 3. 声明你的函数是否支持双重反向传播。 4. 使用 gradcheck 验证你的梯度是否正确。

步骤 1: 在继承 Function 之后,你需要定义 2 个方法:

  • forward() 是执行操作的代码。它可以接受任意数量的参数,其中一些可以是可选的,如果你指定了默认值。这里接受各种类型的Python对象。 Tensor 跟踪历史(即,带有 requires_grad=True)的参数将在调用之前转换为不跟踪历史的参数,并且它们的使用将被记录在图中。请注意,这种逻辑不会遍历列表/字典/任何其他数据结构,只会考虑作为调用直接参数的张量。你可以 返回一个单一的 Tensor 输出,或者如果存在多个输出,则返回一个 tuple 的 张量。此外,请参考 Function 的文档,以查找只能从 forward() 调用的有用方法的描述。

  • backward() 定义了梯度公式。它将获得与输出数量相同的 Tensor 参数,每个参数代表相对于该输出的梯度。非常重要的是,千万不要在原地修改这些值。它应该返回与输入数量相同的张量,每个张量包含相对于其对应输入的梯度。如果你的输入不需要梯度(needs_input_grad 是一个布尔元组,表示每个输入是否需要计算梯度),或者不是 Tensor 对象,你可以返回 None。此外,如果你的 forward() 有可选参数,你可以返回比输入更多的梯度,只要它们都是 None

步骤 2: 确保正确使用 forward 中的函数 ctx 以保证新的 Function 能够与 autograd 引擎正常配合工作。

  • save_for_backward() 必须在保存用于后续反向传播的前向传播输入或输出张量时使用。 其他内容,即非张量以及既不是输入也不是输出的张量,应直接存储在 ctx 上。

  • mark_dirty() 必须用于标记任何被前向函数就地修改的输入。

  • mark_non_differentiable() 必须 用于告知引擎输出是否不可微分。默认情况下,所有可微分类型的输出张量都将设置为需要梯度。非可微分类型(即整数类型)的张量永远不会被标记为需要梯度。

  • set_materialize_grads() 可以用于告诉自动梯度计算引擎在输出不依赖于输入的情况下优化梯度计算,方法是不在反向传播函数中生成给定的梯度张量。也就是说,如果设置为False,在Python中将不会将None对象或在C++中的“未定义张量”(对于x.defined()返回False的张量x)转换为填充零的张量,因此你的代码需要像处理填充零的张量一样处理这些对象。此设置的默认值为True。

步骤 3: 如果你的 Function 不支持反向传播两次 你应该通过装饰器显式声明这一点,使用 once_differentiable() 装饰反向函数。有了这个装饰器,尝试 通过你的函数进行两次反向传播将产生错误。 有关双反向传播的更多信息,请参阅我们的双反向传播教程。

步骤4: 建议您使用 torch.autograd.gradcheck() 来检查您的反向函数是否正确计算了前向的梯度,通过使用您的反向函数计算雅可比矩阵,并与使用有限差分法数值计算的雅可比矩阵逐元素进行比较。

示例

下面你可以找到从 Lineartorch.nn 的函数代码,附有额外的注释:

# Inherit from Function
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

现在,为了更方便地使用这些自定义操作,我们建议为它们的 apply 方法创建别名:

linear = LinearFunction.apply

在这里,我们提供了一个额外的例子,展示了一个由非张量参数化的函数:

class MulConstant(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        # ctx is a context object that can be used to stash information
        # for backward computation
        ctx.constant = constant
        return tensor * constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

在这里,我们通过调用 set_materialize_grads(False) 来优化上述示例:

class MulConstant(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        ctx.set_materialize_grads(False)
        ctx.constant = constant
        return tensor * constant

    @staticmethod
    def backward(ctx, grad_output):
        # Here we must handle None grad_output tensor. In this case we
        # can skip unnecessary computations and just return None.
        if grad_output is None:
            return None, None

        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

注意

Inputs to backward, i.e., grad_output, can also be tensors that track history. So if backward is implemented with differentiable operations, (e.g., invocation of another custom function), higher order derivatives will work. In this case, the tensors saved with save_for_backward can also be used in the backward and have gradients flowing back but tensors saved in the ctx won’t have gradients flowing back for them. If you need gradients to flow back for a Tensor saved in the ctx, you should make it an output of the custom Function and save it with save_for_backward.

你可能想要检查你实现的反向方法是否确实计算了你的函数的导数。这可以通过与使用小的有限差分进行数值逼近比较来实现:

from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

参见数值梯度检查以了解更多关于有限差分梯度比较的详细信息。 如果你的函数用于高阶导数(区分反向传播), 你可以使用同一包中的gradgradcheck函数来检查高阶导数。

扩展 torch.nn

nn 提供两种类型的接口 - 模块及其函数式版本。你可以通过这两种方式扩展它,但我们建议对于所有包含参数或缓冲区的层使用模块,并建议对无参数的操作(如激活函数、池化等)使用函数式形式。

添加操作的功能版本已经在上面的章节中完全涵盖了。

添加一个 Module

由于 nn 大量使用了 autograd,添加一个新的 Module 需要实现一个 Function 来执行该操作并能够计算梯度。从现在开始,假设我们想要实现一个 Linear 模块,并且我们已经按照上面的代码列表实现了该函数。要添加这个模块所需的代码非常少。现在,需要实现两个函数:

  • __init__ (可选) - 接受诸如内核大小、特征数量等参数,并初始化参数和缓冲区。

  • forward() - 实例化一个 Function 并 使用它来执行操作。这与上面显示的功能包装器非常相似。

这是如何实现一个Linear模块的方法:

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(Linear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

扩展 torch

你可以通过定义一个自定义类,其方法与Tensor匹配,来创建自定义类型以模拟Tensor。但如果你希望将这些类型传递给顶级torch命名空间中的函数,如torch.add(),这些函数接受Tensor操作数,又该怎么办呢?

如果你的自定义Python类型定义了一个名为 __torch_function__ 的方法,当你的自定义类的一个实例传递给 torch 命名空间中的函数时,PyTorch 将调用你的 __torch_function__ 实现。这使得你可以为 torch 命名空间中的任何函数定义自定义实现,而你的 __torch_function__ 实现可以调用这些函数,从而使你的用户能够使用他们已经为 Tensor 编写的现有 PyTorch 工作流程。这不仅适用于与 Tensor 无关的“鸭子”类型,也适用于 Tensor 的用户自定义子类。

扩展 torch 以支持类似 Tensor 的类型

注意

此功能受NumPy __array_function__ 协议的启发。 详见 NumPy 文档NEP-0018 以获取更多详情。

为了具体说明这一点,让我们从一个简单的例子开始,该例子说明了API调度机制。我们将创建一个自定义类型,表示一个二维标量张量,由阶数N和对角线元素的值value参数化:

class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value

   def __repr__(self):
       return "DiagonalTensor(N={}, value={})".format(self._N, self._value)

   def tensor(self):
       return self._value * torch.eye(self._N)

这个设计的第一版并不是很有用。ScalarTensor 的主要功能是提供一个比基础张量类更紧凑的标量张量字符串表示:

>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]])

如果我们尝试使用这个对象与torch API,我们将遇到问题:

>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor

__torch_function__ 添加一个实现到 ScalarTensor 使得上述操作能够成功。让我们重新进行我们的实现,这次添加一个 __torch_function__ 实现:

HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "DiagonalTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

The __torch_function__ 方法接受四个参数:func,一个指向正在被重写的torch API函数的引用,types,实现 __torch_function__ 的Tensor-like类型的列表,args,传递给函数的参数元组,以及 kwargs,传递给函数的关键字参数字典。它使用一个名为 HANDLED_FUNCTIONS 的全局调度表来存储自定义实现。该字典的键是 torch 命名空间中的函数,值是 ScalarTensor 的实现。

注意

使用全局调度表并不是__torch_function__ API的强制性部分,它只是一个有用的用于组织你的重写实现的设计模式。

这个类定义还不足以使 torch.mean 在我们传递一个 ScalarTensor 时做正确的事情 - 我们还需要为 torch.mean 定义一个实现,用于 ScalarTensor 操作数,并将该实现添加到 HANDLED_FUNCTIONS 分派表字典中。一种方法是定义一个装饰器:

import functools
def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    @functools.wraps(torch_function)
    def decorator(func):
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

这可以应用于我们重写实现中:

@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N

通过这个更改,我们现在可以使用 torch.meanScalarTensor

>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4

当然,torch.mean 是一个最简单的函数覆盖示例,因为它只接受一个操作数。我们可以使用相同的机制来覆盖接受多个操作数的函数,其中任何一个都可能是定义 __torch_function__ 的张量或类似张量,例如对于 torch.add()

def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))

此版本在两个操作数都是ScalarTensor实例时有一个快速路径,并且还有一个较慢的路径,当任一操作数不是ScalarTensor时会退化为将数据转换为张量。这使得覆盖函数在任一操作数是ScalarTensor或普通Tensor时能够正确工作:

>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
DiagonalTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
        [1., 3.]])

请注意,我们对add的实现不接受alphaout作为关键字参数,而torch.add()则接受:

>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'

为了速度和灵活性,__torch_function__ 分发机制不会检查覆盖函数的签名是否与 torch API 中被覆盖函数的签名匹配。对于某些应用来说,忽略可选参数可能是可以接受的,但为了确保与 Tensor 的完全兼容性,用户实现的 torch API 函数应仔细模拟被覆盖函数的 API。

Functions in the torch API that do not have explicit overrides will return NotImplemented from __torch_function__. If all operands with __torch_function__ defined on them return NotImplemented, PyTorch will raise a TypeError. This means that most of the time operations that do not have explicit overrides for a type will raise a TypeError when an instance of such a type is passed:

>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]

实际上这意味着,如果你想使用类似这样的__torch_function__实现来实现你的覆盖,你需要显式地实现完整的torch API 或者对你用例相关的 API 子集。这可能是一个艰巨的任务,因为完整的torch API 非常广泛。

另一个选择是,对于未处理的操作,不返回 NotImplemented,而是当没有可用覆盖时,将 Tensor 传递给原始的 torch 函数。例如,如果我们更改 __torch_function__ScalarTensor 实现为如下所示:

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
    return HANDLED_FUNCTIONS[func](*args, **kwargs)

然后 torch.mul() 将会正确工作,尽管返回类型将始终是 Tensor 而不是 ScalarTensor,即使两个操作数都是 ScalarTensor 实例:

>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
        [0., 4.]])

也请参见下面的 MetadataTensor 示例,这是该模式的另一种变体,但始终返回一个 MetadataTensor 以通过 torch API 中的操作传播元数据。

__torch_function__ 协议旨在全面覆盖 API, 部分覆盖可能导致不理想的结果,特别是某些 函数可能会抛出一个 TypeError。对于子类来说尤其如此, 其中 torch.addtorch.Tensor.__add__torch.Tensor.add 都必须被覆盖,即使它们返回完全相同的结果。未能做到这一点 也可能导致无限递归。如果需要实现来自 torch.Tensor 子类的函数, 则必须在其实现中使用 super().__torch_function__

继承 torch.Tensor

从1.7.0版本开始,在torch.Tensor上的方法和在公共torch.*命名空间中的函数应用于torch.Tensor子类时,将返回子类实例而不是torch.Tensor实例:

>>> class SubTensor(torch.Tensor):
...     pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'

如果存在多个子类,默认会选择层次结构中最低的那个。如果没有唯一的方法来确定这种情况,则会引发一个TypeError

>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]

如果希望为所有张量方法设置全局覆盖,可以使用 __torch_function__。以下是一个记录所有函数/方法调用的示例:

class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        if func is not torch.Tensor.__repr__:
            logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

然而,如果希望覆盖Tensor子类上的方法, 则可以通过直接覆盖该方法(通过为子类定义它)或使用__torch_function__并匹配 func来实现。

__torch_function__ 内,子类应始终调用 super().__torch_function__(func, ...) 而不是直接调用 func, 就像在 1.7.0 版本之前的情况一样。未能做到这一点可能会导致 func 递归回到 __torch_function__,从而导致无限递归。

扩展 torchTensor 包装类型

另一个有用的案例是包装一个 Tensor 的类型,无论是作为属性还是通过子类化。下面我们将实现这种类型的特例,即一个 MetadataTensor,它会将一个元数据字典附加到 Tensor 上,并通过 torch 操作进行传播。由于这是一种对完整 torch API 的通用包装方式,我们不需要单独实现每个重写,因此我们可以使 __torch_function__ 的实现对允许的操作更加宽松:

class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        args = [a._t if hasattr(a, '_t') else a for a in args]
        metadatas = tuple(a._metadata if hasattr(a, '_metadata') for a in args)
        assert len(metadatas) > 0
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0])

这个简单的实现不一定能与torch API中的每个函数一起工作,但它足以捕捉大多数常见操作:

>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[2, 4],
        [4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[1, 4],
        [3, 8]])

在多种类型上定义的操作 __torch_function__

可以使用torch API与多种不同的类型,每种类型都有一个__torch_function__实现,但需要特别注意。在这种情况下,规则是:

  • 调度操作收集每个操作数的所有不同实现的 __torch_function__ 并按顺序调用它们:先调用子类再调用父类,否则在运算符表达式中从左到右调用。

  • 如果返回的值不是NotImplemented,则该值将作为结果返回。实现可以注册它们不支持某个操作,通过返回NotImplemented

  • 如果所有 __torch_function__ 个实现都返回 NotImplemented,PyTorch 将引发一个 TypeError

测试PyTorch API覆盖的重写

实现__torch_function__的一个麻烦之处是,如果某些操作有覆盖而其他操作没有,则用户在最坏的情况下会在运行时看到错误,当他们使用没有覆盖的函数时。为了简化这个过程,PyTorch提供了一个面向开发者的API,以确保对__torch_function__覆盖的全面支持。此API是私有的,并且将来可能会在没有任何警告的情况下进行更改。

首先,要获取所有可覆盖函数的列表,请使用 torch.overrides._get_overridable_functions。这将返回一个字典,其键是 PyTorch Python API 中的命名空间,其值是该命名空间中可以被覆盖的函数列表。例如,让我们打印出 torch.nn.functional 中前 5 个可被覆盖的函数名称:

>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
 'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']

这个函数列表使得可以迭代所有可重写的功能,但在实际操作中,这还不足以在不费力地手动复制每个测试的每个函数签名的情况下为所有这些功能编写测试。为了简化这个过程,torch.overrides._get_testing_overrides 函数返回一个字典,将 PyTorch API 中的可重写函数映射到具有与原始函数相同签名但无条件返回 -1 的虚拟 lambda 函数。这些函数最适用于与 inspect 一起使用,以分析原始 PyTorch 函数的函数签名:

>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>

最后,torch.overrides.get_ignored_functions 返回一个函数元组, 这些函数明确不能被 __torch_function__ 覆盖。这个列表可以 用来确认那些不在 get_overridable_functions 返回的字典中的函数 无法被覆盖。

编写自定义C++扩展

请参阅这个 PyTorch教程 以获取详细说明和示例。

文档可在 torch.utils.cpp_extension 中找到。

编写自定义C扩展

示例可在 此GitHub仓库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源