目录

Gradcheck 机制

此说明概述了 和 函数的工作原理gradgradcheck()

它将涵盖实数和复值函数以及高阶导数的正向和反向模式 AD。 此说明还涵盖了 gradcheck 的默认行为以及传递参数的情况(下面称为快速 gradcheck)。fast_mode=True

符号和背景信息

在本说明中,我们将使用以下约定:

  1. xx,yy,一个a,bb,vv,uu,ururuui是实值向量,zz是一个复值向量,可以根据两个实值向量重写为z=一个+bz = a + i b.

  2. NNMM是两个整数,我们将分别用于输入和输出空间的维度。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M是我们的基本实数到实数函数,使得y=f(x)y = f(x).

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M是我们的基本复数到实数函数,使得y=g(z)y = g(z).

对于简单的实数到实数情况,我们写为JfJ_fff大小M×NM \times N. 此矩阵包含所有偏导数,使得 position 处的入场(,j)(i, j)包含yxj\frac{\partial y_i}{\partial x_j}. 然后,向后模式 AD 正在计算给定向量vv大小MM、数量vTJfv^T J_f. 另一方面,正向模式 AD 是给定向量的计算uu大小NN、数量JfuJ_f u.

对于包含复杂值的函数,情况要复杂得多。我们在此处仅提供要点,完整的描述可以在 Autograd for Complex Numbers 中找到。

满足复微分性(Cauchy-Riemann 方程)的约束对于所有实值损失函数来说都太严格了,因此我们选择使用 Wirtinger 微积分。 在 Wirtinger 演算的基本设置中,链式规则需要访问 Wirtinger 导数(称为WW)和共轭 Wirtinger 导数(称为CWCW)。 双WWCWCW需要传播,因为一般来说,尽管它们的名字,一个不是另一个的复杂共轭。

为了避免必须传播两个值,对于向后模式 AD,我们始终假设正在计算其导数的函数是实值函数或更大的实值函数的一部分。这个假设意味着我们在向后传递期间计算的所有中间梯度也与实值函数相关联。 在实践中,在进行优化时,此假设不受限制,因为此类问题需要实值目标(因为复数没有自然排序)。

在此假设下,使用WWCWCW定义,我们可以证明W=CWW = CW^*(我们使用*在这里表示复共轭),因此两个值中只有一个实际上需要 “backwarded through the graph” ,因为另一个值可以很容易地恢复。 为了简化内部计算,PyTorch 使用2CW2 * CW作为值,它会向后移动并在用户请求渐变时返回。 与实际情况类似,当 output 实际在RM\mathcal{R}^M,则向后模式 AD 不计算2CW2 * CW但只有vT(2CW)v^T (2 * CW)对于给定的向量vRMv \in \mathcal{R}^M.

对于正向模式 AD,我们使用类似的逻辑,在本例中,假设该函数是输入位于R\mathcal{R}.在这个假设下,我们可以做出类似的声明,即每个中介结果都对应于一个函数,其输入位于R\mathcal{R}在这种情况下,使用WWCWCW定义,我们可以证明W=CWW = CW用于中介功能。 为了确保 forward 和 backward 模式在一维函数的基本情况下计算相同的量,forward 模式还计算2CW2 * CW. 与实际情况类似,当 input 实际处于RN\mathcal{R}^N,则转发模式 AD 不计算2CW2 * CW但只有(2CW)u(2 * CW) u对于给定的向量uRNu \in \mathcal{R}^N.

默认向后模式 gradcheck 行为

实数到实数函数

测试函数f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我们重建完整的雅可比矩阵JfJ_f大小M×NM \times N以两种方式:分析和数值。 分析版本使用我们的后向模式 AD,而数值版本使用有限差分。 然后对两个重建的雅可比矩阵进行元素比较是否相等。

默认实数输入数值计算

如果我们考虑一维函数的基本情况 (N=M=1N = M = 1),那么我们可以使用 Wikipedia 文章中的基本有限差分公式。我们使用 “central difference” 来获得更好的数值性质:

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

此公式很容易泛化为多个输出 (M>1M \gt 1) 替换为yx\frac{\partial y}{\partial x}是大小为M×1M \times 1喜欢f(x+eps)f(x + eps). 在这种情况下,上述公式可以按原样重复使用,并且只需对 user 函数进行两次计算即可近似完整的雅可比矩阵(即f(x+eps)f(x + eps)f(xeps)f(x - eps)).

处理具有多个输入 (N>1N \gt 1).在这个场景中,我们一个接一个地遍历所有输入,并应用epsepsperturbation 的每个单元的xx一个接一个。这允许我们重建JfJ_f逐列矩阵。

默认实数输入分析评估

对于分析评估,我们使用如上所述的事实,即 backward mode AD 计算vTJfv^T J_f. 对于具有单个输出的函数,我们只需使用v=1v = 1通过一次向后传递恢复完整的雅可比矩阵。

对于具有多个输出的函数,我们采用 for 循环,该循环迭代每个输出vv是一个 one-hot 向量,对应于一个接一个的输出。这允许重建JfJ_f逐行矩阵。

复数到实数函数

测试函数g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to yz=一个+bz = a + i b,我们重构包含2CW2 * CW.

默认复数输入数值计算

考虑基本情况,其中N=M=1N = M = 1第一。我们从这篇研究论文的(第 3 章)中了解到:

CW:=yz=12(y一个+yb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

请注意,y一个\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b},在上面的方程中,是RR\mathcal{R} \to \mathcal{R}衍生物。 为了对这些进行数值评估,我们使用上述方法进行实数到实数的情况。 这允许我们计算CWCW矩阵,然后将其乘以22.

请注意,截至撰写本文时,代码以略微复杂的方式计算此值:

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

