使用 autograd.Function torch.func。功能¶
因此,您希望使用torch.autograd.Function
使用torch.func
转换,如torch.vmap()
,torch.func.grad()
等。
有两个主要用例:
您希望调用不包含 PyTorch作的代码,并且 让它与函数转换一起工作。也就是说,
torch.autograd.Function
的 forward/backward/etc 调用来自其他系统(如 C++、CUDA、numpy)的函数。您希望指定自定义渐变规则,例如 JAX 的 custom_vjp/custom_jvp
PyTorch 将这两个概念组合成torch.autograd.Function
.
基本用法¶
本指南假设你熟悉扩展 torch.autograd、
,它解释了如何使用torch.autograd.Function
.
torch.autograd.Function
可以具有forward()
接受 ctx 对象
或者它可以有单独的forward()
(不接受)和修改对象的 staticMethod。ctx
setup_context()
ctx
函数转换仅支持后者:
forward()
是执行作的代码,它不应接受 一个对象。ctx
setup_context(ctx, inputs, output)
是您可以 在 上调用方法 。这是您应该保存 Tensor 以供 backward 的地方 (通过调用 )或保存非 Tensor (通过将它们分配给对象)。ctx
ctx.save_for_backward(*tensors)
ctx
因为只接受 和 ,
唯一可以保存的数量是
输入或输出或从它们派生的数量 (如 )。
如果您希望从setup_context()
inputs
output
Tensor.shape
Function.forward()
对于 backward,则需要将其作为
输出forward()
,以便将其传递给 。setup_context()
根据转换,
要支持反向模式 AD (
torch.func.grad()
,torch.func.vjp()
), 这torch.autograd.Function
需要一个backward()
static方法。支持
torch.vmap()
这torch.autograd.Function
需要一个vmap()
static方法。支持
torch.func.jvp()
这torch.autograd.Function
需要一个jvp()
static方法。以支持转换的组合(如
torch.func.jacrev()
,torch.func.jacfwd()
,torch.func.hessian()
) – 您可能需要多个 上述的。
为了让torch.autograd.Function
可任意组合与 function
transforms,我们建议将forward()
并且必须是可转换的:也就是说,它们必须仅包含 PyTorch
运算符或调用其他setup_context()
torch.autograd.Function
(可以调用 C++/CUDA/etc)。
让我们回顾一些常见用例的示例。
示例 1:autograd。对另一个系统的函数调用¶
一个常见的情况是torch.autograd.Function
使用 forward() 和 backward() 调用
到另一个系统(如 C++、CUDA、numpy、triton)中。
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
# Note that forward does not take ctx
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
# Any intermediates to be saved in backward must be returned as
# outputs.
return (
# The desired output
torch.tensor(result, device=device),
# intermediate to save for backward
torch.tensor(ind, device=device),
# intermediate to save for backward
torch.tensor(ind_inv, device=device),
)
# setup_context is responsible for calling methods and/or assigning to
# the ctx object. Please do not do additional compute (e.g. add
# Tensors together) in setup_context.
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
# Note that output is whatever you returned from forward.
# If you returned multiple values, then output is a Tuple of multiple values.
# If you returned a single Tensor, then output is a Tensor.
# If you returned a Tuple with a single Tensor, then output is a
# Tuple with a single Tensor.
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
# Tensors must be saved via ctx.save_for_backward. Please do not
# assign them directly onto the ctx object.
ctx.save_for_backward(ind, ind_inv)
# Non-tensors may be saved by assigning them as attributes on the ctx object.
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
# For the autograd.Function to be arbitrarily composable with function
# transforms, all staticmethod other than forward and setup_context
# must be implemented in a "transformable" way; that is, they must
# only consist of PyTorch operations or autograd.Function.
#
# For example, this allows us to do double backwards and/or compute
# second order gradients.
#
# We've written the backward pass of NumpySort in terms of another
# autograd.Function, NumpyTake.
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
现在,为了使其更易于使用(为了隐藏中间体,我们
作为输出返回,并允许默认 args 和 kwargs),我们会创建一个新的
调用它的函数:NumpySort
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
下面是一个健全性检查:
x = torch.randn(2, 3)
grad_x = torch.func.grad(lambda x: numpy_sort(x).sum())(x)
assert torch.allclose(grad_x, torch.ones_like(x))
示例 2:autograd。函数指定自定义渐变规则¶
另一种常见情况是torch.autograd.Function
使用 PyTorch 实现的
操作。PyTorch 能够自动计算 PyTorch作的梯度,
但也许我们希望自定义梯度的计算方式。一些原因
我们可能需要一个与 PyTorch 提供给我们的自定义不同的向后自定义是:
提高数值稳定性
更改 backward 的性能特征
更改边缘情况的处理方式(例如 Nans、InF)
修改渐变(例如渐变裁剪)
下面是一个torch.autograd.Function
对于函数,其中我们
更改性能特征(通常会发生的一些计算
在向后传递期间,计算 dx 发生在向前传递中)。y = x ** 3
class MyCube(torch.autograd.Function):
@staticmethod
def forward(x):
result = x ** 3
# In regular PyTorch, if we had just run y = x ** 3, then the backward
# pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
# that computation here in the forward pass instead.
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
# In order for the autograd.Function to work with higher-order
# gradients, we must add the gradient contribution of `dx`.
result = grad_output * dx + grad_dx * 6 * x
return result
现在,为了使其更易于使用(并隐藏中间体,我们
返回为 outputs),我们创建一个调用它的新函数:NumpySort
def my_cube(x):
result, _ = MyCube.apply(x)
return result
这是计算二阶梯度的健全性检查:
x = torch.randn([])
ggx = torch.func.grad(torch.func.grad(my_cube))(x)
assert torch.allclose(ggx, 6 * x)
限制和陷阱¶
警告
请阅读torch.autograd.Function
使用 torch.func 转换
仔细。我们无法捕捉到其中的许多情况和错误
优雅地,因此它们将导致未定义的行为。
请不要捕获正在转换的 Tensor,请
requires_grad=True 或双张量,则转换为torch.autograd.Function
.完全安全的方法是确保唯一的
在torch.autograd.Function
必须直接
作为输入传递(或通过 ctx 对象传递),而不是来自外部
这torch.autograd.Function
.
torch.autograd.Function
不处理 pytree 中的 Tensor(任意嵌套
可能包含也可能不包含 Tensor 的 Python 数据结构)。为
那些要被 autograd 跟踪的 Tensor 时,它们必须直接作为
一个参数torch.autograd.Function
.这与
贾克斯。{custom_vjp, custom_jvp},它们确实接受 pytree。
请仅使用save_for_backward()
或保存 Tensor。
请不要将 Tensor 或 Tensor 集合直接分配给 ctx 对象 -
这些 Tensor 不会被跟踪save_for_forward()
torch.vmap()
支持¶
要使用torch.autograd.Function
跟torch.vmap()
,您必须:
提供
vmap()
static方法,它告诉我们torch.autograd.Function
下torch.vmap()
要求我们通过设置 来自动生成它。
generate_vmap_rule=True
自动生成 vmap 规则¶
如果您的torch.autograd.Function
满足以下附加约束,则
能够为其生成 vmap 规则。如果它不满足约束,或者您
想要在 vmap 下自定义行为,请手动定义一个 vmap staticmethod(参见下一节)。
警告
我们不能轻易检查以下约束和错误 优雅地出来。违反约束可能会导致 undefined 行为。
这
torch.autograd.Function
的forward()
,backward()
(如果存在)和jvp()
(如果存在)staticMethods 必须可以通过torch.vmap()
.那 是,它们必须仅包含 PyTorch作(而不是 NumPy 或自定义 CUDA 内核)。
例:
class MyCube(torch.autograd.Function):
# Set generate_vmap_rule to True to ask PyTorch to automatically generate
# a vmap rule.
generate_vmap_rule = True
@staticmethod
def forward(x):
result = x ** 3
dx = 3 * x ** 2
return result, dx
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result, dx = output
ctx.save_for_backward(x, dx)
@staticmethod
def backward(ctx, grad_output, grad_dx):
x, dx = ctx.saved_tensors
result = grad_output * dx + grad_dx * 6 * x
return result
def my_cube(x):
result, dx = MyCube.apply(x)
return result
x = torch.randn(3)
result = torch.vmap(my_cube)(x)
assert torch.allclose(result, x ** 3)
定义 vmap staticmethod¶
如果您的torch.autograd.Function
调用另一个系统(如 NumPy、C++、CUDA、triton),
然后让它工作torch.vmap()
或使用它的 transforms 的 Transform 中,你将
需要手动定义一个vmap()
static方法。
根据您想要使用的转换和您的用例,您可能不需要
要添加vmap()
static方法添加到所有torch.autograd.Function
:
例如
torch.func.jacrev()
执行vmap()
越过 backward pass。 因此,如果您只对使用torch.func.jacrev()
只 这backward()
static方法必须是 VMAppable 的。
我们建议您确保所有torch.autograd.Function
支持torch.vmap()
不过,特别是如果您正在编写第三方库并且希望您的torch.autograd.Function
可使用torch.func()
变换。
从概念上讲,vmap static方法负责定义forward()
应该在torch.vmap()
.也就是说,它定义了如何转换
这forward()
要运行具有附加维度的输入(维度
正在 vmap 覆盖)。这类似于torch.vmap()
在
PyTorch作:对于每个作,我们定义一个 vmap 规则(有时也
称为 “批处理规则”)。
下面介绍如何定义vmap()
static方法:
签名为 ,其中 与 args 相同
vmap(info, in_dims: Tuple[Optional[int]], *args)
*args
forward()
.vmap static方法负责定义
forward()
应该表现 下torch.vmap()
.也就是说,给定具有附加维度的输入 (由 指定),我们如何计算 的批处理版本in_dims
forward()
?对于 中的每个 arg ,都有一个对应的 。 如果 arg 不是 Tensor 或者 arg 没有被 vmap 覆盖, 否则,它是一个整数,指定要进行 vmap 的 Tensor 的维度 多。
args
in_dims
Optional[int]
None
info
是可能有用的其他元数据的集合:指定要 vmap 的维度的大小,而 是传递给info.batch_size
info.randomness
randomness
torch.vmap()
.vmap staticmethod 的返回值是 的元组。类似 to 应与 和 包含 每个输出一个,用于指定输出是否具有 vmapped 维度及其所在的索引。
(output, out_dims)
in_dims
out_dims
output
out_dim
例:
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpySort(torch.autograd.Function):
@staticmethod
def forward(x, dim):
device = x.device
x = to_numpy(x)
ind = np.argsort(x, axis=dim)
ind_inv = np.argsort(ind, axis=dim)
result = np.take_along_axis(x, ind, axis=dim)
return (
torch.tensor(result, device=device),
torch.tensor(ind, device=device),
torch.tensor(ind_inv, device=device),
)
@staticmethod
def setup_context(ctx, inputs, output):
x, dim = inputs
_, ind, ind_inv = output
ctx.mark_non_differentiable(ind, ind_inv)
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output, _0, _1):
ind, ind_inv = ctx.saved_tensors
return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
# The signature of the vmap staticmethod is:
# vmap(info, in_dims: Tuple[Optional[int]], *args)
# where *args is the same as the arguments to `forward`.
@staticmethod
def vmap(info, in_dims, x, dim):
# For every input (x and dim), in_dims stores an Optional[int]
# that is:
# - None if the input is not being vmapped over or if the input
# is not a Tensor
# - an integer if the input is being vmapped over that represents
# the index of the dimension being vmapped over.
x_bdim, _ = in_dims
# A "vmap rule" is the logic of how to perform the operation given
# inputs with one additional dimension. In NumpySort, x has an
# additional dimension (x_bdim). The vmap rule is simply
# to call NumpySort again but pass it a different `dim`.
x = x.movedim(x_bdim, 0)
# Handle negative dims correctly
dim = dim if dim >= 0 else dim + x.dim() - 1
result = NumpySort.apply(x, dim + 1)
# The vmap rule must return a tuple of two things
# 1. the output. Should be the same amount of things
# as returned by the forward().
# 2. one Optional[int] for each output specifying if each output
# is being vmapped over, and if so, the index of the
# dimension being vmapped over.
#
# NumpySort.forward returns a Tuple of 3 Tensors. Since we moved the
# dimension being vmapped over to the front of `x`, that appears at
# dimension 0 of all outputs.
# The return is (output, out_dims) -- output is a tuple of 3 Tensors
# and out_dims is a Tuple of 3 Optional[int]
return NumpySort.apply(x, dim + 1), (0, 0, 0)
class NumpyTake(torch.autograd.Function):
@staticmethod
def forward(x, ind, ind_inv, dim):
device = x.device
x = to_numpy(x)
ind = to_numpy(ind)
return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
@staticmethod
def setup_context(ctx, inputs, output):
x, ind, ind_inv, dim = inputs
ctx.save_for_backward(ind, ind_inv)
ctx.dim = dim
@staticmethod
def backward(ctx, grad_output):
ind, ind_inv = ctx.saved_tensors
result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
return result, None, None, None
@staticmethod
def vmap(info, in_dims, x, ind, ind_inv, dim):
x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
# The strategy is: expand {x, ind, ind_inv} to all have the dimension
# being vmapped over.
# Then, call back into NumpyTake(expanded_x, expanded_ind, expanded_ind_inv, new_dim).
# Handle negative dims by wrapping them to be positive
logical_dim = x.dim() if x_bdim is None else x_bdim - 1
dim = dim if dim >= 0 else dim + logical_dim
def maybe_expand_bdim_at_front(x, x_bdim):
if x_bdim is None:
return x.expand(info.batch_size, *x.shape)
return x.movedim(x_bdim, 0)
# If the Tensor doesn't have the dimension being vmapped over,
# expand it out. Otherwise, move it to the front of the Tensor
x = maybe_expand_bdim_at_front(x, x_bdim)
ind = maybe_expand_bdim_at_front(ind, ind_bdim)
ind_inv = maybe_expand_bdim_at_front(ind_inv, ind_inv_bdim)
# The return is a tuple (output, out_dims). Since output is a Tensor,
# then out_dims is an Optional[int] (instead of being a Tuple).
return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
def numpy_sort(x, dim=-1):
result, _, _ = NumpySort.apply(x, dim)
return result
x = torch.randn(2, 3)
result = torch.vmap(numpy_sort)(x)
assert torch.allclose(result, numpy_sort(result, 1))
注意
vmap static方法应该旨在保留
整个Function
.也就是说,(伪代码)应该可以用 .grad(vmap(MyFunc))
grad(map(MyFunc))
如果你的 autograd.函数在向后传递中有任何自定义行为,请 请记住这一点。
注意
为Function
PyTorch 能够生成 vmap
通过 的规则 。如果
生成的 vmap 规则没有您要查找的语义。generate_vmap_rule=True
torch.func.jvp()
支持¶
为了支持正向模式 AD,一个torch.autograd.Function
必须具有jvp()
static方法。
有关详细信息,请参阅转发模式 AD。