torch.func¶
torch.func,之前称为“functorch”,是用于PyTorch的 类似于JAX的可组合函数转换。
注意
这个库目前处于测试版。 这意味着功能通常可以正常工作(除非另有文档说明), 并且我们(PyTorch 团队)致力于推进这个库的发展。然而,根据用户反馈,API 可能会发生变化, 并且我们还没有完全覆盖 PyTorch 的所有操作。
如果您有关于 API 或您希望涵盖的用例的建议,请打开 GitHub 问题或联系我们。我们很想知道您是如何使用该库的。
什么是可组合函数变换?¶
一个“函数变换”是一个高阶函数,它接受一个数值函数并返回一个新的函数,用于计算不同的量。
torch.func具有自动微分变换(grad(f)返回一个计算f梯度的函数),向量化/批量变换(vmap(f)返回一个在输入批次上计算f的函数),以及其他功能。这些函数变换可以任意组合。例如,组合
vmap(grad(f))计算一个称为每样本梯度的量, 而今天的标准PyTorch无法高效地计算这个量。