目录

扩展PyTorch

在本笔记中,我们将介绍扩展 torch.nn, torch.autograd, torch 以及编写自定义 C++ 扩展的方法。

添加新运算符

PyTorch 提供了一个大型的运算符库,这些运算符可以在张量上工作(例如 torch.add(), torch.sum() 等)。然而,你可能希望将一个新的自定义操作引入 PyTorch, 并使其像 PyTorch 的内置运算符一样工作。为了实现这一点,你必须 通过 Python torch.library 或 C++ TORCH_LIBRARY APIs 将自定义操作注册到 PyTorch。

请参阅PyTorch 自定义操作符登录页面以获取更多详细信息。

扩展 torch.autograd

autograd 添加操作需要为每个操作实现一个新的 Function 子类。回想一下,函数是 autograd 用来编码操作历史和计算梯度的方式。

本文档的第一部分专注于反向模式自动微分,因为这是最广泛使用的功能。在最后的部分讨论了前向模式自动微分的扩展。

何时使用

一般来说,如果你想在模型中执行不可微分的计算或依赖非PyTorch库(例如NumPy)的计算,但仍然希望你的操作能够与其他操作链接并使用自动梯度引擎,那么可以实现一个自定义函数。

在某些情况下,自定义函数也可以用于提高性能和 内存使用:如果你使用C++扩展实现了前向和后向传播, 你可以将它们包装在 Function 中以与自动求导 引擎接口。如果你想减少为反向传播保存的缓冲区数量, 可以使用自定义函数将操作组合在一起。

何时不应使用

如果你已经能够用PyTorch内置的操作来编写你的函数,那么它的反向图(很可能)已经被autograd记录下来了。在这种情况下,你不需要自己实现反向函数。考虑使用普通的Python函数。

如果你需要维护状态,即,可训练参数,你应该(也)使用一个自定义模块。有关扩展 torch.nn 的更多信息,请参见下面的部分。

如果您想在反向传播过程中修改梯度或执行副作用,请考虑注册一个 张量模块钩子。

如何使用

请按照以下步骤操作: 1. 继承Function并实现forward()方法, (可选)setup_context()backward()方法。 2. 在ctx参数上调用适当的方法。 3. 声明您的函数是否支持 反向传播两次。 4. 使用gradcheck验证梯度是否正确。

步骤1: 在继承 Function 后,您需要定义3个方法:

  • forward() 是执行操作的代码。它可以接受任意数量的参数,其中一些可以是可选的,如果你指定了默认值。这里接受各种类型的Python对象。 Tensor 跟踪历史(即,带有 requires_grad=True)的参数将在调用之前转换为不跟踪历史的参数,并且它们的使用将被记录在图中。请注意,这种逻辑不会遍历列表/字典/任何其他数据结构,只会考虑作为调用直接参数的张量。你可以 返回一个单一的 Tensor 输出,或者如果存在多个输出,则返回一个 tuple 的 张量。此外,请参考 Function 的文档,以查找只能从 forward() 调用的有用方法的描述。

  • setup_context() (可选)。可以编写一个“组合”的 forward(),它接受一个 ctx 对象,或者(从 PyTorch 2.0 开始)编写一个不接受 ctx 的独立 forward() 和一个 setup_context() 方法,在该方法中进行 ctx 修改。 forward() 应包含计算逻辑,而 setup_context() 只应负责 ctx 修改(并且不应包含任何计算逻辑)。 通常,独立的 forward()setup_context() 更接近于 PyTorch 原生操作的工作方式,因此与各种 PyTorch 子系统更易于组合。 有关更多详细信息,请参阅 组合或分离 forward() 和 setup_context()

  • backward()(或vjp())定义了梯度公式。 它将接收与输出数量相同的Tensor参数,每个参数代表相对于该输出的梯度。重要的是绝对不要原地修改这些参数。它应该返回与输入数量相同的张量,每个张量包含相对于其对应输入的梯度。如果你的输入不需要计算梯度(needs_input_grad是一个布尔值元组,表示每个输入是否需要梯度计算),或者它们是非Tensor对象,你可以返回python:None。此外,如果你有可选参数传递给forward(),你可以返回比输入更多的梯度,只要它们都是None

