注意
点击 这里 下载完整示例代码
雅可比矩阵、海森矩阵、hvp、vhp 及更多:组合函数变换¶
创建日期: 2023年3月15日 | 最后更新日期: 2023年4月18日 | 最后验证日期: 2024年11月5日
计算雅可比或黑塞矩阵在许多非传统的深度学习模型中非常有用。使用PyTorch的常规自动微分API(Tensor.backward(), torch.autograd.grad)高效地计算这些量是困难的(或令人烦恼的)。PyTorch的受JAX启发的函数变换API提供了高效计算各种高阶自动微分量的方法。
注意
本教程要求使用 PyTorch 2.0.0 或更高版本。
计算雅可比矩阵¶
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
让我们从一个我们希望计算雅可比矩阵的函数开始。 这是一个简单的线性函数,带有非线性激活。
让我们添加一些假数据:一个权重、一个偏置和一个特征向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
让我们将 predict 视作一个函数,它将输入 x 映射到 \(R^D \to R^D\)。
PyTorch Autograd 计算向量-雅可比乘积。为了计算这个 \(R^D \to R^D\) 函数的完整雅可比矩阵,我们每次都需要使用不同的单位向量逐行进行计算。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
Instead of 计算雅可比矩阵的每一行,我们可以使用 PyTorch 的
torch.vmap 函数转换来消除 for 循环并矢量化计算。我们不能直接将 vmap 应用于 torch.autograd.grad;
相反,PyTorch 提供了一个 torch.func.vjp 转换,可以与
torch.vmap 组合使用:
from torch.func import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在后续的教程中,反向模式AD与vmap的组合将为我们提供
单个样本梯度。
在这个教程中,反向模式AD与vmap的组合为我们提供了雅可比计算能力!
各种vmap和自动微分变换的组合可以为我们提供不同的有趣量。
PyTorch 提供了 torch.func.jacrev 作为方便函数,执行 vmap-vjp 组合以计算雅可比矩阵。jacrev 接受一个 argnums
参数,指定我们希望根据哪个参数计算雅可比矩阵。
from torch.func import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
让我们比较一下两种计算雅可比矩阵的方法的性能。 函数转换版本要快得多(并且随着输出数量的增加会更快)。
一般而言,我们期望通过 vmap 向量化可以有助于消除开销并更好地利用您的硬件。
vmap 通过将外部循环推入函数的基本操作来实现这一奇迹,从而获得更好的性能。
让我们快速编写一个函数来评估性能并处理微秒和毫秒的测量值:
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然后运行性能对比:
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c42779990>
compute_jac(xp)
2.77 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c30629660>
jacrev(predict, argnums=2)(weight, bias, x)
717.70 us
1 measurement, 500 runs , 1 thread
让我们对上述内容进行相对性能比较,使用我们的 get_perf 函数:
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
Performance delta: 74.0459 percent improvement with vmap
此外,也很容易将问题反过来表述,即我们想要计算模型参数(权重、偏置)的雅可比矩阵,而不是输入的雅可比矩阵。
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩阵(jacrev) vs 正向模式雅可比矩阵(jacfwd)¶
我们提供两种计算雅可比矩阵的API:jacrev 和 jacfwd:
jacrev使用了反向模式自动微分。如您所见,它是vjp和vmap变换的组合。jacfwd使用了前向模式自动微分。它是由我们的jvp和vmap转换组成的。
jacfwd 和 jacrev 可以互相替换,但它们有不同的性能特性。
As a general rule of thumb, if you’re computing the jacobian of an \(R^N \to R^M\)
函数,且输出远多于输入(例如,\(M > N\)),则应优先使用 jacfwd,否则使用 jacrev。但此规则有例外,不过非严格的论据如下:
在反向模式自动微分中,我们逐行计算雅可比矩阵,在前向模式自动微分(用于计算雅可比-向量乘积)中,则是逐列计算。雅可比矩阵有M行和N列,因此如果矩阵更长或更宽,我们可能更倾向于处理较少行或列的方法。
首先,让我们用更多的输入比输出进行基准测试:
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c7c030970>
jacfwd(predict, argnums=2)(weight, bias, x)
1.31 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c3062a140>
jacrev(predict, argnums=2)(weight, bias, x)
9.64 ms
1 measurement, 500 runs , 1 thread
然后进行一个相对基准测试:
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 633.6084 percent improvement with jacrev
现在是相反的情况 - 输出(M)多于输入(N):
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c2bb1bfd0>
jacfwd(predict, argnums=2)(weight, bias, x)
6.43 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c30334070>
jacrev(predict, argnums=2)(weight, bias, x)
823.32 us
1 measurement, 500 runs , 1 thread
并且进行相对性能比较:
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 680.5701 percent improvement with jacfwd
functorch.hessian进行海森矩阵计算¶
我们提供了一种便捷的API来计算海森矩阵:torch.func.hessiani。
海森矩阵是雅可比的雅可比(或二阶偏导数的偏导数,即二次导数)。
这表明可以仅仅通过组合functorch雅可比变换来计算海森矩阵。
确实,在底层,hessian(f) 简单来说就是 jacfwd(jacrev(f))。
注意:为了提升性能:根据上述关于宽矩阵与高矩阵的经验法则,您的模型也可能需要使用 jacfwd(jacfwd(f)) 或 jacrev(jacrev(f)) 来计算海森矩阵。
from torch.func import hessian
# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
让我们验证使用海森矩阵API和
使用 jacfwd(jacfwd()) 是否得到相同的结果。
True
批量雅可比矩阵和批量海森矩阵¶
在上述示例中,我们一直在处理单一特征向量。
在某些情况下,您可能需要计算一批输出相对于一批输入的雅可比矩阵。也就是说,给定一批形状为(B, N)的输入和一个从\(R^N \to R^M\)到(B, M, N)的函数,我们希望得到一个形状为(B, M, N)的雅可比矩阵。
这是最简单的方法,使用 vmap:
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])
如果你有一个从 (B, N) -> (B, M) 的函数,并且确定每个输入都会产生独立的输出,那么有时也可以通过求和输出然后计算该函数的雅可比矩阵来实现这一点,而不需要使用 vmap:
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果你有一个从 \(R^N \to R^M\) 输出但输入是批量的数据的功能函数,你可以将 vmap 与 jacrev 组合起来计算批量雅可比矩阵:
最后,批量海森矩阵也可以类似地计算。最容易想到的是通过使用 vmap 对海森矩阵计算进行批量处理,但在某些情况下,求和技巧也能奏效。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
计算海森矩阵向量积¶
计算海森矩阵-向量积(Hessian-vector product, HVP)的简单方法是先构建完整的海森矩阵,然后与一个向量进行点积运算。我们有更好的办法:实际上,我们不需要构建完整的海森矩阵来完成这一操作。我们将介绍两种(众多方法中的两种)不同的策略来计算海森矩阵-向量积: - 使用反向模式自动微分与反向模式自动微分相结合 - 使用反向模式自动微分与正向模式自动微分相结合
将反向模式自动微分与正向模式自动微分相结合(而不是反向模式与反向模式相结合)通常是一种更节省内存的方式来计算Hessian-向量积,因为正向模式自动微分不需要为反向构建Autograd图并保存中间结果。
以下是一些示例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 的前向自动微分不支持您的操作,那么我们可以将反向模式自动微分与反向模式自动微分相结合:
脚本总运行时间: ( 0 分钟 12.094 秒)