目录

自动求导机制

本笔记将概述autograd的工作原理以及如何记录操作。虽然严格来说,理解这些内容并非必要,但我们建议您熟悉它们,因为这将帮助您编写更高效、更简洁的程序,并且在调试时也能提供帮助。

自动微分如何记录历史

Autograd 是一个反向自动微分系统。概念上,当您执行操作时,autograd 会记录创建数据的所有操作的图,为您提供一个有向无环图,其叶子是输入张量,根是输出张量。通过从根到叶跟踪此图,您可以使用链式法则自动计算梯度。

在内部,autograd 将此图表示为由 Function 个对象(实际上是表达式)组成的图,这些对象可以被 apply() 用来计算评估图的结果。在计算前向传播时,autograd 同时执行请求的计算并构建一个表示梯度计算函数的图(每个 torch.Tensor.grad_fn 属性是进入该图的入口点)。当前向传播完成后,我们在反向传播中评估此图以计算梯度。

需要注意的是,图在每次迭代时都会从头开始重建,这正是允许使用任意Python控制流语句的原因,这些语句可以在每次迭代中改变图的总体形状和大小。你不需要在启动训练之前对所有可能的路径进行编码——你运行的就是你要进行微分的内容。

保存的张量

一些操作需要在前向传递期间保存中间结果以便执行反向传递。例如,该函数 xx2x\mapsto x^2 保存输入 xx 以计算梯度。

在定义自定义的Python Function时,你可以使用 save_for_backward() 在前向传播过程中保存 张量,并使用saved_tensors 在反向传播过程中 检索它们。更多信息请参阅扩展PyTorch

对于PyTorch定义的操作(例如 torch.pow()),张量会根据需要自动保存。你可以通过查找以 _saved 前缀开头的属性,来探索某个 grad_fn 保存了哪些张量(用于教育或调试目的)。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))  # True
print(x is y.grad_fn._saved_self)  # True

在之前的代码中,y.grad_fn._saved_self 指的是与x相同的张量对象。 但这并不总是如此。例如:

x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

在内部,为了防止引用循环,PyTorch 在保存时将张量 打包 并在读取时将其 解包 到另一个张量中。在这里,你通过访问 y.grad_fn._saved_result 获得的张量与 y 是不同的张量对象(但它们仍然共享相同的存储)。

张量是否会打包到不同的张量对象中,取决于它是否是其自身的grad_fn输出,这是一个实现细节,可能会发生变化,用户不应依赖于此。

您可以使用 保存的张量钩子 控制 PyTorch 如何进行打包 / 解包。

非可微函数的梯度

使用自动微分进行梯度计算仅在所使用的每个基本函数都是可微的情况下才有效。 不幸的是,我们在实践中使用的许多函数都不具备这一特性(例如在 0 处的 relusqrt)。 为了尽量减少不可微函数的影响,我们通过依次应用以下规则来定义基本操作的梯度:

  1. 如果该函数是可微的,因此在当前点存在梯度,则使用它。

  2. 如果函数是凸的(至少在局部是),使用最小范数的次梯度(它是最陡下降方向)。

  3. 如果函数是凹函数(至少在局部范围内),使用最小范数的超梯度(考虑-f(x)并应用前一个点)。

  4. 如果函数已定义,则通过连续性在当前点定义梯度(注意这里 inf 是可能的,例如对于 sqrt(0))。如果有多个值是可能的,则任意选择一个。

  5. 如果函数未定义(例如输入为sqrt(-1)log(-1)或大多数函数在输入为NaN时),则用作梯度的值是任意的(我们可能会引发错误,但这不能保证)。大多数函数将使用NaN作为梯度,但由于性能原因,某些函数将使用其他值(例如log(-1))。

  6. 如果函数不是确定性映射(即它不是一个数学函数),它将被标记为不可微分。这将导致在反向传播时出错,如果它被用于需要梯度的张量上且不在no_grad环境中。

局部禁用梯度计算

有几种机制可以从Python中局部禁用梯度计算:

要禁用整个代码块的梯度,可以使用上下文管理器 如无梯度模式和推理模式。 对于更精细地排除子图的梯度计算, 可以设置张量的requires_grad字段。

在下面的内容中,除了讨论上述机制外,我们还描述了评估模式(nn.Module.eval()),这是一种不用于禁用梯度计算的方法,但由于其名称,经常与上述三种方法混淆。

设置 requires_grad

requires_grad 是一个标志,默认为false,除非被包裹在 nn.Parameter 中,它允许对梯度计算中的子图进行细粒度的排除。它在前向和后向传递中都生效:

在前向传递过程中,只有当其输入张量中至少有一个需要梯度时,才会将操作记录在反向图中。 在反向传递(.backward())期间,只有具有requires_grad=True的叶张量才会将其梯度累积到它们的.grad字段中。

需要注意的是,尽管每个张量都有这个标志, 设置它只对叶张量(没有 grad_fn的张量,例如,一个nn.Module的参数)有意义。 非叶张量(有grad_fn的张量)是具有与之关联的反向图的张量。因此,它们的梯度将作为中间结果来计算需要梯度的叶张量的梯度。从这个定义来看,所有非叶张量都将自动具有require_grad=True

设置 requires_grad 应该是你控制模型哪些部分参与梯度计算的主要方式,例如,如果你需要在模型微调期间冻结预训练模型的部分。

要冻结模型的部分,只需将 .requires_grad_(False) 应用于 你不想更新的参数。正如上面所述, 由于使用这些参数作为输入的计算不会在 前向传递中记录,因此它们不会在反向 传递中更新其 .grad 字段,因为它们从一开始就不会成为反向图的一部分,这正是我们所期望的。

由于这是非常常见的模式,requires_grad 也可以在模块级别设置为 nn.Module.requires_grad_()。 当应用于模块时,.requires_grad_() 对模块的所有参数生效(这些参数默认具有 requires_grad=True)。

梯度模式

除了设置 requires_grad 之外,还有三种可以从Python中选择的梯度模式,这些模式会影响PyTorch中的计算如何被autograd内部处理:默认模式(梯度模式)、无梯度模式和推理模式,所有这些都可以通过上下文管理器和装饰器进行切换。

模式

排除操作不被记录在反向图中

跳过额外的自动梯度跟踪开销

在启用模式下创建的张量可以在以后的梯度模式中使用。

示例

默认

前向传播

no-grad

优化器更新

推理

数据处理,模型评估

默认模式(梯度模式)

“默认模式”是指在没有启用其他模式(如无梯度模式和推理模式)时我们隐式处于的模式。为了与“无梯度模式”进行对比,默认模式有时也被称为“梯度模式”。

关于默认模式最重要的一点是,只有在这种模式下 requires_grad 才会生效。在其他两种模式中,requires_grad 总是被覆盖为 False

无梯度模式

在无梯度模式下的计算行为就像没有任何输入需要梯度一样。 换句话说,即使有输入具有 require_grad=True,无梯度模式下的计算也永远不会被记录在反向图中。

当你需要执行不应被自动梯度记录的操作,但你仍然希望在稍后以梯度模式使用这些计算的输出时,启用无梯度模式。这个上下文管理器使得在代码块或函数中禁用梯度变得方便,而无需临时将张量设置为 requires_grad=False,然后再恢复为 True

例如,在编写优化器时,无梯度模式可能会很有用:在执行训练更新时,你希望原地更新参数而不被自动求导记录。你还打算在下一次前向传播中使用更新后的参数进行计算,并且这些计算需要在有梯度模式下进行。

torch.nn.init 中的实现同样依赖于无梯度模式来初始化参数,以避免在原地更新已初始化参数时进行自动求导跟踪。

推理模式

推理模式是无梯度模式的极端版本。就像在无梯度模式下一样,推理模式中的计算不会被记录在反向图中,但启用推理模式将允许PyTorch进一步加速你的模型。这种更好的运行时性能伴随着一个缺点:在退出推理模式后,推理模式中创建的张量将无法用于由自动求导记录的计算中。