步骤2: 您有责任正确使用ctx中的函数,以确保新的Function与自动梯度引擎正常工作。

  • save_for_backward() 必须用于保存将在反向传播中使用的任何张量。非张量应直接存储在ctx上。如果保存了既不是输入也不是输出的张量以备反向传播,那么您的Function可能不支持二次反向传播(参见步骤3)。

  • mark_dirty() 必须用于标记任何被前向函数就地修改的输入。

  • mark_non_differentiable() 必须 用于告知引擎输出是否不可微分。默认情况下,所有可微分类型的输出张量都将设置为需要梯度。非可微分类型(即整数类型)的张量永远不会被标记为需要梯度。

  • set_materialize_grads() 可以用于告诉自动梯度计算引擎在输出不依赖于输入的情况下优化梯度计算,方法是不在反向传播函数中生成给定的梯度张量。也就是说,如果设置为False,在Python中将不会将None对象或在C++中的“未定义张量”(对于x.defined()返回False的张量x)转换为填充零的张量,因此你的代码需要像处理填充零的张量一样处理这些对象。此设置的默认值为True。

步骤3: 如果您的 Function 不支持双重反向传播, 您应该通过使用 once_differentiable() 装饰器显式声明这一点。使用此装饰器后,尝试通过您的函数进行双重反向传播将产生错误。 有关双重反向传播的更多信息,请参阅我们的双重反向传播教程。

步骤4: 建议您使用 torch.autograd.gradcheck() 来检查您的反向函数是否正确计算了前向的梯度,通过使用您的反向函数计算雅可比矩阵,并与使用有限差分法数值计算的雅可比矩阵逐元素进行比较。

示例

在下面你可以找到一个Linear函数的代码,带有 额外的注释:

