目录

自动求导机制

本笔记将概述autograd的工作原理以及如何记录操作。虽然严格来说,理解这些内容并非必要,但我们建议您熟悉它们,因为这将帮助您编写更高效、更简洁的程序,并且在调试时也能提供帮助。

自动微分如何记录历史

Autograd 是一个反向自动微分系统。从概念上讲, autograd 会在你执行操作时记录一个图,该图记录了创建数据的所有操作, 从而给你一个有向无环图,其叶子节点是输入张量,根节点是输出张量。 通过从根节点到叶子节点追踪这个图,你可以使用链式法则自动计算梯度。

在内部,autograd 将此图表示为 Function 个对象(实际上是表达式),这些对象可以被 apply() 计算以得出结果。在计算前向传递时,autograd 同时执行请求的计算并构建一个表示计算梯度的函数的图(每个 torch.Tensor.grad_fn 属性是进入此图的入口点)。当前向传递完成时,我们在反向传递中评估此图以计算梯度。

需要注意的是,图在每次迭代时都会从头开始重建,这正是允许使用任意Python控制流语句的原因,这些语句可以在每次迭代中改变图的总体形状和大小。你不需要在启动训练之前对所有可能的路径进行编码——你运行的就是你要进行微分的内容。

保存的张量

一些操作需要在前向传递期间保存中间结果以便执行反向传递。例如,该函数 xx2x\mapsto x^2 保存输入 xx 以计算梯度。

在定义自定义的Python Function时,你可以使用 save_for_backward() 在前向传播过程中保存 张量,并使用saved_tensors 在反向传播过程中 检索它们。更多信息请参阅扩展PyTorch

对于PyTorch定义的操作(例如 torch.pow()),张量会根据需要自动保存。你可以通过查找以 _saved 前缀开头的属性,来探索某个 grad_fn 保存了哪些张量(用于教育或调试目的)。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self))  # True
print(x is y.grad_fn._saved_self)  # True

在之前的代码中,y.grad_fn._saved_self 指的是与x相同的张量对象。 但这并不总是如此。例如:

x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result))  # True
print(y is y.grad_fn._saved_result)  # False

在内部,为了防止引用循环,PyTorch 在保存时将张量 打包 并在读取时将其 解包 到另一个张量中。在这里,你通过访问 y.grad_fn._saved_result 获得的张量与 y 是不同的张量对象(但它们仍然共享相同的存储)。

张量是否会打包到不同的张量对象中,取决于它是否是其自身的grad_fn输出,这是一个实现细节,可能会发生变化,用户不应依赖于此。

您可以使用 保存的张量钩子 控制 PyTorch 如何进行打包 / 解包。

非可微函数的梯度

使用自动微分进行梯度计算仅在所使用的每个基本函数都是可微的情况下才有效。 不幸的是,我们在实践中使用的许多函数都不具备这一特性(例如在 0 处的 relusqrt)。 为了尽量减少不可微函数的影响,我们通过依次应用以下规则来定义基本操作的梯度:

  1. 如果该函数是可微的,因此在当前点存在梯度,则使用它。

  2. 如果函数是凸的(至少在局部是),使用最小范数的次梯度(它是最陡下降方向)。

  3. 如果函数是凹函数(至少在局部范围内),使用最小范数的超梯度(考虑-f(x)并应用前一个点)。

  4. 如果函数已定义,则通过连续性在当前点定义梯度(注意这里 inf 是可能的,例如对于 sqrt(0))。如果有多个值是可能的,则任意选择一个。

  5. 如果函数未定义(例如输入为sqrt(-1)log(-1)或大多数函数在输入为NaN时),则用作梯度的值是任意的(我们可能会引发错误,但这不能保证)。大多数函数将使用NaN作为梯度,但由于性能原因,某些函数将使用其他值(例如log(-1))。

  6. 如果函数不是确定性映射(即它不是一个数学函数),它将被标记为不可微分。这将导致在反向传播时出错,如果它被用于需要梯度的张量上且不在no_grad环境中。

