使用autograd.Function扩展torch.func¶
所以你想使用 torch.autograd.Function 与 torch.func
转换器如 torch.vmap(), torch.func.grad() 等。
有两个主要的使用场景:
您希望调用不包含PyTorch操作的代码,并使其与函数转换一起工作。也就是说,
torch.autograd.Function的 前向/后向等调用进入其他系统(如C++、CUDA、numpy)中的函数。您希望指定自定义梯度规则,例如 JAX的 custom_vjp/custom_jvp
PyTorch 将这两个概念结合到 torch.autograd.Function 中。
基本用法¶
本指南假设您熟悉 扩展 torch.autograd,
该指南解释了如何使用 torch.autograd.Function。
torch.autograd.Function 可以有一个 forward(),它接受一个 ctx 对象,
或者它可以有单独的 forward()(不接受 ctx) 和一个 setup_context()
静态方法来修改 ctx 对象。
仅支持后者使用函数转换:
forward()是执行操作的代码,它不应该接受一个ctx对象。setup_context(ctx, inputs, output)是你可以调用ctx上方法的代码。这里是你应该保存用于反向传播的张量的地方(通过调用ctx.save_for_backward(*tensors)),或者保存非张量(通过将它们分配给ctx对象)。
因为 setup_context() 只接受 inputs 和 output,
唯一可以保存的数量要么是输入或输出中的对象(如张量),
要么是从它们派生出的量(如 Tensor.shape)。
如果你想在反向传播时保存一个非输入的中间激活值来自
Function.forward(),那么你需要将其作为输出返回从 forward() 以便它被传递到
setup_context()。
根据转换方式,
为了支持反向模式自动微分(
torch.func.grad(),torch.func.vjp()),torch.autograd.Function需要一个backward()类静态方法。为了支持
torch.vmap(),torch.autograd.Function需要一个vmap()静态方法。为了支持
torch.func.jvp(),torch.autograd.Function需要一个jvp()静态方法。支持变换组合(如
torch.func.jacrev(),torch.func.jacfwd(),torch.func.hessian()) – 你可能需要多个 上述内容。
为了使 torch.autograd.Function 能够与函数变换任意组合,我们建议除了 forward() 和
setup_context() 之外的所有其他静态方法都必须是可变换的:也就是说,它们必须仅由 PyTorch 操作符或调用其他 torch.autograd.Function(这些操作可能会调用 C++/CUDA 等)组成。
让我们来看一些常见用例的例子。
示例1:autograd.Function调用另一个系统¶
一个常见的案例是 torch.autograd.Function 同时在 forward() 和 backward() 中调用
另一个系统(如 C++、CUDA、numpy、triton)。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
现在,为了更方便地使用NumpySort(隐藏我们作为输出返回的中间结果,并允许默认参数和关键字参数),我们创建了一个新的函数来调用它:
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
这是一个合理性检查:
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd.Function 指定自定义梯度规则¶
另一种常见的情况是使用PyTorch操作实现的torch.autograd.Function。PyTorch能够自动为PyTorch操作计算梯度,
但也许我们希望自定义梯度的计算方式。我们可能想要一个与PyTorch提供的不同的自定义反向传播的原因有:
提高数值稳定性
改变反向传播的性能特性
改变边缘情况的处理方式(例如:nans、inf)
修改梯度(例如,梯度裁剪)
这是一个函数 y = x ** 3 的示例,其中我们更改了性能特征(一些通常在反向传播过程中计算 dx 的计算会在前向传播过程中进行)。
torch.autograd.Function 表示该示例中的某个部分。
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了更方便地使用NumpySort(并隐藏我们作为输出返回的中间结果),我们创建了一个新函数来调用它:
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这是一个计算二阶梯度的合理性检查:
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和注意事项¶
警告
请仔细阅读这些关于 torch.autograd.Function 与 torch.func 转换的限制。我们无法捕获许多这些情况并优雅地报错,因此它们将导致未定义的行为。
请勿将正在转换的张量、requires_grad=True 的张量或双张量捕获到 torch.autograd.Function 的方法中。确保完全安全的方法是,torch.autograd.Function 的任何方法中使用的唯一张量必须直接作为输入传递(或通过 ctx 对象),而不是从 torch.autograd.Function 外部获取。
torch.autograd.Function 不处理 pytrees 中的张量(可能包含或不包含张量的任意嵌套 Python 数据结构)。为了使这些张量被 autograd 跟踪,它们必须直接作为参数传递给 torch.autograd.Function。这与 jax.{custom_vjp, custom_jvp} 形成对比,后者确实接受 pytrees。
请仅使用 save_for_backward() 或
save_for_forward() 来保存张量。
请不要直接将张量或张量集合赋值到 ctx 对象上 -
这些张量将不会被追踪
torch.vmap() 支持¶
要使用 torch.autograd.Function 与 torch.vmap(),您必须:
提供一个
vmap()类方法,告诉我们torch.autograd.Function在torch.vmap()下的行为请将
generate_vmap_rule=True设置为我们自动生成它。
自动生成vmap规则¶
如果你的 torch.autograd.Function 满足以下附加约束条件,那么我们可以为其生成一个vmap规则。如果它不满足这些约束条件,或者你希望在vmap下有自定义行为,请手动定义一个vmap静态方法(参见下一节)。
警告
我们无法轻易地检查以下约束并优雅地报错。违反这些约束可能会导致未定义的行为。
The
torch.autograd.Function的forward(),backward()(如果存在)和jvp()(如果存在)的静态方法必须可以通过torch.vmap()进行转换。也就是说, 它们必须仅由 PyTorch 操作组成(而不是例如 NumPy 或自定义 CUDA 内核)。
Example:
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义vmap静态方法¶
如果你的 torch.autograd.Function 调用另一个系统(如NumPy、C++、CUDA、triton),
那么为了使其与 torch.vmap() 或使用它的转换一起工作,你需要手动定义一个 vmap() 静态方法。
根据您想要使用的转换和您的使用情况,您可能不需要为所有的 vmap() 添加一个 torch.autograd.Function 的静态方法:
例如,
torch.func.jacrev()在反向传播过程中执行vmap()。 因此,如果你只对使用torch.func.jacrev()感兴趣,只需要 将backward()静态方法映射到虚拟机即可。
我们建议确保你的所有 torch.autograd.Function 都支持
torch.vmap(),特别是如果你正在编写第三方库,并希望你的
torch.autograd.Function 能与所有组合的 torch.func() 变换兼容。
概念上,vmap静态方法负责定义forward()
在torch.vmap()下应该如何表现。也就是说,它定义了如何将
forward()转换为在一个额外的维度上运行(该维度被vmapped)。这类似于
PyTorch操作中如何实现torch.vmap():对于每个操作,我们定义一个vmap规则(有时也称为“批处理规则”)。
以下是定义 vmap() 静态方法的方式:
签名是
vmap(info, in_dims: Tuple[Optional[int]], *args),其中*args与forward()的参数相同。vmap静态方法负责定义
forward()在torch.vmap()下应该如何表现。也就是说,给定具有额外维度的输入(由in_dims指定),我们如何计算forward()的批量版本?对于
args中的每个参数,in_dims都有一个对应的Optional[int]。 如果参数不是Tensor或者参数没有被vmapped处理,则为None, 否则,它是一个整数,指定Tensor的哪个维度正在被vmapped处理。info是一组可能有帮助的额外元数据:info.batch_size指定被 vmapped 的维度的大小,而info.randomness是传递给torch.vmap()的randomness选项。vmap静态方法的返回值是一个包含
(output, out_dims)个元素的元组。类似于in_dims,out_dims应该与output具有相同的结构,并且每个输出包含一个out_dim,用于指定输出是否具有vmapped维度以及其索引位置。
Example:
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap静态方法应旨在保留整个Function的语义。也就是说,(伪代码)grad(vmap(MyFunc))应该可以替换为grad(map(MyFunc))。
如果你的autograd.Function在反向传播过程中有任何自定义行为,请记住这一点。
注意
这是为Function编写自定义 vmap 静态方法的合法用例,PyTorch 能够通过generate_vmap_rule=True生成 vmap 规则。如果生成的 vmap 规则不符合您所需的语义,您可能希望这样做。
torch.func.jvp() 支持¶
要支持前向模式自动微分,torch.autograd.Function 必须有一个 jvp() 静态方法。
详情请参阅 前向模式自动微分。