# Inherit from Function
class LinearFunction(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

现在,为了更方便地使用这些自定义操作,我们建议要么为它们创建别名,要么将它们封装在函数中。将它们封装在函数中可以让我们支持默认参数和关键字参数:

# Option 1: alias
linear = LinearFunction.apply

# Option 2: wrap in a function, to support default args and keyword args.
def linear(input, weight, bias=None):
    return LinearFunction.apply(input, weight, bias)

在这里,我们提供了一个额外的例子,展示了一个由非张量参数化的函数:

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        # ctx is a context object that can be used to stash information
        # for backward computation
        tensor, constant = inputs
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

在这里,我们通过调用 set_materialize_grads(False) 来优化上述示例:

class MulConstant(Function):
    @staticmethod
    def forward(tensor, constant):
        return tensor * constant

    @staticmethod
    def setup_context(ctx, inputs, output):
        tensor, constant = inputs
        ctx.set_materialize_grads(False)
        ctx.constant = constant

    @staticmethod
    def backward(ctx, grad_output):
        # Here we must handle None grad_output tensor. In this case we
        # can skip unnecessary computations and just return None.
        if grad_output is None:
            return None, None

        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        return grad_output * ctx.constant, None

如果您需要保存在forward()中计算的任何“中间”张量, 则必须将它们作为输出返回,或者结合使用forwardsetup_context() (请参阅组合或单独的 forward() 和 setup_context())。 请注意,这意味着如果您希望梯度通过这些中间值流动, 您需要为它们定义梯度公式(另请参阅 双重反向教程 ):

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        return result

# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

注意

Inputs to backward, 即,grad_output, 也可以是跟踪历史的张量。因此,如果 backward 是用可微分操作实现的(例如,调用另一个自定义 Function),高阶导数将起作用。 在这种情况下,使用 save_for_backward 保存的张量也可以用于反向传播,并且梯度可以回流,但在 ctx 中保存的张量不会有梯度回流。 如果你需要在 ctx 中保存的张量有梯度回流,你应该 将其作为自定义 Function 的输出,并使用 save_for_backward 保存它。

你可能想要检查你实现的反向方法是否确实计算了你的函数的导数。这可以通过与使用小的有限差分进行数值逼近比较来实现:

from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
print(test)

参见数值梯度检查以了解更多关于有限差分梯度比较的详细信息。 如果你的函数用于高阶导数(区分反向传播), 你可以使用同一包中的gradgradcheck函数来检查高阶导数。

组合或分开 forward()setup_context()

定义 Function 主要有两种方式。要么:

  • 定义一个 forward() 将前向计算逻辑与 setup_context() 结合起来

  • (自PyTorch 2.0起) 定义一个单独的 forward()setup_context()

我们推荐第二种选项(分别使用 forward()setup_context()) 因为这更接近于PyTorch原生操作的实现方式,并且可以与 torch.func 转换组合。然而,我们计划在未来支持这两种方法; 将 forward()setup_context(): 结合起来会带来更多的灵活性,因为 你可以保存中间结果而不需要将它们作为输出返回。

请参阅上一节了解如何定义 Function 并分别使用 forward()setup_context()

这是一个如何定义一个 Function 并结合 forward()setup_context() 的示例:

class LinearFunction(Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(ctx, input, weight, bias=None):
        # The forward pass can use ctx.
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

前向模式自动微分

覆盖前向模式自动微分公式具有非常相似的API,但有一些不同的细微差别。 你可以实现 jvp() 函数。

它将接收与输入数量相同的 Tensor 个参数,每个参数表示相对于该输入的梯度。它应该返回与输出数量相同的张量,每个张量包含相对于其对应输出的梯度。 jvp() 将在 forward() 方法之后立即被调用,在 apply() 返回之前。

jvp()backward() 函数有一些细微的差别:

  • 您可以使用 ctx 将任何数据从 forward() 传递到 jvp() 函数。 如果该状态对于 backward() 来说不需要, 您可以通过在 jvp() 函数末尾执行 del ctx.foo 来显式释放它。

  • Pytorch深度学习框架的jvp()实现必须是可反向微分的,或者明确检查给定的前向模式梯度中没有设置requires_grad

  • jvp() 函数必须与 forward() 的视图/原地行为相匹配。 例如,如果第 i 个输入被原地修改,则第 i 个梯度必须被原地更新。 同样,如果第 j 个输出是第 k 个输入的视图。那么返回的第 j 个输出梯度必须是给定的第 k 个输入梯度的视图。

  • 由于用户无法指定需要计算哪个梯度,因此jvp()函数应该 始终为所有输出计算梯度。

  • 正向模式梯度确实会遵循由set_materialize_grads() 设置的标志,并且在禁用此功能时,你可以获得None个输入梯度。

torch.func 转换和/或 torch.vmap()

请参阅 使用 autograd.Function 扩展 torch.func 以获取详细信息。

扩展 torch.nn

nn 提供了两种类型的接口 - 模块及其函数版本。你可以通过这两种方式扩展它,但我们建议在所有类型的层中使用模块,这些层包含任何参数或缓冲区,并建议在没有参数的操作中使用函数形式,如激活函数、池化等。

添加操作的功能版本已经在上面的章节中完全涵盖了。

添加一个 Module

自从 nn 大量使用了 autograd,添加一个新的 Module 需要实现一个 Function 来执行操作并能够计算梯度。从现在开始让我们假设我们要实现一个 Linear 模块,并且我们已经实现了如上列表中的函数。需要添加的代码非常少。现在,需要实现两个函数:

  • __init__ (可选) - 接受诸如内核大小、特征数量等参数,并初始化参数和缓冲区。

  • forward() - 实例化一个 Function 并 使用它来执行操作。这与上面显示的功能包装器非常相似。

这是如何实现一个Linear模块的方法:

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

扩展 torch Python API

你可以创建自定义类型来模拟 Tensor,方法是定义一个具有与 Tensor 匹配的方法的自定义类。但是,如果你想能够将这些类型传递给顶级命名空间中的函数,如 torch.add(),该函数接受 Tensor 操作数怎么办?

如果您的自定义 Python 类型定义了一个名为 __torch_function__ 的方法,当您的自定义类实例传递给 torch 命名空间中的函数时,PyTorch 将调用您的 __torch_function__ 实现。这使得您可以为 torch 命名空间中的任何函数定义自定义实现,您的 __torch_function__ 实现可以调用这些函数,从而使您的用户能够将您的自定义类型与现有的 PyTorch 工作流一起使用,这些工作流是为 Tensor 编写的。这种方法不仅适用于与 Tensor 无关的“鸭子”类型,也适用于 Tensor 的用户定义子类。

扩展 torch 使用类似 Tensor 的类型

注意

此功能受NumPy __array_function__ 协议的启发。 详见 NumPy 文档NEP-0018 以获取更多详情。

为了具体说明这一点,让我们从一个简单的例子开始,该例子说明了API调度机制。我们将创建一个自定义类型,表示一个二维标量张量,由阶数N和对角线元素的值value参数化:

class ScalarTensor(object):
   def __init__(self, N, value):
       self._N = N
       self._value = value

   def __repr__(self):
       return "ScalarTensor(N={}, value={})".format(self._N, self._value)

   def tensor(self):
       return self._value * torch.eye(self._N)

这个设计的第一版并不是很有用。ScalarTensor 的主要功能是提供一个比基础张量类更紧凑的标量张量字符串表示:

>>> d = ScalarTensor(5, 2)
>>> d
ScalarTensor(N=5, value=2)
>>> d.tensor()
tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]])