局部禁用梯度计算

有几种机制可以从Python中局部禁用梯度计算:

要禁用整个代码块的梯度,可以使用上下文管理器 如无梯度模式和推理模式。 对于更精细地排除子图的梯度计算, 可以设置张量的requires_grad字段。

下面,除了讨论上述机制外,我们还描述了 评估模式 (nn.Module.eval()),这是一种实际上并未用于禁用梯度计算的方法,但由于其名称,常常与前三者混淆。

设置 requires_grad

requires_grad 是一个标志,默认为false,除非被包裹在 nn.Parameter 中,它允许对梯度计算中的子图进行细粒度的排除。它在前向和后向传递中都生效:

在前向传递过程中,只有当其输入张量中至少有一个需要梯度时,才会将操作记录在反向图中。 在反向传递(.backward())期间,只有具有requires_grad=True的叶张量才会将其梯度累积到它们的.grad字段中。

需要注意的是,尽管每个张量都有这个标志, 设置它只对叶张量(没有 grad_fn的张量,例如,一个nn.Module的参数)有意义。 非叶张量(有grad_fn的张量)是具有与之关联的反向图的张量。因此,它们的梯度将作为中间结果来计算需要梯度的叶张量的梯度。从这个定义来看,所有非叶张量都将自动具有require_grad=True

设置 requires_grad 应该是你控制模型哪些部分参与梯度计算的主要方式,例如,如果你需要在模型微调期间冻结预训练模型的部分。

要冻结模型的部分,只需将 .requires_grad_(False) 应用于 你不想更新的参数。正如上面所述, 由于使用这些参数作为输入的计算不会在 前向传递中记录,因此它们不会在反向 传递中更新其 .grad 字段,因为它们从一开始就不会成为反向图的一部分,这正是我们所期望的。

由于这是非常常见的模式,requires_grad 也可以在模块级别设置为 nn.Module.requires_grad_()。 当应用于模块时,.requires_grad_() 对模块的所有参数生效(这些参数默认具有 requires_grad=True)。

梯度模式

除了设置 requires_grad 之外,还有三种可以通过 Python 启用的模式,这些模式会影响 PyTorch 中 autograd 内部如何处理计算:默认模式(grad 模式)、无梯度模式和推理模式。所有这些模式都可以通过上下文管理器和装饰器进行切换。

默认模式(梯度模式)

“默认模式”实际上是在未启用其他模式(如 no-grad 模式和推理模式)时我们所隐式处于的模式。与“no-grad 模式”相对,“默认模式”有时也被称为“grad 模式”。

关于默认模式最重要的一点是,只有在这种模式下 requires_grad 才会生效。在其他两种模式中,requires_grad 总是被覆盖为 False

无梯度模式

在无梯度模式下的计算行为就像没有任何输入需要梯度一样。 换句话说,即使有输入具有 require_grad=True,无梯度模式下的计算也永远不会被记录在反向图中。

当你需要执行不应被自动梯度记录的操作,但你仍然希望在稍后以梯度模式使用这些计算的输出时,启用无梯度模式。这个上下文管理器使得在代码块或函数中禁用梯度变得方便,而无需临时将张量设置为 requires_grad=False,然后再恢复为 True

例如,在编写优化器时,无梯度模式可能会很有用:在执行训练更新时,你希望原地更新参数而不被自动求导记录。你还打算在下一次前向传播中使用更新后的参数进行计算,并且这些计算需要在有梯度模式下进行。

torch.nn.init 中的实现同样依赖于无梯度模式来初始化参数,以避免在原地更新已初始化参数时进行自动求导跟踪。

推理模式

