目录

Gradcheck 机制

本文介绍了 gradcheck()gradgradcheck() 函数的工作原理概述。

它将涵盖实值和复值函数的前向和反向模式自动微分,以及高阶导数。 本说明还涵盖了gradcheck的默认行为以及传递fast_mode=True个参数的情况(以下称为快速gradcheck)。

符号和背景信息

在整个说明中,我们将使用以下约定:

  1. xx, yy, aa, bb, vv, uu, ururuiui 是实值向量,而 zz 是一个复值向量,可以重写为两个实值向量的形式,即 z=a+ibz = a + i b

  2. NNMM 是我们将用于输入和输出空间维度的两个整数。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我们的基本实到实函数,使得 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我们的基本复数到实数函数,使得 y=g(z)y = g(z)

对于简单的实到实情况,我们将其雅可比矩阵表示为JfJ_f,与ff相关,其大小为M×NM \times N。 该矩阵包含所有的偏导数,使得在位置(i,j)(i, j)处的条目包含yixj\frac{\partial y_i}{\partial x_j}。 反向模式自动微分(AD)然后计算给定大小为MM的向量vv的量vTJfv^T J_f。 另一方面,前向模式自动微分(AD)计算给定大小为NN的向量uu的量JfuJ_f u

对于包含复数值的函数,情况要复杂得多。我们在这里只提供要点,完整的描述可以在 复数的自动求导 找到。

满足复数可微分性的约束(柯西-黎曼方程)对所有实值损失函数来说过于严格,因此我们选择使用威廷格微积分。 在威廷格微积分的基本设置中,链式法则需要访问威廷格导数(如下WW所示)和共轭威廷格导数(如下CWCW所示)。 无论是WW还是CWCW都需要进行传播,因为一般来说,尽管它们的名字如此,一个并不是另一个的复共轭。

为了避免在反向模式自动微分中传播两个值,我们总是假设正在计算导数的函数要么是一个实值函数,要么是更大实值函数的一部分。这一假设意味着我们在反向传播过程中计算的所有中间梯度也与实值函数相关联。 实际上,在进行优化时,这一假设并不具有限制性,因为这类问题需要实值目标(因为复数没有自然的排序)。

在该假设下,使用WWCWCW定义,我们可以证明W=CWW = CW^*(我们用*表示复共轭)只需要其中一个值进行“反向传播”,因为另一个值可以很容易地恢复。 为了简化内部计算,PyTorch在用户请求梯度时,使用2CW2 * CW作为其反向传播并返回的值。 类似于实数情况,当输出实际上是RM\mathcal{R}^M时,反向模式自动微分不会计算2CW2 * CW,而只会计算给定向量vRMv \in \mathcal{R}^MvT(2CW)v^T (2 * CW)

对于前向模式AD,我们使用类似的逻辑,在这种情况下,假设该函数是一个更大函数的一部分,其输入在R\mathcal{R}中。在此假设下,我们可以得出类似的说法,即每个中间结果对应于一个输入在R\mathcal{R}中的函数,并且在这种情况下,使用WWCWCW的定义,我们可以证明对于中间函数W=CWW = CW。 为了确保在单变量函数的基本情况下,前向模式和后向模式计算相同的量,前向模式还计算2CW2 * CW。 类似于实际情况,当输入实际上在RN\mathcal{R}^N中时,前向模式AD不会计算2CW2 * CW,而只计算给定向量uRNu \in \mathcal{R}^N(2CW)u(2 * CW) u

默认反向模式的gradcheck行为

实到实函数

测试一个函数 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我们以两种方式重构完整的雅可比矩阵 JfJ_f,大小为 M×NM \times N:解析法和数值法。 解析版本使用我们的反向模式自动微分,而数值版本使用有限差分。 然后逐元素比较两个重构的雅可比矩阵是否相等。

默认实数输入数值评估

如果我们考虑一维函数的基本情况(N=M=1N = M = 1),那么我们可以使用来自维基百科文章的基本有限差分公式。为了获得更好的数值特性,我们使用“中心差分”:

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

此公式可以很容易地推广到多个输出(M>1M \gt 1)的情况,通过让yx\frac{\partial y}{\partial x}成为大小为M×1M \times 1的列向量,类似于f(x+eps)f(x + eps)。 在这种情况下,上述公式可以直接重用,并且只需两次用户函数评估(即f(x+eps)f(x + eps)f(xeps)f(x - eps))即可近似完整的雅可比矩阵。

处理多个输入的情况在计算上更昂贵(N>1N \gt 1)。在这种情况下,我们依次遍历所有的输入,并对每个元素依次应用epseps扰动。这使我们能够逐列重构JfJ_f矩阵。

默认实数输入分析评估

对于分析评估,我们使用上述描述的事实,即反向模式自动微分计算vTJfv^T J_f。 对于单输出函数,我们只需使用v=1v = 1在一次反向传递中恢复完整的雅可比矩阵。

对于具有多个输出的函数,我们使用for循环迭代每个输出,其中vv是一个与每个输出对应的one-hot向量,依次排列。这使得我们可以逐行重构JfJ_f矩阵。

复数到实数函数

测试函数 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to yz=a+ibz = a + i b,我们重构包含 2CW2 * CW 的(复数)矩阵。

默认复杂输入数值评估

