目录

torch.func

torch.func,之前称为“functorch”,是用于PyTorch的 类似于JAX的可组合函数转换。

注意

这个库目前处于测试版。 这意味着功能通常可以正常工作(除非另有文档说明), 并且我们(PyTorch 团队)致力于推进这个库的发展。然而,根据用户反馈,API 可能会发生变化, 并且我们还没有完全覆盖 PyTorch 的所有操作。

如果您有关于 API 或您希望涵盖的用例的建议,请打开 GitHub 问题或联系我们。我们很想知道您是如何使用该库的。

什么是可组合函数变换?

  • 一个“函数变换”是一个高阶函数,它接受一个数值函数并返回一个新的函数,用于计算不同的量。

  • torch.func 具有自动微分变换(grad(f) 返回一个计算f梯度的函数),向量化/批量变换(vmap(f) 返回一个在输入批次上计算f的函数),以及其他功能。

  • 这些函数变换可以任意组合。例如,组合vmap(grad(f))计算一个称为每样本梯度的量, 而今天的标准PyTorch无法高效地计算这个量。

为什么使用可组合函数变换?

目前有一些用例在 PyTorch 中处理起来比较棘手:

  • 计算每个样本的梯度(或其他每个样本的量)

  • 在单台机器上运行模型集合

  • 在 MAML 内循环中高效地批量处理任务

  • 高效计算雅可比矩阵和黑塞矩阵

  • 高效计算批处理雅可比矩阵和黑塞矩阵

组合vmap()grad()vjp()变换,使我们能够在不为每个子系统单独设计的情况下表达上述内容。 这种可组合函数变换的思想来自JAX框架

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源