默认复杂输入分析评估

由于向后模式 AD 的计算结果正好是CWCW导数,我们只需使用与此处的实数到实数情况相同的技巧,并在有多个实数输出时逐行重建矩阵。

具有复杂输出的函数

在这种情况下,用户提供的函数不遵循 autograd 的假设,即我们为其反向计算 AD 的函数是实值函数。 这意味着直接在这个函数上使用 autograd 没有明确定义。 为了解决这个问题,我们将替换函数的 testh:PNCMh: \mathcal{P}^N \to \mathcal{C}^M(其中P\mathcal{P}可以是R\mathcal{R}C\mathcal{C}),具有两个功能:hrhrhhi这样:

hr(q):=re一个l(f(q))h(q):=m一个g(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

哪里qPq \in \mathcal{P}. 然后,我们对两者进行基本的 gradcheckhrhrhhi使用上述 real-to-real 或 complex-to-real 情况,具体取决于P\mathcal{P}.

请注意,在撰写本文时,代码并未显式创建这些函数,而是使用re一个lrealm一个gimag函数,方法是将grad_out\text{grad\_out}参数分配给不同的函数。 什么时候grad_out=1\text{grad\_out} = 1,那么我们正在考虑hrhr. 什么时候grad_out=1j\text{grad\_out} = 1j,那么我们正在考虑hhi.

快速向后模式 gradcheck

虽然上述 gradcheck 的公式很棒,但为了确保正确性和可调试性,它非常慢,因为它重建了完整的雅可比矩阵。 本节介绍了一种在不影响其正确性的情况下以更快的方式执行 gradcheck 的方法。 当我们检测到错误时,可以通过添加特殊 logic 来恢复可调试性。在这种情况下,我们可以运行默认版本来重建完整矩阵,以向用户提供完整的详细信息。

这里的高级策略是找到一个标量,该标量可以通过数值和分析方法有效计算,并且足够好地表示慢速 gradcheck 计算的完整矩阵,以确保它能够捕获雅可比矩阵中的任何差异。

用于实数到实数函数的快速 gradcheck

我们在这里要计算的标量是vTJfuv^T J_f u对于给定的随机向量vRMv \in \mathcal{R}^M和一个随机单位范数向量uRNu \in \mathcal{R}^N.

对于数值计算,我们可以有效地计算

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然后,我们执行此向量和vv获取感兴趣的标量值。

对于分析版本,我们可以使用向后模式 AD 来计算vTJfv^T J_f径直。然后,我们使用uu以获取预期值。

复杂到实际函数的快速 gradcheck

与实数到实数的情况类似,我们想要对完整矩阵进行约简。但是2CW2 * CWmatrix 是复值,因此在这种情况下,我们将与复标量进行比较。

由于在数值情况下我们可以有效计算的内容受到一些限制,并且为了将数值计算的数量保持在最低限度,我们计算以下(尽管令人惊讶的)标量值:

s:=2vT(re一个l(CW)ur+m一个g(CW)u)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

哪里vRMv \in \mathcal{R}^M,urRNur \in \mathcal{R}^NuRNui \in \mathcal{R}^N.

快速复数输入数值评估

我们首先考虑如何计算ss使用数值方法。为此,请记住我们正在考虑g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to yz=一个+bz = a + i b,以及CW=12(y一个+yb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我们将其重写如下:

s=2vT(re一个l(CW)ur+m一个g(CW)u)=2vT(12y一个ur+12ybu)=vT(y一个ur+ybu)=vT((y一个ur)+(ybu))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在这个公式中,我们可以看到y一个ur\frac{\partial y}{\partial a} urybu\frac{\partial y}{\partial b} ui可以采用与 real-to-real 情况的快速版本相同的方式进行评估。 一旦计算了这些实值量,我们就可以在右侧重建复向量,并使用实值vv向量。

快速复杂输入分析评估

对于分析情况,事情更简单,我们将公式重写为:

s=2vT(re一个l(CW)ur+m一个g(CW)u)=vTre一个l(2CW)ur+vTm一个g(2CW)u)=re一个l(vT(2CW))ur+m一个g(vT(2CW))u\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我们可以利用 backward mode AD 为我们提供了一种有效的计算方法vT(2CW)v^T (2 * CW)然后执行实部的点积urur虚部与uui在重建最终的复数标量之前ss.

为什么不使用复杂uu

此时,您可能想知道为什么我们没有选择综合体uu然后就进行了缩减2vTCWu2 * v^T CW u'. 为了深入研究这一点,在本段中,我们将使用uu著名的u=ur+uu' = ur' + i ui'. 使用这种复杂的uu',问题是在进行数值计算时,我们需要计算:

2CWu=(y一个+yb)(ur+u)=y一个ur+y一个u+yburybu\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

这将需要对实数到实数的有限差值进行四次评估(与上面提出的方法相比,这是两倍)。 由于这种方法没有更多的自由度(相同数量的实值变量),并且我们试图在这里获得尽可能快的评估,因此我们使用上面的另一个公式。

对具有复杂输出的函数进行快速 gradcheck

就像在慢速情况下一样,我们考虑两个实值函数,并为每个函数使用上面的适当规则。

Gradgradcheck 实现

PyTorch 还提供了一个实用程序来验证二阶梯度。这里的目标是确保 backward implementation 也是正确的可微分的,并计算出正确的东西。

此功能是通过考虑函数来实现的F:x,vvTJfF: x, v \to v^T J_f并在此函数上使用上面定义的 gradcheck。 请注意,vv在本例中,只是一个类型与f(x)f(x).

gradgradcheck 的快速版本是通过在同一函数上使用 gradcheck 的快速版本来实现的FF.

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源