如果我们尝试使用此对象与 torch API,我们将遇到问题:

>>> import torch
>>> torch.mean(d)
TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor

__torch_function__ 添加一个实现到 ScalarTensor 使得上述操作能够成功。让我们重新进行我们的实现,这次添加一个 __torch_function__ 实现:

HANDLED_FUNCTIONS = {}
class ScalarTensor(object):
    def __init__(self, N, value):
        self._N = N
        self._value = value

    def __repr__(self):
        return "ScalarTensor(N={}, value={})".format(self._N, self._value)

    def tensor(self):
        return self._value * torch.eye(self._N)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
            return NotImplemented
        return HANDLED_FUNCTIONS[func](*args, **kwargs)

The __torch_function__ 方法接受四个参数:func,一个指向正在被重写的torch API函数的引用,types,实现 __torch_function__ 的Tensor-like类型的列表,args,传递给函数的参数元组,以及 kwargs,传递给函数的关键字参数字典。它使用一个名为 HANDLED_FUNCTIONS 的全局调度表来存储自定义实现。该字典的键是 torch 命名空间中的函数,值是 ScalarTensor 的实现。

注意

使用全局调度表并不是__torch_function__ API的强制性部分,它只是一个有用的用于组织你的重写实现的设计模式。

这个类定义还不足以使 torch.mean 在我们传递一个 ScalarTensor 时做正确的事情 - 我们还需要为 torch.mean 定义一个实现,用于 ScalarTensor 操作数,并将该实现添加到 HANDLED_FUNCTIONS 分派表字典中。一种方法是定义一个装饰器:

import functools
def implements(torch_function):
    """Register a torch function override for ScalarTensor"""
    def decorator(func):
        functools.update_wrapper(func, torch_function)
        HANDLED_FUNCTIONS[torch_function] = func
        return func
    return decorator

这可以应用于我们重写实现中:

@implements(torch.mean)
def mean(input):
    return float(input._value) / input._N

通过这个更改,我们现在可以使用 torch.meanScalarTensor

>>> d = ScalarTensor(5, 2)
>>> torch.mean(d)
0.4

当然,torch.mean 是一个最简单的函数覆盖示例,因为它只接受一个操作数。我们可以使用相同的机制来覆盖接受多个操作数的函数,其中任何一个都可能是定义 __torch_function__ 的张量或类似张量,例如对于 torch.add()

def ensure_tensor(data):
    if isinstance(data, ScalarTensor):
        return data.tensor()
    return torch.as_tensor(data)

@implements(torch.add)
def add(input, other):
   try:
       if input._N == other._N:
           return ScalarTensor(input._N, input._value + other._value)
       else:
           raise ValueError("Shape mismatch!")
   except AttributeError:
       return torch.add(ensure_tensor(input), ensure_tensor(other))

此版本在两个操作数都是ScalarTensor实例时有一个快速路径,并且还有一个较慢的路径,当任一操作数不是ScalarTensor时会退化为将数据转换为张量。这使得覆盖函数在任一操作数是ScalarTensor或普通Tensor时能够正确工作:

>>> s = ScalarTensor(2, 2)
>>> torch.add(s, s)
ScalarTensor(N=2, value=4)
>>> t = torch.tensor([[1, 1,], [1, 1]])
>>> torch.add(s, t)
tensor([[3., 1.],
        [1., 3.]])

请注意,我们对add的实现不接受alphaout作为关键字参数,而torch.add()则接受:

>>> torch.add(s, s, alpha=2)
TypeError: add() got an unexpected keyword argument 'alpha'

为了速度和灵活性,__torch_function__ 分发机制不会检查覆盖函数的签名是否与在 torch API 中被覆盖的函数的签名匹配。对于某些应用来说,忽略可选参数是可以接受的,但为了确保与 Tensor 的完全兼容性,用户实现的 torch API 函数应该注意精确模仿被覆盖函数的 API。

torch API 中没有显式重写的函数将 从 __torch_function__ 返回 NotImplemented。如果所有定义了 __torch_function__ 的操作数都返回 NotImplemented,PyTorch 将 抛出一个 TypeError。这意味着大多数情况下,对于没有 为类型显式重写的操作,在传递该类型的实例时将抛出一个 TypeError

>>> torch.mul(s, 3)
TypeError: no implementation found for 'torch.mul' on types that
implement __torch_function__: [ScalarTensor]

在实际操作中,这意味着如果你想使用类似于__torch_function__的实现来实现你的重写,你需要明确地实现完整的torch API或你关心的API子集以满足你的使用场景。这可能是一个艰巨的任务,因为完整的torch API相当广泛。

另一种选择是对于未处理的操作不返回 NotImplemented,而是当没有可用的重写时,将一个 Tensor 传递给原始的 torch 函数。例如,如果我们更改 __torch_function__ 对于 ScalarTensor 的实现为以下内容:

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
    if kwargs is None:
        kwargs = {}
    if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, ScalarTensor))
            for t in types
        ):
        args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
        return func(*args, **kwargs)
    return HANDLED_FUNCTIONS[func](*args, **kwargs)

然后 torch.mul() 将会正确工作,尽管返回类型将始终是 Tensor 而不是 ScalarTensor,即使两个操作数都是 ScalarTensor 实例:

>>> s = ScalarTensor(2, 2)
>>> torch.mul(s, s)
tensor([[4., 0.],
        [0., 4.]])

还可以参见下面的 MetadataTensor 示例,了解这种模式的另一种变化,但始终返回一个 MetadataTensor 以通过操作传播元数据到 torch API 中。

__torch_function__ 协议旨在全面覆盖 API, 部分覆盖可能导致不理想的结果,特别是某些 函数可能会抛出一个 TypeError。对于子类来说尤其如此, 其中 torch.addtorch.Tensor.__add__torch.Tensor.add 都必须被覆盖,即使它们返回完全相同的结果。未能做到这一点 也可能导致无限递归。如果需要实现来自 torch.Tensor 子类的函数, 则必须在其实现中使用 super().__torch_function__

继承 torch.Tensor

从1.7.0版本开始,在torch.Tensor上的方法和在公共torch.*命名空间中的函数应用于torch.Tensor子类时,将返回子类实例而不是torch.Tensor实例:

>>> class SubTensor(torch.Tensor):
...     pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.tensor([1]))).__name__
'SubTensor'

如果存在多个子类,默认会选择层次结构中最低的那个。如果没有唯一的方法来确定这种情况,则会引发一个TypeError

>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]

如果希望为所有张量方法设置全局覆盖,可以使用 __torch_function__。以下是一个记录所有函数/方法调用的示例:

class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        if func is not torch.Tensor.__repr__:
            logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

然而,如果希望覆盖Tensor子类上的方法, 则可以通过直接覆盖该方法(通过为子类定义它)或使用__torch_function__并匹配 func来实现。

__torch_function__ 内,子类应始终调用 super().__torch_function__(func, ...) 而不是直接调用 func, 就像在 1.7.0 版本之前的情况一样。未能做到这一点可能会导致 func 递归回到 __torch_function__,从而导致无限递归。