推理模式是无梯度模式的极端版本。就像在无梯度模式下一样,推理模式中的计算不会被记录在反向图中,但启用推理模式将允许PyTorch进一步加速你的模型。这种更好的运行时性能伴随着一个缺点:在退出推理模式后,推理模式中创建的张量将无法用于由自动求导记录的计算中。

在执行不需要记录到反向图中的计算时启用推理模式,并且你不会在任何后续由 autograd 记录的计算中使用在推理模式下创建的张量。

建议你在代码中不需要自动梯度跟踪的部分(例如数据处理和模型评估)尝试使用推理模式。如果它在你的使用场景中可以直接工作,那么这是一个免费的性能提升。如果你在启用推理模式后遇到错误,请检查你是否没有在退出推理模式后被自动梯度记录的计算中使用在推理模式下创建的张量。如果你无法避免在你的案例中这样使用,你可以随时切换回无梯度模式。

有关推理模式的详细信息,请参阅 推理模式

有关推理模式的实现细节,请参见 RFC-0011-InferenceMode

评估模式 (nn.Module.eval())

评估模式实际上并不是一种用于局部禁用梯度计算的机制。 无论如何,这里包含它是因为有时会被误认为是这样的机制。

功能上,module.eval()(或等效地module.train(False))与无梯度模式和推理模式完全正交。model.eval()如何影响你的模型完全取决于你的模型中使用的特定模块以及它们是否定义了任何特定于训练模式的行为。

您负责调用 model.eval()model.train(),如果您的 模型依赖于诸如 torch.nn.Dropouttorch.nn.BatchNorm2d 等模块, 这些模块在训练模式下可能会有不同的行为,例如,为了避免在验证数据上更新您的 BatchNorm 运行统计信息。

建议您在训练时始终使用 model.train(),在评估模型(验证/测试)时使用 model.eval(),即使您不确定您的模型是否具有特定于训练模式的行为,因为您使用的模块可能会更新为在训练和评估模式下表现出不同的行为。

原地操作与自动梯度

支持就地操作在自动微分中是一个复杂的问题,我们不建议在大多数情况下使用它们。自动微分的激进缓冲区释放和重用使其非常高效,实际上能够显著降低内存使用的场合很少见。除非你在极高的内存压力下运行,否则你可能永远不需要使用它们。

有两个主要原因限制了原地操作的适用性:

  1. 原地操作可能会覆盖计算梯度所需的数据。

  2. 每个原地操作实际上都需要实现来重写计算图。非原地版本只是分配新的对象并保留对旧图的引用,而原地操作则需要将所有输入的创建者更改为表示此操作的Function。这可能会很棘手,特别是当有多个张量引用相同的存储(例如通过索引或转置创建的张量),如果修改的输入存储被任何其他Tensor引用,原地函数实际上会引发错误。

原地正确性检查

每个张量都保留一个版本计数器,该计数器在每次标记为任何操作中的脏数据时都会递增。当一个函数保存任何用于反向传播的张量时,也会保存其包含张量的版本计数器。一旦你访问self.saved_tensors,它就会被检查,如果它大于保存的值,则会引发错误。这确保了如果你正在使用原地函数并且没有看到任何错误,你可以确信计算出的梯度是正确的。

多线程自动梯度

自动求导引擎负责运行所有必要的反向操作以计算反向传播。本节将描述所有可以帮助你在多线程环境中充分利用它的细节。(这仅适用于PyTorch 1.6+,因为之前版本的行为有所不同。)

用户可以使用多线程代码(例如 Hogwild 训练)来训练他们的模型,并且不会在并发的反向计算上阻塞,示例代码可以是:

# Define a train function to be used in different threads
def train_fn():
    x = torch.ones(5, 5, requires_grad=True)
    # forward
    y = (x + 3) * (x + 4) * 0.5
    # backward
    y.sum().backward()
    # potential optimizer update


# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
    p = threading.Thread(target=train_fn, args=())
    p.start()
    threads.append(p)

for p in threads:
    p.join()