启用推理模式时,您正在进行的计算没有与自动梯度(autograd)交互,并且您不打算在以后由自动梯度记录的任何计算中使用在推理模式下创建的张量。

建议你在代码中不需要自动梯度跟踪的部分(例如数据处理和模型评估)尝试使用推理模式。如果它在你的使用场景中可以直接工作,那么这是一个免费的性能提升。如果你在启用推理模式后遇到错误,请检查你是否没有在退出推理模式后被自动梯度记录的计算中使用在推理模式下创建的张量。如果你无法避免在你的案例中这样使用,你可以随时切换回无梯度模式。

有关推理模式的详细信息,请参阅 推理模式

有关推理模式的实现细节,请参见 RFC-0011-InferenceMode

评估模式 (nn.Module.eval())

评估模式并不是一种局部禁用梯度计算的机制。 尽管如此,这里还是包含了它,因为它有时会被误认为是这样的机制。

功能上,module.eval()(或等效地module.train(False))与无梯度模式和推理模式完全正交。model.eval()如何影响你的模型完全取决于你的模型中使用的特定模块以及它们是否定义了任何特定于训练模式的行为。

您负责调用 model.eval()model.train(),如果您的 模型依赖于诸如 torch.nn.Dropouttorch.nn.BatchNorm2d 等模块, 这些模块在训练模式下可能会有不同的行为,例如,为了避免在验证数据上更新您的 BatchNorm 运行统计信息。

建议您在训练时始终使用 model.train(),在评估模型(验证/测试)时使用 model.eval(),即使您不确定您的模型是否具有特定于训练模式的行为,因为您使用的模块可能会更新为在训练和评估模式下表现出不同的行为。

原地操作与自动梯度

在自动梯度计算中支持原地操作是一件困难的事情,我们不建议在大多数情况下使用它们。自动梯度的激进缓冲区释放和重用使其非常高效,并且很少有情况下原地操作能显著降低内存使用量。除非你在承受巨大的内存压力,否则你可能永远不需要使用它们。

有两个主要原因限制了原地操作的适用性:

  1. 原地操作可能会覆盖计算梯度所需的数据。

  2. 每个原地操作都需要实现重写计算图。非原地版本只是分配新对象并保留对旧图的引用,而原地操作则需要将所有输入的创建者更改为Function表示此操作。这可能会很棘手,特别是如果有许多张量引用相同的存储(例如通过索引或转置创建),并且如果修改的输入的存储被任何其他Tensor引用,原地函数将引发错误。

原地正确性检查

每个张量都保留一个版本计数器,该计数器在每次标记为任何操作中的脏数据时都会递增。当一个函数保存任何用于反向传播的张量时,也会保存其包含张量的版本计数器。一旦你访问self.saved_tensors,它就会被检查,如果它大于保存的值,则会引发错误。这确保了如果你正在使用原地函数并且没有看到任何错误,你可以确信计算出的梯度是正确的。

多线程自动梯度

自动求导引擎负责运行所有必要的反向操作以计算反向传播。本节将描述所有可以帮助你在多线程环境中充分利用它的细节。(这仅适用于PyTorch 1.6+,因为之前版本的行为有所不同。)

用户可以使用多线程代码(例如 Hogwild 训练)来训练他们的模型,并且不会在并发的反向计算上阻塞,示例代码可以是:

# Define a train function to be used in different threads
def train_fn():
    x = torch.ones(5, 5, requires_grad=True)
    # forward
    y = (x + 3) * (x + 4) * 0.5
    # backward
    y.sum().backward()
    # potential optimizer update


# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
    p = threading.Thread(target=train_fn, args=())
    p.start()
    threads.append(p)

for p in threads:
    p.join()

请注意,用户应了解某些行为:

CPU上的并发性

当你在多个线程中通过Python或C++ API 在CPU上运行backward()grad()时,你期望看到额外的并发性,而不是在执行期间以特定顺序序列化所有的反向调用(这是PyTorch 1.6之前的默认行为)。

Non-determinism

