注意
单击此处下载完整的示例代码
正向模式自动微分 (Beta)¶
创建时间: 2021年12月07日 |上次更新时间:2023 年 4 月 18 日 |上次验证: Nov 05, 2024
本教程演示如何使用前向模式 AD 进行计算 方向导数(或等效的 Jacobian 向量积)。
下面的教程使用了一些仅在版本 >= 1.11 中可用的 API (或夜间构建)。
另请注意,前向模式 AD 目前处于测试阶段。API 是 可能会发生变化,运维覆盖范围仍不完整。
基本用法¶
与反向模式 AD 不同,正向模式 AD 急切地计算梯度
沿着前传。我们可以使用前向模式 AD 来计算
Directional 导数,如前所述执行 forward pass,
除了我们首先将我们的输入与另一个表示
方向导数的方向(或等效地,在 Jacobian 向量积中)。当一个输入(我们称为 “primal”)是
与我们称为 “tangent” 的 “direction” 张量相关联,
生成的 New Tensor 对象在其连接中称为 “dual Tensor”
设置为双数 [0]。v
在执行前向传递时,如果任何输入张量是双张量,则 执行额外的计算来传播 功能。
import torch
import torch.autograd.forward_ad as fwAD
primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)
def fn(x, y):
return x ** 2 + y ** 2
# All forward AD computation must be performed in the context of
# a ``dual_level`` context. All dual tensors created in such a context
# will have their tangents destroyed upon exit. This is to ensure that
# if the output or intermediate results of this computation are reused
# in a future forward AD computation, their tangents (which are associated
# with this computation) won't be confused with tangents from the later
# computation.
with fwAD.dual_level():
# To create a dual tensor we associate a tensor, which we call the
# primal with another tensor of the same size, which we call the tangent.
# If the layout of the tangent is different from that of the primal,
# The values of the tangent are copied into a new tensor with the same
# metadata as the primal. Otherwise, the tangent itself is used as-is.
#
# It is also important to note that the dual tensor created by
# ``make_dual`` is a view of the primal.
dual_input = fwAD.make_dual(primal, tangent)
assert fwAD.unpack_dual(dual_input).tangent is tangent
# To demonstrate the case where the copy of the tangent happens,
# we pass in a tangent with a layout different from that of the primal
dual_input_alt = fwAD.make_dual(primal, tangent.T)
assert fwAD.unpack_dual(dual_input_alt).tangent is not tangent
# Tensors that do not have an associated tangent are automatically
# considered to have a zero-filled tangent of the same shape.
plain_tensor = torch.randn(10, 10)
dual_output = fn(dual_input, plain_tensor)
# Unpacking the dual returns a ``namedtuple`` with ``primal`` and ``tangent``
# as attributes
jvp = fwAD.unpack_dual(dual_output).tangent
assert fwAD.unpack_dual(dual_output).tangent is None
与模块一起使用¶
要与转发 AD 一起使用,请将
模型。在
写入时,无法创建双张量 'nn。参数的。解决方法是,必须注册 dual Tensor
作为模块的非参数属性。nn.Module
import torch.nn as nn
model = nn.Linear(5, 5)
input = torch.randn(16, 5)
params = {name: p for name, p in model.named_parameters()}
tangents = {name: torch.rand_like(p) for name, p in params.items()}
with fwAD.dual_level():
for name, p in params.items():
delattr(model, name)
setattr(model, name, fwAD.make_dual(p, tangents[name]))
out = model(input)
jvp = fwAD.unpack_dual(out).tangent
使用功能模块 API(测试版)¶
与正向 AD 一起使用的另一种方法是使用
功能性模块 API(也称为无状态模块 API)。nn.Module
from torch.func import functional_call
# We need a fresh module because the functional call requires the
# the model to have parameters registered.
model = nn.Linear(5, 5)
dual_params = {}
with fwAD.dual_level():
for name, p in params.items():
# Using the same ``tangents`` from the above section
dual_params[name] = fwAD.make_dual(p, tangents[name])
out = functional_call(model, dual_params, input)
jvp2 = fwAD.unpack_dual(out).tangent
# Check our results
assert torch.allclose(jvp, jvp2)
自定义 autograd 函数¶
自定义函数还支持正向模式 AD。创建自定义函数
支持正向模式 AD,注册 static 方法。是的
可能,但对于自定义函数来说,支持两者 forward 并非强制性的
和向后 AD。有关更多信息,请参阅文档。jvp()
class Fn(torch.autograd.Function):
@staticmethod
def forward(ctx, foo):
result = torch.exp(foo)
# Tensors stored in ``ctx`` can be used in the subsequent forward grad
# computation.
ctx.result = result
return result
@staticmethod
def jvp(ctx, gI):
gO = gI * ctx.result
# If the tensor stored in`` ctx`` will not also be used in the backward pass,
# one can manually free it using ``del``
del ctx.result
return gO
fn = Fn.apply
primal = torch.randn(10, 10, dtype=torch.double, requires_grad=True)
tangent = torch.randn(10, 10)
with fwAD.dual_level():
dual_input = fwAD.make_dual(primal, tangent)
dual_output = fn(dual_input)
jvp = fwAD.unpack_dual(dual_output).tangent
# It is important to use ``autograd.gradcheck`` to verify that your
# custom autograd Function computes the gradients correctly. By default,
# ``gradcheck`` only checks the backward-mode (reverse-mode) AD gradients. Specify
# ``check_forward_ad=True`` to also check forward grads. If you did not
# implement the backward formula for your function, you can also tell ``gradcheck``
# to skip the tests that require backward-mode AD by specifying
# ``check_backward_ad=False``, ``check_undefined_grad=False``, and
# ``check_batched_grad=False``.
torch.autograd.gradcheck(Fn.apply, (primal,), check_forward_ad=True,
check_backward_ad=False, check_undefined_grad=False,
check_batched_grad=False)
True
功能 API(测试版)¶
我们还在 functorch 中提供了更高级别的函数式 API 用于计算您可能觉得使用更简单的雅可比向量积 具体取决于您的用例。
函数式 API 的好处是无需理解 或者使用较低级别的双张量 API,并且您可以使用 其他 functorch 转换(如 vmap); 缺点是它为您提供的控制较少。
请注意,本教程的其余部分将需要 functorch (https://github.com/pytorch/functorch) 运行。请查找安装 说明。
import functorch as ft
primal0 = torch.randn(10, 10)
tangent0 = torch.randn(10, 10)
primal1 = torch.randn(10, 10)
tangent1 = torch.randn(10, 10)
def fn(x, y):
return x ** 2 + y ** 2
# Here is a basic example to compute the JVP of the above function.
# The ``jvp(func, primals, tangents)`` returns ``func(*primals)`` as well as the
# computed Jacobian-vector product (JVP). Each primal must be associated with a tangent of the same shape.
primal_out, tangent_out = ft.jvp(fn, (primal0, primal1), (tangent0, tangent1))
# ``functorch.jvp`` requires every primal to be associated with a tangent.
# If we only want to associate certain inputs to `fn` with tangents,
# then we'll need to create a new function that captures inputs without tangents:
primal = torch.randn(10, 10)
tangent = torch.randn(10, 10)
y = torch.randn(10, 10)
import functools
new_fn = functools.partial(fn, y=y)
primal_out, tangent_out = ft.jvp(new_fn, (primal,), (tangent,))
/var/lib/workspace/intermediate_source/forward_ad_usage.py:203: FutureWarning:
We've integrated functorch into PyTorch. As the final step of the integration, `functorch.jvp` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.jvp` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html
/var/lib/workspace/intermediate_source/forward_ad_usage.py:214: FutureWarning:
We've integrated functorch into PyTorch. As the final step of the integration, `functorch.jvp` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.jvp` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html
将函数式 API 与模块一起使用¶
用于计算雅可比向量积
对于模型参数,我们需要将 AS 重新表述为同时接受模型参数和输入的函数
添加到模块中。nn.Module
functorch.jvp
nn.Module
model = nn.Linear(5, 5)
input = torch.randn(16, 5)
tangents = tuple([torch.rand_like(p) for p in model.parameters()])
# Given a ``torch.nn.Module``, ``ft.make_functional_with_buffers`` extracts the state
# (``params`` and buffers) and returns a functional version of the model that
# can be invoked like a function.
# That is, the returned ``func`` can be invoked like
# ``func(params, buffers, input)``.
# ``ft.make_functional_with_buffers`` is analogous to the ``nn.Modules`` stateless API
# that you saw previously and we're working on consolidating the two.
func, params, buffers = ft.make_functional_with_buffers(model)
# Because ``jvp`` requires every input to be associated with a tangent, we need to
# create a new function that, when given the parameters, produces the output
def func_params_only(params):
return func(params, buffers, input)
model_output, jvp_out = ft.jvp(func_params_only, (params,), (tangents,))
/var/lib/workspace/intermediate_source/forward_ad_usage.py:235: FutureWarning:
We've integrated functorch into PyTorch. As the final step of the integration, `functorch.make_functional_with_buffers` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.functional_call` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html
/var/lib/workspace/intermediate_source/forward_ad_usage.py:242: FutureWarning:
We've integrated functorch into PyTorch. As the final step of the integration, `functorch.jvp` is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use `torch.func.jvp` instead; see the PyTorch 2.0 release notes and/or the `torch.func` migration guide for more details https://pytorch.org/docs/main/func.migrating.html
[0] https://en.wikipedia.org/wiki/Dual_number
脚本总运行时间:(0 分 0.163 秒)