请注意,用户应了解某些行为:

CPU上的并发性

当你在多个线程中通过Python或C++ API 在CPU上运行backward()grad()时,你期望看到额外的并发性,而不是在执行期间以特定顺序序列化所有的反向调用(这是PyTorch 1.6之前的默认行为)。

Non-determinism

如果你在多个线程中并发调用 backward(),但使用了共享输入(即 Hogwild CPU 训练)。由于参数会在线程之间自动共享,在不同线程的反向传播调用中,梯度累积可能会变得不可预测,因为两个反向传播调用可能会访问并尝试累积相同的 .grad 属性。从技术上讲,这并不安全,可能导致竞态条件,并且结果可能无效。

但如果你使用多线程方法来驱动整个训练过程,同时使用共享参数,这种情况是预期的模式。使用多线程的用户应牢记线程模型,并应预期会出现这种情况。用户可以使用函数式 API torch.autograd.grad() 来计算梯度,而不是使用 backward() 以避免非确定性。

保留图

如果自动梯度图的一部分在多个线程之间共享,即首先在单个线程中运行前半部分,然后在多个线程中运行后半部分,则图的前半部分是共享的。在这种情况下,不同的线程在同一图上执行grad()backward()可能会出现一个问题,即一个线程在运行时破坏了图,而另一个线程在这种情况下会崩溃。自动梯度将向用户报告类似于调用backward()两次而不使用retain_graph=True的错误,并告知用户他们应该使用retain_graph=True

自动求导节点的线程安全性

由于 Autograd 允许调用线程驱动其反向执行以实现潜在的并行性,因此在 CPU 上确保线程安全性非常重要,特别是在共享部分或全部 GraphTask 的并行反向传播中。

自定义 Python autograd.Function 由于 GIL 的存在,会自动实现线程安全。 对于内置的 C++ Autograd 节点(例如 AccumulateGrad、CopySlices)和自定义 autograd::Function,Autograd 引擎使用线程互斥锁来确保对可能具有状态读写操作的 Autograd 节点的线程安全。

C++ 钩子不保证线程安全

自动求导依赖于用户编写线程安全的C++钩子。如果你想在多线程环境中正确应用钩子,你需要编写适当的线程锁定代码以确保钩子是线程安全的。

复数的自动求导

简短版本:

  • 当你使用PyTorch对具有复数域和/或陪域的任何函数 f(z)f(z) 进行微分时, 梯度计算是基于该函数是更大的实值损失函数 g(input)=Lg(input)=L 的一部分这一假设进行的。计算出的梯度为 Lz\frac{\partial L}{\partial z^*} (注意z的共轭),其负值正是梯度下降算法中使用的最陡下降方向。因此,所有现有的优化器都可以直接与复数参数配合使用。

  • 此惯例与TensorFlow的复数微分惯例匹配,但与JAX不同(JAX计算Lz\frac{\partial L}{\partial z})。

  • 如果你有一个实数到实数的函数,其内部使用了复杂数运算,这里的约定并不重要:你将始终得到与仅使用实数运算实现时相同的结果。

如果你对数学细节感到好奇,或者想知道如何在PyTorch中定义复杂的导数,请继续阅读。

什么是复数导数?

复数可微性的数学定义采用了导数的极限定义,并将其推广以适用于复数。考虑一个函数 f:CCf: ℂ → ℂ

f(z=x+yj)=u(x,y)+v(x,y)j`f(z=x+yj) = u(x, y) + v(x, y)j`

where uuvv 是两个变量实值函数。

使用导数的定义,我们可以写成:

f(z)=limh0,hCf(z+h)f(z)hf'(z) = \lim_{h \to 0, h \in C} \frac{f(z+h) - f(z)}{h}

为了使这个极限存在,不仅 uuvv 必须是实可微的,而且 ff 还必须满足柯西-黎曼 方程。换句话说:用实数和虚数步长 (hh) 计算的极限必须相等。这是一个更为严格的条件。