如果你正在从多个线程并发调用 backward() 并且有共享输入(例如 Hogwild CPU 训练),那么应该预期到非确定性行为。 这可能是因为参数在各个线程之间自动共享,因此,多个线程可能会访问并尝试在梯度累积期间累积相同的 .grad 属性。这在技术上是不安全的,并且可能导致竞争条件,结果可能无效无法使用。

开发具有共享参数的多线程模型的用户应考虑线程模型,并理解上述描述的问题。

功能API torch.autograd.grad() 可用于计算梯度,而不是使用backward()以避免非确定性。

保留图

如果自动梯度图的一部分在多个线程之间共享,即首先在单个线程中运行前半部分,然后在多个线程中运行后半部分,则图的前半部分是共享的。在这种情况下,不同的线程在同一图上执行grad()backward()可能会出现一个问题,即一个线程在运行时破坏了图,而另一个线程在这种情况下会崩溃。自动梯度将向用户报告类似于调用backward()两次而不使用retain_graph=True的错误,并告知用户他们应该使用retain_graph=True

自动求导节点的线程安全性

由于Autograd允许调用线程驱动其反向执行以实现潜在的并行性,因此我们必须确保在CPU上使用共享部分或全部GraphTask的并行backward()调用时的线程安全。

自定义的Python autograd.Function由于GIL是自动线程安全的。 对于内置的C++ Autograd节点(例如AccumulateGrad,CopySlices)和自定义 autograd::Function,Autograd引擎使用线程互斥锁来确保 在可能有状态读/写的autograd节点上的线程安全性。

C++ 钩子不保证线程安全

自动求导依赖于用户编写线程安全的C++钩子。如果你想在多线程环境中正确应用钩子,你需要编写适当的线程锁定代码以确保钩子是线程安全的。

复数的自动求导

简短版本:

  • 当你使用PyTorch对具有复杂定义域和/或值域的任何函数f(z)f(z)进行求导时, 梯度是在假设该函数是一个更大的实值损失函数的一部分的情况下计算的g(input)=Lg(input)=L。计算出的梯度是Lz\frac{\partial L}{\partial z^*} (注意z的共轭),其负值正好是梯度下降算法中使用的最陡下降方向。因此,现有的优化器可以通过复数参数直接工作。

  • 此惯例与TensorFlow的复数微分惯例匹配,但与JAX不同(JAX计算Lz\frac{\partial L}{\partial z})。

  • 如果你有一个实数到实数的函数,其内部使用了复杂数运算,这里的约定并不重要:你将始终得到与仅使用实数运算实现时相同的结果。

如果你对数学细节感到好奇,或者想知道如何在PyTorch中定义复杂的导数,请继续阅读。

什么是复数导数?

复数可微性的数学定义采用了导数的极限定义,并将其推广以适用于复数。考虑一个函数 f:CCf: ℂ → ℂ

f(z=x+yj)=u(x,y)+v(x,y)jf(z=x+yj) = u(x, y) + v(x, y)j

其中uuvv 是两个实值变量函数 而 jj 是虚数单位。

使用导数的定义,我们可以写成:

f(z)=limh0,hCf(z+h)f(z)hf'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}

为了使这个极限存在,不仅 uuvv 必须是实可微的,而且 ff 还必须满足柯西-黎曼 方程。换句话说:用实数和虚数步长 (hh) 计算的极限必须相等。这是一个更为严格的条件。

复可微函数通常被称为全纯函数。它们性质良好,具有所有你在实可微函数中看到的优点,但在优化领域实际上毫无用处。对于优化问题,研究社区只使用实值目标函数,因为复数不属于任何有序域,因此拥有复值损失并没有太大意义。

事实证明,没有任何有趣的实值目标能满足柯西-黎曼方程。因此,基于全纯函数的理论不能用于优化,大多数人因此使用 Wirtinger 算子。

Wirtinger微积分登场……

因此,我们有一个关于复可微性和全纯函数的伟大理论,但我们无法使用其中的任何内容,因为许多常用函数都不是全纯的。一个可怜的数学家该怎么办呢?Well,Wirtinger观察到即使f(z)f(z)不是全纯的,也可以将其重写为两个变量的函数f(z,z)f(z, z*),该函数总是全纯的。这是因为zz的实部和虚部可以用zzzz^*表示:

Re(z)=z+z2Im(z)=zz2j\begin{aligned} \mathrm{Re}(z) &= \frac {z + z^*}{2} \\ \mathrm{Im}(z) &= \frac {z - z^*}{2j} \end{aligned}

Wirtinger微积分建议研究f(z,z)f(z, z^*),如果ff是实可微的,则其保证是全纯的(另一种思考方式是将其视为坐标系的变化,从f(x,y)f(x, y)f(z,z)f(z, z^*))。该函数具有偏导数z\frac{\partial }{\partial z}z\frac{\partial}{\partial z^{*}}。 我们可以使用链式法则来建立这些偏导数与zz的实部和虚部的偏导数之间的关系。

x=zxz+zxz=z+zy=zyz+zyz=1j(zz)\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * \left(\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}\right) \end{aligned}

从上述方程中,我们得到:

z=1/2(x1jy)z=1/2(x+1jy)\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * \left(\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}\right) \\ \frac{\partial }{\partial z^*} &= 1/2 * \left(\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}\right) \end{aligned}

这是您会在维基百科上找到的经典的Wirtinger微积分定义。

这一变化带来了许多美妙的结果。

  • 首先,Cauchy-Riemann 方程简单来说就是 fz=0\frac{\partial f}{\partial z^*} = 0 (也就是说,函数 ff 可以完全用 zz 表示,而无需引用 zz^*)。

  • 我们稍后会看到,另一个重要(且有些违背直觉)的结果是,在对实值损失进行优化时,我们在进行变量更新时应该采取的步骤由Lossz\frac{\partial Loss}{\partial z^*}(而不是Lossz\frac{\partial Loss}{\partial z})给出。

要进一步阅读,请查看:https://arxiv.org/pdf/0906.4835.pdf

变分法在优化中有什么用?

音频和其他领域的研究人员更常用梯度下降来优化具有复数变量的真实值损失函数。通常,这些人将实部和虚部分别视为可以更新的不同通道。对于步长α/2\alpha/2和损失LL,我们可以写出以下在R2ℝ^2中的方程:

xn+1=xn(α/2)Lxyn+1=yn(α/2)Ly\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}

这些方程如何转化为复数空间 C

zn+1=xn(α/2)Lx+1j(yn(α/2)Ly)=znα1/2(Lx+jLy)=znαLz\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * \left(\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}\right) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}

发生了一件非常有趣的事情:Wirtinger演算告诉我们,我们可以将上面的复变量更新公式简化为仅引用共轭Wirtinger导数 Lz\frac{\partial L}{\partial z^*},这正好是我们优化过程中所采取的步骤。

由于共轭 Wirtinger 导数正好为我们提供了实值损失函数的正确步长,因此 PyTorch 在对具有实值损失的函数求导时会给出这个导数。

PyTorch 如何计算共轭Wirtinger导数?

通常,我们的导数公式以grad_output作为输入, 代表我们已经计算出的传入向量-雅可比矩阵积,即Ls\frac{\partial L}{\partial s^*}, 其中LL是整个计算的损失(产生实际损失), 而ss是我们函数的输出。目标是计算Lz\frac{\partial L}{\partial z^*}, 其中zz是函数的输入。事实证明,在实际损失的情况下, 我们只需要计算Ls\frac{\partial L}{\partial s^*}, 尽管链式法则暗示我们也需要访问Ls\frac{\partial L}{\partial s}。如果你想要跳过这个推导, 请查看本节的最后一方程,然后跳到下一节。

让我们继续使用f:CCf: ℂ → ℂ,其被定义为 f(z)=f(x+yj)=u(x,y)+v(x,y)jf(z) = f(x+yj) = u(x, y) + v(x, y)j。正如上面讨论的那样, autograd 的梯度约定围绕实值损失函数的优化展开,因此我们假设ff 是更大的 实值损失函数 gg 的一部分。利用链式法则,我们可以写出:

