目录

具有自定义功能的双后退

创建时间: Aug 13, 2021 |上次更新时间:2021 年 8 月 13 日 |上次验证: Nov 05, 2024

有时,通过向后图形向后运行两次是有用的,因为 example 来计算高阶梯度。它需要理解 但是,autograd 和一些人关心支持双向后。功能 不一定支持向后执行一次 配备支持双后退。在本教程中,我们将展示如何 编写一个支持 double backward 的自定义 autograd 函数,以及 指出一些需要注意的事项。

当编写自定义 autograd 函数以向后返回两次时, 了解何时在自定义函数中执行操作非常重要 被 Autograd 记录,当它们没有被记录时,最重要的是,save_for_backward 如何处理所有这些。

自定义函数以两种方式隐式影响 grad 模式:

  • 在转发期间,autograd 不会记录任何 在 forward 函数中执行的操作。转发时 completes,自定义函数的 backward 函数 成为每个 forward 输出的 grad_fn

  • 在向后过程中,autograd 会记录用于 如果指定了 create_graph,则计算向后传递

接下来,要了解 save_for_backward 如何与上述内容交互, 我们可以探索几个例子:

保存输入

考虑这个简单的平方函数。它保存了一个输入张量 为向后。double backward 在 autograd 时自动工作 能够在 backward pass 中记录操作,因此通常有 当我们为 backward AS 保存 input 时,无需担心 如果 input 是任何 Tensor 的函数,则 input 应具有 grad_fn 这需要 grad.这允许正确传播渐变。

import torch

class Square(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # Because we are saving one of the inputs use `save_for_backward`
        # Save non-tensors and non-inputs/non-outputs directly on ctx
        ctx.save_for_backward(x)
        return x**2

    @staticmethod
    def backward(ctx, grad_out):
        # A function support double backward automatically if autograd
        # is able to record the computations performed in backward
        x, = ctx.saved_tensors
        return grad_out * 2 * x

# Use double precision because finite differencing method magnifies errors
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Square.apply, x)
# Use gradcheck to verify second-order derivatives
torch.autograd.gradgradcheck(Square.apply, x)

我们可以使用 torchviz 来可视化图形,看看为什么这样

import torchviz

x = torch.tensor(1., requires_grad=True).clone()
out = Square.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

我们可以看到 wrt 到 x 的梯度本身就是 x 的函数 (dout/dx = 2x) 并且这个函数的图已经被正确构建

https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png

保存输出

与前面的示例略有不同,将 output 而不是 input。机制是相似的,因为 output 也是 与grad_fn关联。

class Exp(torch.autograd.Function):
    # Simple case where everything goes well
    @staticmethod
    def forward(ctx, x):
        # This time we save the output
        result = torch.exp(x)
        # Note that we should use `save_for_backward` here when
        # the tensor saved is an ouptut (or an input).
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_out):
        result, = ctx.saved_tensors
        return result * grad_out

x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
# Validate our gradients using gradcheck
torch.autograd.gradcheck(Exp.apply, x)
torch.autograd.gradgradcheck(Exp.apply, x)

使用 torchviz 可视化图形:

out = Exp.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})
https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png

保存中间结果

更棘手的情况是当我们需要保存一个中间结果时。 我们通过实施来演示此案例:

\[sinh(x) := \frac{e^x - e^{-x}}{2} \]

由于 sinh 的导数是 cosh,因此重用 exp(x)exp(-x) 可能很有用,这两个中间结果在 forward 中 在反向计算中。

不过,中间结果不应该直接保存和向后使用。 因为 forward 是在 no-grad 模式下执行的,所以如果中间结果 用于计算 backward pass 中的梯度 梯度的反向图将不包括运算 ,它计算了中间结果。这会导致不正确的梯度。

class Sinh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        expx = torch.exp(x)
        expnegx = torch.exp(-x)
        ctx.save_for_backward(expx, expnegx)
        # In order to be able to save the intermediate results, a trick is to
        # include them as our outputs, so that the backward graph is constructed
        return (expx - expnegx) / 2, expx, expnegx

    @staticmethod
    def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp):
        expx, expnegx = ctx.saved_tensors
        grad_input = grad_out * (expx + expnegx) / 2
        # We cannot skip accumulating these even though we won't use the outputs
        # directly. They will be used later in the second backward.
        grad_input += _grad_out_exp * expx
        grad_input -= _grad_out_negexp * expnegx
        return grad_input

def sinh(x):
    # Create a wrapper that only returns the first output
    return Sinh.apply(x)[0]

x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(sinh, x)
torch.autograd.gradgradcheck(sinh, x)

使用 torchviz 可视化图形:

out = sinh(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
https://user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png

保存中间结果:不该做什么

现在我们展示当我们不返回中间 results 作为输出:grad_x 甚至不会有反向图 因为它纯粹是一个函数 expexpnegx,而它们没有 需要 grad。

class SinhBad(torch.autograd.Function):
    # This is an example of what NOT to do!
    @staticmethod
    def forward(ctx, x):
        expx = torch.exp(x)
        expnegx = torch.exp(-x)
        ctx.expx = expx
        ctx.expnegx = expnegx
        return (expx - expnegx) / 2

    @staticmethod
    def backward(ctx, grad_out):
        expx = ctx.expx
        expnegx = ctx.expnegx
        grad_input = grad_out * (expx + expnegx) / 2
        return grad_input

使用 torchviz 可视化图形。请注意,grad_x 不是 图表的一部分!

out = SinhBad.apply(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
https://user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png

当 Backward 未被跟踪时

最后,让我们考虑一个可能无法使用 autograd 来向后跟踪函数的梯度。 我们可以想象cube_backward是一个可能需要 非 PyTorch 库(如 SciPy 或 NumPy),或编写为 C++ 扩展。此处演示的解决方法是创建另一个 自定义函数 CubeBackward,其中您还可以手动指定 倒退cube_backward!

def cube_forward(x):
    return x**3

def cube_backward(grad_out, x):
    return grad_out * 3 * x**2

def cube_backward_backward(grad_out, sav_grad_out, x):
    return grad_out * sav_grad_out * 6 * x

def cube_backward_backward_grad_out(grad_out, x):
    return grad_out * 3 * x**2

class Cube(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return cube_forward(x)

    @staticmethod
    def backward(ctx, grad_out):
        x, = ctx.saved_tensors
        return CubeBackward.apply(grad_out, x)

class CubeBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, grad_out, x):
        ctx.save_for_backward(x, grad_out)
        return cube_backward(grad_out, x)

    @staticmethod
    def backward(ctx, grad_out):
        x, sav_grad_out = ctx.saved_tensors
        dx = cube_backward_backward(grad_out, sav_grad_out, x)
        dgrad_out = cube_backward_backward_grad_out(grad_out, x)
        return dgrad_out, dx

x = torch.tensor(2., requires_grad=True, dtype=torch.double)

torch.autograd.gradcheck(Cube.apply, x)
torch.autograd.gradgradcheck(Cube.apply, x)

使用 torchviz 可视化图形:

out = Cube.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})
https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png

总而言之,double backward 是否适用于您的自定义函数 这仅仅取决于 Autograd 是否可以跟踪向后传递。 通过前两个示例,我们展示了 double backward 的情况 开箱即用。通过第三个和第四个示例,我们进行了演示 允许跟踪 backward 函数的技术,当它们 否则不会。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源