考虑最基本的情况,先看N=M=1N = M = 1。我们从这篇研究论文(第三章)this research paper中得知:

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

注意,在上述方程中,ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b}RR\mathcal{R} \to \mathcal{R} 的导数。 为了数值计算这些值,我们使用了上面描述的实数到实数的情况的方法。 这使我们能够计算出 CWCW 矩阵,然后将其乘以 22

请注意,截至撰写本文时,该代码以一种稍微复杂的方式计算此值:

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

默认复杂输入分析评估

由于反向模式自动微分已经精确计算了CWCW阶导数的两倍,我们在这里简单地使用与实数到实数情况相同的技巧,并在存在多个实数输出时逐行重构矩阵。

具有复杂输出的函数

在这种情况下,用户提供的函数不符合自动求导的假设,即我们为其计算反向自动微分的函数是实值函数。 这意味着直接在这个函数上使用自动求导是没有明确定义的。 为了解决这个问题,我们将替换函数 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M 的测试(其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C}),用两个函数:hrhrhihi,使得:

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

qPq \in \mathcal{P}处。 然后我们对hrhrhihi进行基本的梯度检查,具体使用实到实或复到实的情况,取决于P\mathcal{P}

注意,截至编写之时,该代码并未显式创建这些函数,而是通过手动传递realrealimagimag函数的grad_out\text{grad\_out}参数来执行链式法则。 当grad_out=1\text{grad\_out} = 1时,我们正在考虑hrhr。 当grad_out=1j\text{grad\_out} = 1j时,我们正在考虑hihi

快速反向模式梯度检查

虽然上述的 gradcheck 表述在确保正确性和可调试性方面都很出色,但它非常慢,因为它重建了完整的雅可比矩阵。 本节介绍了一种以更快的方式执行 gradcheck 的方法,而不影响其正确性。 通过在检测到错误时添加特殊逻辑,可以恢复可调试性。在这种情况下,我们可以运行默认版本来重建完整矩阵,以便为用户提供详细信息。

这里的高级策略是找到一个标量量值,该量值可以通过数值方法和解析方法高效计算,并且能够很好地表示由缓慢的 gradcheck 计算得到的整个矩阵,从而确保可以捕捉到雅可比矩阵中的任何差异。

快速梯度检查适用于实数到实数的函数

我们在这里要计算的标量量是 vTJfuv^T J_f u,给定一个随机向量 vRMv \in \mathcal{R}^M 和一个随机单位范数向量 uRNu \in \mathcal{R}^N

对于数值评估,我们可以高效地计算

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然后,我们计算该向量与vv的点积,以获得感兴趣的标量值。

对于分析版本,我们可以使用反向模式自动微分直接计算vTJfv^T J_f。然后,我们对其进行点积运算与uu,以获得期望值。

快速梯度检查适用于复数到实数的函数

类似于真实到真实的案例,我们希望对整个矩阵进行约简。但2CW2 * CW矩阵是复数形式的,所以在这种情况下,我们将与复数标量进行比较。

由于在数值计算情况下的一些限制,以及为了将数值评估的数量降到最低,我们计算了以下(尽管令人惊讶)的标量值:

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^MurRNur \in \mathcal{R}^NuiRNui \in \mathcal{R}^N

快速复杂输入数值评估

我们首先考虑如何用数值方法计算ss。为此,牢记我们在考虑g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to yz=a+ibz = a + i b,并且CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我们将其重写如下:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在该公式中,我们可以看到yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui可以与实值到实值情况下的快速版本以相同的方式进行评估。 一旦计算出这些实数值量,我们就可以重构右侧的复向量,并与实值vv向量进行点积运算。

快速复杂输入分析评估

对于分析案例,情况要简单得多,我们将其公式重写为:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我们可以利用反向模式自动微分为我们提供了一种高效的方式来计算vT(2CW)v^T (2 * CW),然后将实部与urur进行点积运算,将虚部与uiui进行点积运算,最后重构最终的复数标量ss

为什么不使用一个复杂的 uu

到此为止,你可能会想知道为什么我们没有选择一个复杂的uu,而只是进行了简化2vTCWu2 * v^T CW u'。 要深入了解这一点,在这一段中,我们将使用复杂的版本的uu,其注解为u=ur+iuiu' = ur' + i ui'。 使用这种复杂的uu',问题在于在进行数值评估时,我们需要计算:

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

这将需要四次实数到实数有限差分的计算(相比上述提出的方法多一倍)。 由于该方法没有更多的自由度(实数值变量的数量相同),并且我们试图在这里获得尽可能快的评估,因此我们使用了上面的另一种公式。

快速梯度检查适用于具有复数输出的函数

就像在慢速情况下一样,我们考虑两个实值函数,并为每个函数使用上述适当的规则。

gradgradcheck 实现

PyTorch 还提供了一个工具来验证二阶导数。这里的目的是确保反向实现也是可微分的,并正确计算结果。

此功能通过考虑函数 F:x,vvTJfF: x, v \to v^T J_f 并在此函数上使用上述定义的gradcheck来实现。 请注意,在这种情况下,vv 只是一个与f(x)f(x) 同类型的随机向量。

通过在相同函数上使用快速版本的gradcheck来实现gradgradcheck的快速版本 FF

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源