(1)Lz=Luuz+Lvvz\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}

现在使用Wirtinger导数定义,我们可以写:

Ls=1/2(LuLvj)Ls=1/2(Lu+Lvj)\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * \left(\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j\right) \\ \frac{\partial L}{\partial s^*} = 1/2 * \left(\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j\right) \end{aligned}

需要注意的是,由于uuvv是实函数,并且根据我们的假设ff是实值函数的一部分,LL也是实数,因此我们有:

(2)(Ls)=Ls\left( \frac{\partial L}{\partial s} \right)^* = \frac{\partial L}{\partial s^*}

即,Ls\frac{\partial L}{\partial s}等于grad_outputgrad\_output^*

解上述方程可得Lu\frac{\partial L}{\partial u}Lv\frac{\partial L}{\partial v}

(3)Lu=Ls+LsLv=1j(LsLs)\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) \end{aligned}

(3) 替换 (1),我们得到:

Lz=(Ls+Ls)uz+1j(LsLs)vz=Ls(uz+vzj)+Ls(uzvzj)=Ls(u+vj)z+Ls(u+vj)z=Lssz+Lssz\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}\right) * \frac{\partial u}{\partial z^*} + 1j * \left(\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}\right) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \left(\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j\right) + \frac{\partial L}{\partial s^*} * \left(\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j\right) \\ &= \frac{\partial L}{\partial s} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}

使用 (2),我们得到:

(4)Lz=(Ls)sz+Ls(sz)=(grad_output)sz+grad_output(sz)\begin{aligned} \frac{\partial L}{\partial z^*} &= \left(\frac{\partial L}{\partial s^*}\right)^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \left(\frac{\partial s}{\partial z}\right)^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * \left(\frac{\partial s}{\partial z}\right)^* } \\ \end{aligned}

最后一个方程是你编写自己的梯度时的重要公式, 因为它将我们的导数公式分解为一个更简单的形式, 便于手动计算。

如何为一个复杂函数编写自己的导数公式?

上述方程给出了所有复函数导数的一般公式。然而,我们仍然需要计算sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}。 你可以通过两种方式来完成这个计算:

  • The first way is to just use the definition of Wirtinger derivatives directly and calculate sz\frac{\partial s}{\partial z} and sz\frac{\partial s}{\partial z^*} by using sx\frac{\partial s}{\partial x} and sy\frac{\partial s}{\partial y} (which you can compute in the normal way).

  • The second way is to use the change of variables trick and rewrite f(z)f(z) as a two variable function f(z,z)f(z, z^*), and compute the conjugate Wirtinger derivatives by treating zz and zz^* as independent variables. This is often easier; for example, if the function in question is holomorphic, only zz will be used (and sz\frac{\partial s}{\partial z^*} will be zero).

让我们以函数f(z=x+yj)=cz=c(x+yj)f(z = x + yj) = c * z = c * (x+yj)为例,其中cRc \in ℝ

使用第一种方法计算Wirtinger导数,我们得到。

sz=1/2(sxsyj)=1/2(c(c1j)1j)=csz=1/2(sx+syj)=1/2(c+(c1j)1j)=0\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * \left(\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * \left(\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j\right) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}

使用(4),和grad_output = 1.0 (这是在PyTorch中调用backward()时用于标量输出的默认梯度输出值),我们得到:

Lz=10+1c=c\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c

使用第二种方法计算Wirtinger导数,我们直接得到:

sz=(cz)z=csz=(cz)z=0\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}

再次使用 (4),我们得到 Lz=c\frac{\partial L}{\partial z^*} = c。正如你所见,第二种方法涉及较少的计算,并且更适合快速计算。

关于跨域函数呢?

有些函数是从复数输入映射到实数输出,或者反之。这些函数是(4)的一个特例,我们可以使用链式法则推导出来:

  • For f:CRf: ℂ → ℝ, we get:

    Lz=2grad_outputsz\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
  • For f:RCf: ℝ → ℂ, we get:

    Lz=2Re(grad_outputsz)\frac{\partial L}{\partial z^*} = 2 * \mathrm{Re}(grad\_output^* * \frac{\partial s}{\partial z^{*}})