复可微函数通常被称为全纯函数。它们性质良好,具有所有你在实可微函数中看到的优点,但在优化领域实际上毫无用处。对于优化问题,研究社区只使用实值目标函数,因为复数不属于任何有序域,因此拥有复值损失并没有太大意义。

事实证明,没有有趣的实值目标函数能满足柯西-黎曼方程。因此,全形函数的理论不能用于优化,大多数人因此使用微分学(Wirtinger calculus)。

Wirtinger Calculus comes in picture …

因此,我们有一个关于复可微性和全纯函数的伟大理论,但我们无法使用其中的任何内容,因为许多常用函数都不是全纯的。一个可怜的数学家该怎么办呢?Well,Wirtinger观察到即使f(z)f(z)不是全纯的,也可以将其重写为两个变量的函数f(z,z)f(z, z*),该函数总是全纯的。这是因为zz的实部和虚部可以用zzzz^*表示:

Re(z)=z+z2Im(z)=zz2j\begin{aligned} Re(z) &= \frac {z + z^*}{2} \\ Im(z) &= \frac {z - z^*}{2j} \end{aligned}

Wirtinger微积分建议研究f(z,z)f(z, z^*),如果ff是实可微的,则其保证是全纯的(另一种思考方式是将其视为坐标系的变化,从f(x,y)f(x, y)f(z,z)f(z, z^*))。该函数具有偏导数z\frac{\partial }{\partial z}z\frac{\partial}{\partial z^{*}}。 我们可以使用链式法则来建立这些偏导数与zz的实部和虚部的偏导数之间的关系。

x=zxz+zxz=z+zy=zyz+zyz=1j(zz)\begin{aligned} \frac{\partial }{\partial x} &= \frac{\partial z}{\partial x} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial x} * \frac{\partial }{\partial z^*} \\ &= \frac{\partial }{\partial z} + \frac{\partial }{\partial z^*} \\ \\ \frac{\partial }{\partial y} &= \frac{\partial z}{\partial y} * \frac{\partial }{\partial z} + \frac{\partial z^*}{\partial y} * \frac{\partial }{\partial z^*} \\ &= 1j * (\frac{\partial }{\partial z} - \frac{\partial }{\partial z^*}) \end{aligned}

从上述方程中,我们得到:

z=1/2(x1jy)z=1/2(x+1jy)\begin{aligned} \frac{\partial }{\partial z} &= 1/2 * (\frac{\partial }{\partial x} - 1j * \frac{\partial }{\partial y}) \\ \frac{\partial }{\partial z^*} &= 1/2 * (\frac{\partial }{\partial x} + 1j * \frac{\partial }{\partial y}) \end{aligned}

这是您会在维基百科上找到的经典的Wirtinger微积分定义。

这一变化带来了许多美妙的结果。

  • 首先,Cauchy-Riemann 方程简单来说就是 fz=0\frac{\partial f}{\partial z^*} = 0 (也就是说,函数 ff 可以完全用 zz 表示,而无需引用 zz^*)。

  • 我们稍后会看到,另一个重要(且有些违背直觉)的结果是,在对实值损失进行优化时,我们在进行变量更新时应该采取的步骤由Lossz\frac{\partial Loss}{\partial z^*}(而不是Lossz\frac{\partial Loss}{\partial z})给出。

要进一步阅读,请查看:https://arxiv.org/pdf/0906.4835.pdf

变分法在优化中有什么用?

音频和其他领域的研究人员更常用梯度下降来优化具有复数变量的真实值损失函数。通常,这些人将实部和虚部分别视为可以更新的不同通道。对于步长α/2\alpha/2和损失LL,我们可以写出以下在R2ℝ^2中的方程:

xn+1=xn(α/2)Lxyn+1=yn(α/2)Ly\begin{aligned} x_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} \\ y_{n+1} &= y_n - (\alpha/2) * \frac{\partial L}{\partial y} \end{aligned}

