torch.func¶
torch.func,以前称为“functorch”,是 PyTorch 的类似 JAX 的可组合函数转换。
注意
此库目前处于测试阶段。 这意味着这些功能通常有效(除非另有说明) 我们(PyTorch 团队)致力于推动此库的发展。但是,API 可能会根据用户反馈进行更改,并且我们没有完全覆盖 PyTorch 操作。
如果您对 API 或希望涵盖的用例有任何建议,请 打开 GitHub 问题或联系我们。我们很想听听您如何使用该库。
什么是可组合函数转换?¶
为什么使用可组合函数转换?¶
目前,在 PyTorch 中有许多棘手的使用案例:
计算每个样品的梯度(或其他每个样品的数量)
在单台计算机上运行模型集成
在 MAML 的内循环中有效地将任务批处理在一起
高效计算雅可比矩阵和 Hessian 矩阵
高效计算批处理的雅可比矩阵和 Hessian 矩阵
组合 、
和
转换 允许我们表达上述内容,而无需为每个 转换设计单独的子系统。
这种可组合函数转换的理念来自 JAX 框架。