目录

torch.func

torch.func,以前称为“functorch”,是 PyTorch 的类似 JAX 的可组合函数转换。

注意

此库目前处于测试阶段。 这意味着这些功能通常有效(除非另有说明) 我们(PyTorch 团队)致力于推动此库的发展。但是,API 可能会根据用户反馈进行更改,并且我们没有完全覆盖 PyTorch 操作。

如果您对 API 或希望涵盖的用例有任何建议,请 打开 GitHub 问题或联系我们。我们很想听听您如何使用该库。

什么是可组合函数转换?

  • “函数转换”是接受数值函数的高阶函数 并返回一个计算不同数量的新函数。

  • 具有自动微分变换 ( 返回一个函数,该函数 计算 的梯度)、矢量化/批处理转换(返回计算输入批次的函数)等。grad(f)fvmap(f)f

  • 这些函数转换可以任意地相互组合。例如 Composing 计算一个称为 per-sample-gradients 的量,该量 stock PyTorch 目前无法高效计算。vmap(grad(f))

为什么使用可组合函数转换?

目前,在 PyTorch 中有许多棘手的使用案例:

  • 计算每个样品的梯度(或其他每个样品的数量)

  • 在单台计算机上运行模型集成

  • 在 MAML 的内循环中有效地将任务批处理在一起

  • 高效计算雅可比矩阵和 Hessian 矩阵

  • 高效计算批处理的雅可比矩阵和 Hessian 矩阵

组合 转换 允许我们表达上述内容,而无需为每个 转换设计单独的子系统。 这种可组合函数转换的理念来自 JAX 框架

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源