这些方程如何转化为复数空间 C

zn+1=xn(α/2)Lx+1j(yn(α/2)Ly)=znα1/2(Lx+jLy)=znαLz\begin{aligned} z_{n+1} &= x_n - (\alpha/2) * \frac{\partial L}{\partial x} + 1j * (y_n - (\alpha/2) * \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * 1/2 * (\frac{\partial L}{\partial x} + j \frac{\partial L}{\partial y}) \\ &= z_n - \alpha * \frac{\partial L}{\partial z^*} \end{aligned}

发生了一件非常有趣的事情:Wirtinger演算告诉我们,我们可以将上面的复变量更新公式简化为仅引用共轭Wirtinger导数 Lz\frac{\partial L}{\partial z^*},这正好是我们优化过程中所采取的步骤。

由于共轭 Wirtinger 导数正好为我们提供了实值损失函数的正确步长,因此 PyTorch 在对具有实值损失的函数求导时会给出这个导数。

PyTorch 如何计算共轭Wirtinger导数?

通常,我们的导数公式以grad_output作为输入, 代表我们已经计算出的传入向量-雅可比矩阵积,即Ls\frac{\partial L}{\partial s^*}, 其中LL是整个计算的损失(产生实际损失), 而ss是我们函数的输出。目标是计算Lz\frac{\partial L}{\partial z^*}, 其中zz是函数的输入。事实证明,在实际损失的情况下, 我们只需要计算Lz\frac{\partial L}{\partial z^*}, 尽管链式法则暗示我们也需要访问Lz\frac{\partial L}{\partial z^*}。如果你想要跳过这个推导, 请查看本节的最后一方程,然后跳到下一节。

让我们继续使用f:CCf: ℂ → ℂ,其被定义为 f(z)=f(x+yj)=u(x,y)+v(x,y)jf(z) = f(x+yj) = u(x, y) + v(x, y)j。正如上面讨论的那样, autograd 的梯度约定围绕实值损失函数的优化展开,因此我们假设ff 是更大的 实值损失函数 gg 的一部分。利用链式法则,我们可以写出:

(1)Lz=Luuz+Lvvz\frac{\partial L}{\partial z^*} = \frac{\partial L}{\partial u} * \frac{\partial u}{\partial z^*} + \frac{\partial L}{\partial v} * \frac{\partial v}{\partial z^*}

现在使用Wirtinger导数定义,我们可以写:

Ls=1/2(LuLvj)Ls=1/2(Lu+Lvj)\begin{aligned} \frac{\partial L}{\partial s} = 1/2 * (\frac{\partial L}{\partial u} - \frac{\partial L}{\partial v} j) \\ \frac{\partial L}{\partial s^*} = 1/2 * (\frac{\partial L}{\partial u} + \frac{\partial L}{\partial v} j) \end{aligned}

需要注意的是,由于uuvv是实函数,并且根据我们的假设ff是实值函数的一部分,LL也是实数,因此我们有:

(2)(Ls)=Ls(\frac{\partial L}{\partial s})^* = \frac{\partial L}{\partial s^*}

即,Ls\frac{\partial L}{\partial s}等于grad_outputgrad\_output^*

解上述方程可得Lu\frac{\partial L}{\partial u}Lv\frac{\partial L}{\partial v}

(3)Lu=Ls+LsLv=1j(LsLs)\begin{aligned} \frac{\partial L}{\partial u} = \frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*} \\ \frac{\partial L}{\partial v} = -1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) \end{aligned}

(3) 替换 (1),我们得到:

Lz=(Ls+Ls)uz1j(LsLs)vz=Ls(uz+vzj)+Ls(uzvzj)=Ls(u+vj)z+Ls(u+vj)z=Lssz+Lssz\begin{aligned} \frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s} + \frac{\partial L}{\partial s^*}) * \frac{\partial u}{\partial z^*} - 1j * (\frac{\partial L}{\partial s} - \frac{\partial L}{\partial s^*}) * \frac{\partial v}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * (\frac{\partial u}{\partial z^*} + \frac{\partial v}{\partial z^*} j) + \frac{\partial L}{\partial s^*} * (\frac{\partial u}{\partial z^*} - \frac{\partial v}{\partial z^*} j) \\ &= \frac{\partial L}{\partial s^*} * \frac{\partial (u + vj)}{\partial z^*} + \frac{\partial L}{\partial s} * \frac{\partial (u + vj)^*}{\partial z^*} \\ &= \frac{\partial L}{\partial s} * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * \frac{\partial s^*}{\partial z^*} \\ \end{aligned}

使用 (2),我们得到:

(4)Lz=(Ls)sz+Ls(sz)=(grad_output)sz+grad_output(sz)\begin{aligned} \frac{\partial L}{\partial z^*} &= (\frac{\partial L}{\partial s^*})^* * \frac{\partial s}{\partial z^*} + \frac{\partial L}{\partial s^*} * (\frac{\partial s}{\partial z})^* \\ &= \boxed{ (grad\_output)^* * \frac{\partial s}{\partial z^*} + grad\_output * {(\frac{\partial s}{\partial z})}^* } \\ \end{aligned}

最后一个方程是你编写自己的梯度时的重要公式, 因为它将我们的导数公式分解为一个更简单的形式, 便于手动计算。

如何为一个复杂函数编写自己的导数公式?

上述方程给出了所有复函数导数的一般公式。然而,我们仍然需要计算sz\frac{\partial s}{\partial z}sz\frac{\partial s}{\partial z^*}。 你可以通过两种方式来完成这个计算:

  • The first way is to just use the definition of Wirtinger derivatives directly and calculate sz\frac{\partial s}{\partial z} and sz\frac{\partial s}{\partial z^*} by using sx\frac{\partial s}{\partial x} and sy\frac{\partial s}{\partial y} (which you can compute in the normal way).

  • The second way is to use the change of variables trick and rewrite f(z)f(z) as a two variable function f(z,z)f(z, z^*), and compute the conjugate Wirtinger derivatives by treating zz and zz^* as independent variables. This is often easier; for example, if the function in question is holomorphic, only zz will be used (and sz\frac{\partial s}{\partial z^*} will be zero).

让我们以函数f(z=x+yj)=cz=c(x+yj)f(z = x + yj) = c * z = c * (x+yj)为例,其中cRc \in ℝ

使用第一种方法计算Wirtinger导数,我们得到。

sz=1/2(sxsyj)=1/2(c(c1j)1j)=csz=1/2(sx+syj)=1/2(c+(c1j)1j)=0\begin{aligned} \frac{\partial s}{\partial z} &= 1/2 * (\frac{\partial s}{\partial x} - \frac{\partial s}{\partial y} j) \\ &= 1/2 * (c - (c * 1j) * 1j) \\ &= c \\ \\ \\ \frac{\partial s}{\partial z^*} &= 1/2 * (\frac{\partial s}{\partial x} + \frac{\partial s}{\partial y} j) \\ &= 1/2 * (c + (c * 1j) * 1j) \\ &= 0 \\ \end{aligned}

使用(4),和grad_output = 1.0 (这是在PyTorch中调用backward()时用于标量输出的默认梯度输出值),我们得到:

Lz=10+1c=c\frac{\partial L}{\partial z^*} = 1 * 0 + 1 * c = c

使用第二种方法计算Wirtinger导数,我们直接得到:

sz=(cz)z=csz=(cz)z=0\begin{aligned} \frac{\partial s}{\partial z} &= \frac{\partial (c*z)}{\partial z} \\ &= c \\ \frac{\partial s}{\partial z^*} &= \frac{\partial (c*z)}{\partial z^*} \\ &= 0 \end{aligned}

