注意
单击此处下载完整的示例代码
雅可比矩阵、Hessian矩阵、hvp、vhp 等:组合函数变换¶
创建时间: 2023年3月15日 |上次更新时间:2023 年 4 月 18 日 |上次验证: Nov 05, 2024
计算雅可比矩阵或 Hessian 矩阵在许多非传统
深度学习模型。计算这些量很困难(或很烦人)
高效使用 PyTorch 的常规 autodiff API
(, ).PyTorch 的 JAX 启发函数 transforms API 提供了计算各种高阶 autodiff 量的方法
有效。Tensor.backward()
torch.autograd.grad
注意
本教程需要 PyTorch 2.0.0 或更高版本。
计算雅可比矩阵¶
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
让我们从一个我们想要计算雅可比矩阵的函数开始。 这是一个具有非线性激活的简单线性函数。
让我们添加一些虚拟数据:权重、偏差和特征向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
让我们把这个函数想象成一个将 \(R^D \) 到 R^D\) 的输入映射的函数。
PyTorch Autograd 计算向量雅可比积。为了计算完整的
这个 \(R^D \to R^D\) 函数的雅可比行式,我们必须逐行计算它
每次使用不同的单位向量。predict
x
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
我们可以使用 PyTorch 的函数 transform 来摆脱 for 循环并矢量化
计算。我们不能直接申请 ;
相反,PyTorch 提供了一个转换,该转换由以下各项组成:torch.vmap
vmap
torch.autograd.grad
torch.func.vjp
torch.vmap
from torch.func import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在后面的教程中,反向模式 AD 的组合 将给我们
per-sample-gradients 的
在本教程中,编写反向模式 AD 并得到雅可比矩阵
计算!
和 autodiff 变换的各种组合可以给我们带来不同的
有趣的数量。vmap
vmap
vmap
PyTorch 提供了一个方便的函数,该函数执行
用于计算雅可比矩阵的组合。 接受一个参数,该参数表示我们想用哪个参数来计算雅可比矩阵
尊重。torch.func.jacrev
vmap-vjp
jacrev
argnums
from torch.func import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
让我们比较一下两种计算雅可比行列式的方法的性能。 函数 transform 版本要快得多(并且 有更多的输出)。
一般来说,我们希望矢量化 via 可以帮助消除开销
并更好地利用您的硬件。vmap
vmap
通过将外部循环向下推入函数的
primitive 操作以获得更好的性能。
让我们快速创建一个函数来评估性能并处理 微秒和毫秒测量:
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然后运行性能比较:
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c42779990>
compute_jac(xp)
2.77 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f9c30629660>
jacrev(predict, argnums=2)(weight, bias, x)
717.70 us
1 measurement, 500 runs , 1 thread
让我们用我们的函数对上述内容进行相对性能比较:get_perf
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
Performance delta: 74.0459 percent improvement with vmap
此外,很容易将问题反转过来,说我们想 计算模型参数(权重、偏差)的雅可比行列式,而不是输入
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩阵 () 与正向模式雅可比矩阵 (jacrev
jacfwd
)¶
我们提供两个 API 来计算雅可比矩阵:和 :jacrev
jacfwd
jacrev
使用反向模式 AD。正如你在上面看到的,它是 our 和 transforms 的组合。vjp
vmap
jacfwd
使用正向模式 AD。它是作为 our 和 transforms 的组合实现的。jvp
vmap
jacfwd
并且可以相互替换,但它们具有不同的
性能特征。jacrev
作为一般的经验法则,如果你正在计算 \(R^N \to R^M\) 函数的雅可比行列式,并且输出比输入多得多(例如,\(M > N\)),那么最好使用 ,否则使用 。此规则也有例外,
但对此的一个非严格的论点如下:jacfwd
jacrev
在反向模式 AD 中,我们逐行计算雅可比行式,而在 forward-mode AD(计算雅可比向量积),我们正在计算 它逐列。雅可比矩阵有 M 行和 N 列,因此如果它 更高或更宽的一种方式,我们可能更喜欢处理较少的方法 行或列。
首先,让我们使用比输出更多的输入进行基准测试:
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c7c030970>
jacfwd(predict, argnums=2)(weight, bias, x)
1.31 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c3062a140>
jacrev(predict, argnums=2)(weight, bias, x)
9.64 ms
1 measurement, 500 runs , 1 thread
然后执行 Relative Benchmark:
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 633.6084 percent improvement with jacrev
现在反过来 - 输出 (M) 多于输入 (N):
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c2bb1bfd0>
jacfwd(predict, argnums=2)(weight, bias, x)
6.43 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f9c30334070>
jacrev(predict, argnums=2)(weight, bias, x)
823.32 us
1 measurement, 500 runs , 1 thread
以及相对性能比较:
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 680.5701 percent improvement with jacfwd
使用 functorch.hessian 进行 Hessian 计算¶
我们提供了一个方便的 API 来计算 hessians:.
Hessian 矩阵是 jacobian 矩阵(或
偏导数,又名二阶)。torch.func.hessiani
这表明可以将 functorch jacobian 转换组合为
计算 Hessian 矩阵。
确实,在引擎盖下,就是 .hessian(f)
jacfwd(jacrev(f))
注意:要提高性能:根据您的型号,您可能还希望
使用 or 来计算 Hessian 矩阵
利用上述关于较宽与较高矩阵的经验法则。jacfwd(jacfwd(f))
jacrev(jacrev(f))
from torch.func import hessian
# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
让我们验证一下,无论使用 hessian API 还是
用。jacfwd(jacfwd())
True
Batch Jacobian 和 Batch Hessian¶
在上面的例子中,我们一直在使用单个特征向量进行操作。
在某些情况下,您可能希望采用一批输出的雅可比矩阵
对于一批输入。也就是说,给定一批
shape 和一个从 \(R^N \to R^M\) 的函数,我们希望
形状的雅可比行列式 。(B, N)
(B, M, N)
最简单的方法是使用 :vmap
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])
如果你有一个从 (B, N) -> (B, M) 的函数,并且是
确定每个 input 都产生一个独立的 output,那么它也
有时可以通过对 Outputs 求和
然后计算该函数的雅可比行列式:vmap
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果你有一个从 \(R^N \到 R^M\) 的函数,但输入的
是 batched,您可以 compose 来计算 batched jacobians:vmap
jacrev
最后,可以类似地计算批处理 Hessian 矩阵。最容易思考
通过对 Hessian 计算进行批处理,但在某些
情况下,求和技巧也有效。vmap
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
计算 Hessian 向量积¶
计算 Hessian 向量积 (hvp) 的天真方法是实现 完整的 Hessian 矩阵,并使用向量执行点积。我们可以做得更好: 事实证明,我们不需要物化完整的 Hessian 矩阵来执行此操作。我们将 通过两种(多种)不同的策略来计算 Hessian 向量积: - 使用反向模式 AD 组合 - 使用正向模式 AD 组合反向模式 AD
使用正向模式 AD 编写反向模式 AD(而不是反向模式 with reverse-mode)通常是计算 hvp 的 Alpha 创建,因为前向模式 AD 不需要构造 Autograd 图,而 保存中间体以供向后使用:
以下是一些示例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch forward-AD 没有覆盖您的操作,那么我们可以 而是使用反向模式 AD 编写反向模式 AD:
脚本总运行时间:(0 分 12.094 秒)