保存的张量钩子

您可以控制保存的张量如何打包/解包,通过定义一对pack_hook / unpack_hook 钩子。 pack_hook 函数应将张量作为其唯一参数,但可以返回任何 Python 对象(例如另一个张量、元组,甚至是包含文件名的字符串)。 unpack_hook 函数将其唯一参数作为 pack_hook 的输出,并应返回一个在反向传播中使用的张量。由 unpack_hook 返回的张量只需要与传递给 pack_hook 的输入张量具有相同的内容。特别是,任何与 autograd 相关的元数据都可以忽略,因为它们将在解包过程中被覆盖。

这样的组对示例为:

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

注意,unpack_hook 不应删除临时文件,因为它 可能会被多次调用:临时文件应该在返回的SelfDeletingTempFile对象存活期间保持有效。 在上述示例中, 我们通过在不再需要时关闭它来防止泄露临时文件(在删除SelfDeletingTempFile对象时)。

注意

我们保证pack_hook只会被调用一次,但unpack_hook可以根据反向传播的需求多次调用,并且我们期望它每次返回相同的数据。

警告

禁止对任何函数的输入执行就地操作,因为这可能会导致意想不到的副作用。如果对包钩的输入进行了就地修改,PyTorch 会抛出错误,但不会捕获对解包钩输入进行就地修改的情况。

为保存的张量注册钩子

你可以通过调用一个 register_hooks() 方法在一个 SavedTensor 对象上注册一对钩子。这些对象作为 grad_fn 的属性暴露出来,并以 _raw_saved_ 前缀开头。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)

pack_hook方法在配对注册后立即调用。 这unpack_hook方法每次需要访问保存的张量时都会被调用,无论是通过y.grad_fn._saved_self还是在反向传播过程中。

警告

如果你在保存的张量被释放(即反向传播被调用后)仍然保留对SavedTensor的引用,那么调用其register_hooks()是不允许的。 PyTorch 通常会抛出一个错误,但在某些情况下可能会失败,并且可能会出现未定义的行为。

为保存的张量注册默认钩子

或者,你可以使用上下文管理器 saved_tensors_hooks 来注册一对钩子, 这些钩子将应用于在该上下文中创建的所有保存的张量。

Example:

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class Model(nn.Module):
    def forward(self, x):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
          # ... compute output
          output = x
        return output

model = Model()
net = nn.DataParallel(model)

在此上下文中定义的钩子是线程局部的。 因此,以下代码不会产生预期的效果,因为钩子不会经过DataParallel

# Example what NOT to do

net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = net(input)

请注意,使用这些钩子会禁用所有减少 Tensor 对象创建的优化。例如:

with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
    x = torch.randn(5, requires_grad=True)
    y = x * x

没有钩子时,xy.grad_fn._saved_selfy.grad_fn._saved_other 都是指向同一个张量对象。 有了钩子后,PyTorch 会将 x 打包和解包为两个新的张量对象, 这些对象与原始的 x 共享相同的存储(未进行复制)。

反向钩子执行

本节将讨论不同钩子在何时触发或不触发。 然后将讨论它们触发的顺序。 将涵盖的钩子有:通过torch.Tensor.register_hook()注册到张量的反向传播钩子, 通过torch.Tensor.register_post_accumulate_grad_hook()注册到张量的后累加梯度钩子, 通过torch.autograd.graph.Node.register_hook()注册到节点的后钩子, 以及通过torch.autograd.graph.Node.register_prehook()注册到节点的前钩子。

特定的钩子是否会被触发

通过torch.Tensor.register_hook()注册到张量上的钩子在计算该张量梯度时会被执行。(注意,这不需要执行张量的grad_fn。例如,如果张量作为inputs参数的一部分传递给torch.autograd.grad(),张量的grad_fn可能不会被执行,但注册到该张量上的钩子始终会被执行。)