扩展 torch 使用一个 Tensor 包装器类型

另一个有用的案例是包装一个 Tensor 的类型,无论是作为属性还是通过子类化。下面我们将实现这种类型的特殊情况,即一个 MetadataTensor ,它将一个元数据字典附加到一个 Tensor 上,并在 torch 操作中传播。由于这是对整个 torch API 的通用包装,我们不需要单独实现每个重写,因此我们可以使 __torch_function__ 实现对允许的操作更加宽容:

class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        metadatas = tuple(a._metadata for a in args if hasattr(a, '_metadata'))
        args = [getattr(a, '_t', a) for a in args]
        assert len(metadatas) > 0
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0])

这个简单的实现不一定适用于 torch API 中的每个函数,但它足以捕捉大多数常见的操作:

>>> metadata = {'owner': 'Ministry of Silly Walks'}
>>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
>>> t = torch.tensor([[1, 2], [1, 2]])
>>> torch.add(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[2, 4],
        [4, 6]])
>>> torch.mul(t, m)
Metadata:
{'owner': 'Ministry of Silly Walks'}

data:
tensor([[1, 4],
        [3, 8]])

在多种类型上定义的操作 __torch_function__

可以使用torch API与多种不同的类型,每种类型都有一个__torch_function__实现,但需要特别注意。在这种情况下,规则是:

  • 调度操作收集每个操作数的所有不同实现的 __torch_function__ 并按顺序调用它们:先调用子类再调用父类,否则在运算符表达式中从左到右调用。

  • 如果返回的值不是NotImplemented,则该值将作为结果返回。实现可以注册它们不支持某个操作,通过返回NotImplemented

  • 如果所有 __torch_function__ 个实现都返回 NotImplemented,PyTorch 将引发一个 TypeError

测试PyTorch API覆盖的重写

实现__torch_function__的一个麻烦之处是,如果某些操作有覆盖而其他操作没有,则用户在最坏的情况下会在运行时看到错误,当他们使用没有覆盖的函数时。为了简化这个过程,PyTorch提供了一个面向开发者的API,以确保对__torch_function__覆盖的全面支持。此API是私有的,并且将来可能会在没有任何警告的情况下进行更改。

首先,要获取所有可重写函数的列表,请使用 torch.overrides._get_overridable_functions。这将返回一个字典,其键是 PyTorch Python API 中的命名空间,值是该命名空间中可以被重写的函数列表。例如,让我们打印出 torch.nn.functional 中可以被重写的前 5 个函数的名称:

>>> from torch.overrides import get_overridable_functions
>>> func_dict = get_overridable_functions()
>>> nn_funcs = func_dict[torch.nn.functional]
>>> print([f.__name__ for f in nn_funcs[:5])
['adaptive_avg_pool1d', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d',
 'adaptive_max_pool1d', 'adaptive_max_pool1d_with_indices']

这个函数列表使得可以迭代所有可重写的功能,但在实际操作中,这还不足以在不费力地手动复制每个测试的每个函数签名的情况下为所有这些功能编写测试。为了简化这个过程,torch.overrides._get_testing_overrides 函数返回一个字典,将 PyTorch API 中的可重写函数映射到具有与原始函数相同签名但无条件返回 -1 的虚拟 lambda 函数。这些函数最适用于与 inspect 一起使用,以分析原始 PyTorch 函数的函数签名:

>>> import inspect
>>> from torch.overrides import get_testing_overrides
>>> override_dict = get_testing_overrides()
>>> dummy_add = override_dict[torch.add]
>>> inspect.signature(dummy_add)
<Signature (input, other, out=None)>

最后,torch.overrides.get_ignored_functions 返回一个函数元组 这些函数明确不能被 __torch_function__ 覆盖。这个列表可以 用于确认不在由 get_overridable_functions 返回的字典中的函数 不能被覆盖。

扩展 torch 本地API

虽然 __torch_function__ 允许有效地扩展PyTorch的纯Python组件的行为,但它不允许扩展用C++实现的部分。为此,Tensor 子类还可以定义 __torch_dispatch__,这将能够在C++级别上覆盖行为。

要有效使用此功能,了解PyTorch的本地部分是如何实现的是很重要的。其中最重要的组件是我们所谓的“调度器”(最好的描述可以在这篇博客文章中找到,尽管它有些过时)。顾名思义,它负责为特定函数调用选择正确的后端函数。例如,在调用torch.add(a, b)时,调度器会检查两个参数,确定应为此特定调用使用哪种“功能”(自动求导、自动转换、函数化等)和哪种“后端”(CPU、CUDA、MPS等),并最终调用所有正确的内核。 一个非常常见的操作是内核进行“重新调度”。例如,当在GPU上使用自动转换运行神经网络时,第一次调用将是自动转换内核,它将处理任何潜在的自动转换逻辑并向下重新调度。接下来的功能是自动求导,它将正确创建自动求导图并再次向下重新调度。最后,我们到达CUDA的后端内核,它将启动正确的CUDA内核并返回最终结果。在退出时,自动求导会将图附加到输出上,最后,自动转换有机会在退出时进行任何必要的更新。

调度器的一种配置是调用所有这些特征和后端键的顺序。最新的列表及其顺序可以在DispatchKey.h中的DispatchKey枚举中找到。为了扩展torch,对于本次讨论而言,重要的是排序的一个子集:

vmap -> 自动类型转换 -> 自动梯度 -> 零张量 -> 取负/共轭 -> 功能化 -> Python -> 后端

对于本次讨论而言,最重要的键是 Python,因为每个定义了 __torch_dispatch__ 方法的 Tensor 子类都会调用此功能。从那里开始,将调用用户定义的方法,并且可以任意重写行为。从那里再次调用提供的 func 将执行“重新调度”。

此实现的一些重要含义是:

  • 这段代码运行在“所有功能之下”。因此,它仅负责生成每个张量的输出值(就像一个普通的后端一样),并且可以(也应该)忽略所有高级功能,如自动梯度计算、自动类型转换等。

  • 如果任何高级功能在不重新调度的情况下实现了给定的功能,它将永远不会到达Python键,因此__torch_dispatch__回调将永远不会被触发。这种情况特别发生在CompositeImplicitAutograd函数中,这些函数在自动梯度级别上进行评估而不重新调度。这是因为CompositeImplicitAutograd函数通过隐式调用其他原生操作来指定其自动梯度公式,所以在自动梯度级别上,该函数被分解为其原生操作并进行评估。

  • 在回调到Python以及包装结果时,使用的转换与常规的PyTorch Python/C++绑定相同。特别是,某些对象无法在Python中表示,需要特殊处理(例如,未定义的张量会变成None)。

  • 我们的本地函数被惰性填充为 torch.ops.{namespace}.{func_name}.{overload_name} 个可调用的 Python 对象,以便能够从 Python 轻松与它们交互。传递给 __torch_dispatch__func 对象始终是此命名空间中的一个条目。此命名空间可以直接调用本地操作,并绕过通常的 Python API 和绑定代码。

__torch_function__能够拦截所有torch的Python API和Tensor方法类似,__torch_dispatch__能够拦截所有调用到aten原生API。请注意,在进入调度器之前,所有Tensor上的方法都会被转换为函数调用,因此在这里将显示为函数调用:torch.add(a, 2)a + 2将导致完全相同的aten调用。 这些函数中的大多数在native_functions.yaml中定义,该文件指定了这些函数的属性及其后端实现。它们的实现以及指定的功能然后通过代码生成自动注册。 一些更特殊的函数或功能也在C++代码库的其他位置或用户自定义的C++扩展中注册。

也可以使用torch.library添加new个原生函数。此Python特性允许定义和/或添加新的实现到原生函数中。这可以用于添加缺失的内核,替换现有的内核或定义全新的原生函数。

您可以在__torch_dispatch__子类动物园仓库中找到许多基于该子类的示例。

__torch_dispatch__ 调用约定

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    pass

当用户使用具有 __torch_dispatch__ 的输入调用操作符时,该调用 可能会被转发到 __torch_dispatch__。在调用 __torch_dispatch__ 之前,args 和 kwargs 会被规范化,也就是说:

  • kwargs 包含运算符模式中的仅关键字参数。 如果一个关键字参数等于其默认值(在模式中),则不会传递它。

  • the args 包含所有其他参数,无论它们是如何传递给操作符的(位置参数与关键字参数)。 如果一个参数等于其默认值,并且 它是最右边的位置参数或者它右边的所有参数都没有被传递,那么它将不会被传递。

使用模式扩展所有 torch API

不幸的是,有些函数不接受张量输入。这意味着上述的子类方法无法用于覆盖PyTorch所有函数的行为。此外,如果使用场景需要拦截每个函数调用,将每个张量都改为子类可能会过于侵入性。

为了解决这个用例,我们引入了“模式”的概念。这些存在于__torch_function____torch_dispatch__的覆盖中,分别通过子类化torch.overrides.TorchFunctionModetorch.utils._python_dispatch.TorchDispatchMode创建,并作为上下文管理器使用。

为了简化其与子类和其他模式交互的描述,每当进入某个模式的上下文管理器时,每个函数的行为就好像在参数列表的开头有一个额外的Tensor参数,该参数以模式作为子类。 这意味着特别是所有模式处理程序将在任何子类处理程序之前被调用,并且与内部上下文管理器对应的模式将始终首先运行。

需要注意的是,在给定的模式处理器中,此特定模式被禁用,并且可以通过执行 with self: 手动重新启用。

这是一个示例,展示了每种类型的日志记录模式:

import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode

class FunctionLog(TorchFunctionMode):
    def __torch_function__(self, func, types, args, kwargs=None):
        print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

class DispatchLog(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args, kwargs=None):
        print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
        return func(*args, **(kwargs or {}))

def f():
    a = torch.rand(10, requires_grad=True)
    b = a * 2
    b.sum().backward()

print("TorchFunctionMode logging:")
with FunctionLog():
    f()

print("TorchDispatchMode logging:")
with DispatchLog():
    f()

这将打印以下内容,并附带额外的注释:

TorchFunctionMode logging:
Function Log: torch.rand(*(10,), **{'requires_grad': True})
Function Log: torch.Tensor.mul(*(tensor([0.7164, 0.9897, 0.1745, 0.9336, 0.4287, 0.7989, 0.2169, 0.7474, 0.5624,
        0.5970], requires_grad=True), 2), **None)
Function Log: torch.Tensor.sum(*(tensor([1.4328, 1.9794, 0.3490, 1.8671, 0.8573, 1.5977, 0.4338, 1.4948, 1.1249,
        1.1939], grad_fn=<MulBackward0>),), **None)
# Note that at the python level, we only see the call to backward but not what happens in the autograd engine.
Function Log: torch.Tensor.backward(*(tensor(12.3307, grad_fn=<SumBackward0>),), **{'gradient': None, 'retain_graph': None, 'create_graph': False, 'inputs': None})

TorchDispatchMode logging:
# Here the requires_grad flag from autograd is removed while default arguments were populated.
Dispatch Log: aten.rand.default(*([10],), **{'device': device(type='cpu'), 'pin_memory': False})
Dispatch Log: aten.mul.Tensor(*(tensor([0.2151, 0.6018, 0.8415, 0.9060, 0.2974, 0.7708, 0.6668, 0.0352, 0.7948,
        0.6023], requires_grad=True), 2), **{})
Dispatch Log: aten.sum.default(*(tensor([0.4303, 1.2036, 1.6831, 1.8120, 0.5949, 1.5416, 1.3335, 0.0705, 1.5897,
        1.2046], grad_fn=<MulBackward0>),), **{})
# Here we don't see the call to backward itself, but its constituents. Starting here with the factory function that creates the initial gradient.
Dispatch Log: aten.ones_like.default(*(tensor(11.4637, grad_fn=<SumBackward0>),), **{'pin_memory': False, 'memory_format': torch.preserve_format})
# This is the backward of the sum
Dispatch Log: aten.expand.default(*(tensor(1.), [10]), **{})
Dispatch Log: aten.mul.Tensor(*(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), 2), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})
Dispatch Log: aten.detach.default(*(tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]),), **{})

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源