Gradcheck 机制¶
它将涵盖实数和复值函数以及高阶导数的正向和反向模式 AD。
此说明还涵盖了 gradcheck 的默认行为以及传递参数的情况(下面称为快速 gradcheck)。fast_mode=True
符号和背景信息¶
在本说明中,我们将使用以下约定:
,,,,,,和是实值向量,是一个复值向量,可以根据两个实值向量重写为.
和是两个整数,我们将分别用于输入和输出空间的维度。
是我们的基本实数到实数函数,使得.
是我们的基本复数到实数函数,使得.
对于简单的实数到实数情况,我们写为与大小. 此矩阵包含所有偏导数,使得 position 处的入场包含. 然后,向后模式 AD 正在计算给定向量大小、数量. 另一方面,正向模式 AD 是给定向量的计算大小、数量.
对于包含复杂值的函数,情况要复杂得多。我们在此处仅提供要点,完整的描述可以在 Autograd for Complex Numbers 中找到。
满足复微分性(Cauchy-Riemann 方程)的约束对于所有实值损失函数来说都太严格了,因此我们选择使用 Wirtinger 微积分。 在 Wirtinger 演算的基本设置中,链式规则需要访问 Wirtinger 导数(称为)和共轭 Wirtinger 导数(称为)。 双和需要传播,因为一般来说,尽管它们的名字,一个不是另一个的复杂共轭。
为了避免必须传播两个值,对于向后模式 AD,我们始终假设正在计算其导数的函数是实值函数或更大的实值函数的一部分。这个假设意味着我们在向后传递期间计算的所有中间梯度也与实值函数相关联。 在实践中,在进行优化时,此假设不受限制,因为此类问题需要实值目标(因为复数没有自然排序)。
在此假设下,使用和定义,我们可以证明(我们使用在这里表示复共轭),因此两个值中只有一个实际上需要 “backwarded through the graph” ,因为另一个值可以很容易地恢复。 为了简化内部计算,PyTorch 使用作为值,它会向后移动并在用户请求渐变时返回。 与实际情况类似,当 output 实际在,则向后模式 AD 不计算但只有对于给定的向量.
对于正向模式 AD,我们使用类似的逻辑,在本例中,假设该函数是输入位于.在这个假设下,我们可以做出类似的声明,即每个中介结果都对应于一个函数,其输入位于在这种情况下,使用和定义,我们可以证明用于中介功能。 为了确保 forward 和 backward 模式在一维函数的基本情况下计算相同的量,forward 模式还计算. 与实际情况类似,当 input 实际处于,则转发模式 AD 不计算但只有对于给定的向量.
默认向后模式 gradcheck 行为¶
实数到实数函数¶
测试函数,我们重建完整的雅可比矩阵大小以两种方式:分析和数值。 分析版本使用我们的后向模式 AD,而数值版本使用有限差分。 然后对两个重建的雅可比矩阵进行元素比较是否相等。
默认实数输入数值计算¶
如果我们考虑一维函数的基本情况 (),那么我们可以使用 Wikipedia 文章中的基本有限差分公式。我们使用 “central difference” 来获得更好的数值性质:
此公式很容易泛化为多个输出 () 替换为是大小为喜欢. 在这种情况下,上述公式可以按原样重复使用,并且只需对 user 函数进行两次计算即可近似完整的雅可比矩阵(即和).
处理具有多个输入 ().在这个场景中,我们一个接一个地遍历所有输入,并应用perturbation 的每个单元的一个接一个。这允许我们重建逐列矩阵。
默认实数输入分析评估¶
对于分析评估,我们使用如上所述的事实,即 backward mode AD 计算. 对于具有单个输出的函数,我们只需使用通过一次向后传递恢复完整的雅可比矩阵。
对于具有多个输出的函数,我们采用 for 循环,该循环迭代每个输出是一个 one-hot 向量,对应于一个接一个的输出。这允许重建逐行矩阵。
复数到实数函数¶
测试函数跟,我们重构包含.
默认复数输入数值计算¶
考虑基本情况,其中第一。我们从这篇研究论文的(第 3 章)中了解到:
请注意,和,在上面的方程中,是衍生物。 为了对这些进行数值评估,我们使用上述方法进行实数到实数的情况。 这允许我们计算矩阵,然后将其乘以.
请注意,截至撰写本文时,代码以略微复杂的方式计算此值:
# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above
ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()
# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.
默认复杂输入分析评估¶
由于向后模式 AD 的计算结果正好是导数,我们只需使用与此处的实数到实数情况相同的技巧,并在有多个实数输出时逐行重建矩阵。
具有复杂输出的函数¶
在这种情况下,用户提供的函数不遵循 autograd 的假设,即我们为其反向计算 AD 的函数是实值函数。 这意味着直接在这个函数上使用 autograd 没有明确定义。 为了解决这个问题,我们将替换函数的 test(其中可以是或),具有两个功能:和这样:
哪里. 然后,我们对两者进行基本的 gradcheck和使用上述 real-to-real 或 complex-to-real 情况,具体取决于.
请注意,在撰写本文时,代码并未显式创建这些函数,而是使用或函数,方法是将参数分配给不同的函数。 什么时候,那么我们正在考虑. 什么时候,那么我们正在考虑.
快速向后模式 gradcheck¶
虽然上述 gradcheck 的公式很棒,但为了确保正确性和可调试性,它非常慢,因为它重建了完整的雅可比矩阵。 本节介绍了一种在不影响其正确性的情况下以更快的方式执行 gradcheck 的方法。 当我们检测到错误时,可以通过添加特殊 logic 来恢复可调试性。在这种情况下,我们可以运行默认版本来重建完整矩阵,以向用户提供完整的详细信息。
这里的高级策略是找到一个标量,该标量可以通过数值和分析方法有效计算,并且足够好地表示慢速 gradcheck 计算的完整矩阵,以确保它能够捕获雅可比矩阵中的任何差异。
用于实数到实数函数的快速 gradcheck¶
我们在这里要计算的标量是对于给定的随机向量和一个随机单位范数向量.
对于数值计算,我们可以有效地计算
然后,我们执行此向量和获取感兴趣的标量值。
对于分析版本,我们可以使用向后模式 AD 来计算径直。然后,我们使用以获取预期值。
复杂到实际函数的快速 gradcheck¶
与实数到实数的情况类似,我们想要对完整矩阵进行约简。但是matrix 是复值,因此在这种情况下,我们将与复标量进行比较。
由于在数值情况下我们可以有效计算的内容受到一些限制,并且为了将数值计算的数量保持在最低限度,我们计算以下(尽管令人惊讶的)标量值:
哪里,和.
快速复数输入数值评估¶
我们首先考虑如何计算使用数值方法。为此,请记住我们正在考虑跟,以及,我们将其重写如下:
在这个公式中,我们可以看到和可以采用与 real-to-real 情况的快速版本相同的方式进行评估。 一旦计算了这些实值量,我们就可以在右侧重建复向量,并使用实值向量。
快速复杂输入分析评估¶
对于分析情况,事情更简单,我们将公式重写为:
因此,我们可以利用 backward mode AD 为我们提供了一种有效的计算方法然后执行实部的点积虚部与在重建最终的复数标量之前.
为什么不使用复杂¶
此时,您可能想知道为什么我们没有选择综合体然后就进行了缩减. 为了深入研究这一点,在本段中,我们将使用著名的. 使用这种复杂的,问题是在进行数值计算时,我们需要计算:
这将需要对实数到实数的有限差值进行四次评估(与上面提出的方法相比,这是两倍)。 由于这种方法没有更多的自由度(相同数量的实值变量),并且我们试图在这里获得尽可能快的评估,因此我们使用上面的另一个公式。
对具有复杂输出的函数进行快速 gradcheck¶
就像在慢速情况下一样,我们考虑两个实值函数,并为每个函数使用上面的适当规则。
Gradgradcheck 实现¶
PyTorch 还提供了一个实用程序来验证二阶梯度。这里的目标是确保 backward implementation 也是正确的可微分的,并计算出正确的东西。
此功能是通过考虑函数来实现的并在此函数上使用上面定义的 gradcheck。 请注意,在本例中,只是一个类型与.
gradgradcheck 的快速版本是通过在同一函数上使用 gradcheck 的快速版本来实现的.