通过torch.Tensor.register_post_accumulate_grad_hook()注册的钩子在张量的梯度累积完成后执行,这意味着该张量的grad字段已被设置。而通过torch.Tensor.register_hook()注册的钩子在计算梯度时运行,通过torch.Tensor.register_post_accumulate_grad_hook()注册的钩子只有在autograd在反向传播结束时更新了张量的grad字段后才会被触发。因此,post-accumulate-grad钩子只能为叶子张量注册。在非叶子张量上通过torch.Tensor.register_post_accumulate_grad_hook()注册钩子会导致错误,即使你调用了backward(retain_graph=True)

注册到 torch.autograd.graph.Node 的 Hooks 只有在 torch.autograd.graph.Node.register_hook()torch.autograd.graph.Node.register_prehook() 被执行时才会触发。

某个特定节点是否被执行可能取决于反向传播是否被调用,使用的是 torch.autograd.grad()torch.autograd.backward()。 具体来说,在你为一个与传递给 torch.autograd.grad()torch.autograd.backward() 的张量对应的节点注册钩子时,你应该注意这些差异,作为 inputs 参数的一部分。

如果你使用的是torch.autograd.backward(),上述所有的钩子都会被执行, 无论你是否指定了inputs参数。这是因为.backward()会执行所有的节点, 即使它们对应于指定为输入的张量。 (请注意,与作为inputs传递的张量对应的额外节点的执行通常是不必要的,但仍会被执行。此行为可能会更改; 你不应依赖于此。)

另一方面,如果你使用的是torch.autograd.grad(),注册到传递给input张量对应的节点的反向钩子可能不会被执行,因为这些节点只有在有其他依赖于此节点梯度结果的输入时才会被执行。

不同钩子的触发顺序

事情发生的顺序是:

  1. 注册到张量的钩子被执行了

  2. 注册到节点的预钩子会在节点执行时运行(如果节点被执行)。

  3. 张量retain_grad字段更新了 .grad

  4. 节点在满足上述规则的情况下执行

  5. 对于累积值为 .grad 的叶子张量,在累积梯度后会执行后置钩子

  6. 注册到节点的后钩在节点执行时(如果有执行)也会被执行。

如果在同一个张量或节点上注册了多个相同类型的钩子, 它们将以注册的顺序执行。 较晚执行的钩子可以观察到由较早执行的钩子所做的梯度修改。

特殊的钩子

torch.autograd.graph.register_multi_grad_hook() 是使用注册到张量上的钩子实现的。每个单独的张量钩子按照上述定义的张量钩子顺序触发,当最后一个张量梯度计算完毕时,注册的多梯度钩子会被调用。

torch.nn.modules.module.register_module_full_backward_hook() 使用钩子在节点上实现。随着前向计算的进行,钩子会注册到与模块输入和输出对应的grad_fn。由于一个模块可能有多个输入和返回多个输出,因此在前向计算之前,首先对模块的输入应用一个虚拟的自定义自动求导函数,并在前向计算的输出返回之前对模块的输出应用该函数,以确保这些张量共享同一个grad_fn,然后我们可以将钩子附加到这个grad_fn上。

当张量就地修改时,张量钩子的行为

通常注册到张量上的钩子接收到相对于该张量的输出梯度,在反向传播计算时,张量的值被取为其当时的值。

然而,如果你对一个张量注册钩子,然后就地修改该张量,那么在就地修改之前注册的钩子同样会接收到关于输出相对于该张量的梯度,但张量的值会被视为其在就地修改之前的值。

如果你更倾向于前者的行为, 你应该在对张量进行所有就地修改之后,将其注册到该张量中。 例如:

t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()

此外,了解以下内容可能会有所帮助:在底层实现中, 当钩子注册到一个张量时,它们实际上会永久绑定到该张量的 grad_fn。 因此,如果该张量随后被就地修改, 尽管张量现在有了一个新的 grad_fn,但在就地修改之前注册的钩子仍将与旧的 grad_fn 相关联, 例如,当自动微分引擎在图中到达该张量的旧 grad_fn 时,这些钩子将会触发。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源