再次使用 (4),我们得到 Lz=c\frac{\partial L}{\partial z^*} = c。正如你所见,第二种方法涉及较少的计算,并且更适合快速计算。

关于跨域函数呢?

有些函数是从复数输入映射到实数输出,或者反之。这些函数是(4)的一个特例,我们可以使用链式法则推导出来:

  • For f:CRf: ℂ → ℝ, we get:

    Lz=2grad_outputsz\frac{\partial L}{\partial z^*} = 2 * grad\_output * \frac{\partial s}{\partial z^{*}}
  • For f:RCf: ℝ → ℂ, we get:

    Lz=2Re(grad_outsz)\frac{\partial L}{\partial z^*} = 2 * Re(grad\_out^* * \frac{\partial s}{\partial z^{*}})

保存的张量钩子

您可以控制保存的张量如何打包/解包,通过定义一对pack_hook / unpack_hook 钩子。 pack_hook 函数应将张量作为其唯一参数,但可以返回任何 Python 对象(例如另一个张量、元组,甚至是包含文件名的字符串)。 unpack_hook 函数将其唯一参数作为 pack_hook 的输出,并应返回一个在反向传播中使用的张量。由 unpack_hook 返回的张量只需要与传递给 pack_hook 的输入张量具有相同的内容。特别是,任何与 autograd 相关的元数据都可以忽略,因为它们将在解包过程中被覆盖。

这样的组对示例为:

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

注意,unpack_hook 不应删除临时文件,因为它 可能会被多次调用:临时文件应该在返回的SelfDeletingTempFile对象存活期间保持有效。 在上述示例中, 我们通过在不再需要时关闭它来防止泄露临时文件(在删除SelfDeletingTempFile对象时)。

注意

我们保证pack_hook只会被调用一次,但unpack_hook可以根据反向传播的需求多次调用,并且我们期望它每次返回相同的数据。

警告

禁止对任何函数的输入执行就地操作,因为这可能会导致意想不到的副作用。如果对包钩的输入进行了就地修改,PyTorch 会抛出错误,但不会捕获对解包钩输入进行就地修改的情况。

为保存的张量注册钩子

你可以通过调用一个 register_hooks() 方法在一个 SavedTensor 对象上注册一对钩子。这些对象作为 grad_fn 的属性暴露出来,并以 _raw_saved_ 前缀开头。

x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)

pack_hook方法在配对注册后立即调用。 这unpack_hook方法每次需要访问保存的张量时都会被调用,无论是通过y.grad_fn._saved_self还是在反向传播过程中。

警告

如果你在保存的张量被释放(即反向传播被调用后)仍然保留对SavedTensor的引用,那么调用其register_hooks()是不允许的。 PyTorch 通常会抛出一个错误,但在某些情况下可能会失败,并且可能会出现未定义的行为。

为保存的张量注册默认钩子

或者,你可以使用上下文管理器 saved_tensors_hooks 来注册一对钩子, 这些钩子将应用于在该上下文中创建的所有保存的张量。

Example:

# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class Model(nn.Module):
    def forward(self, x):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
          # ... compute output
          output = x
        return output

model = Model()
net = nn.DataParallel(model)

在此上下文中定义的钩子是线程局部的。 因此,以下代码不会产生预期的效果,因为钩子不会经过DataParallel

# Example what NOT to do

net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    output = net(input)

请注意,使用这些钩子会禁用所有减少 Tensor 对象创建的优化。例如:

with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
    x = torch.randn(5, requires_grad=True)
    y = x * x

没有钩子时,xy.grad_fn._saved_selfy.grad_fn._saved_other 都是指向同一个张量对象。 有了钩子后,PyTorch 会将 x 打包和解包为两个新的张量对象, 这些对象与原始的 x 共享相同的存储(未进行复制)。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源