torch.library¶
PyTorch 操作符注册 API 提供了将用户自定义操作符扩展到 PyTorch 核心操作符库中的能力。目前,这可以通过两种方式实现:
创建新的库
允许您通过指定适当的分发密钥,为各种后端和功能注册新操作符和内核。例如,
Consider registering a new operator
addin your newly created namespacefoo. You can access this operator using thetorch.opsAPI and calling into by callingtorch.ops.foo.add. You can also access specific registered overloads by callingtorch.ops.foo.add.{overload_name}.If you registered a new kernel for the
CUDAdispatch key for this operator, then your custom defined function will be called for CUDA tensor inputs.
这可以通过创建
"DEF"种类型的库类对象来实现。
扩展现有的 C++ 库(例如 aten)
允许你通过指定适当的调度键,为各种后端和功能对应的现有操作符注册内核。
这在通过调度键实现的功能需要填补操作符支持不完整时可能会派上用场。例如,
You can add operator support for Meta Tensors (by registering function to the
Metadispatch key).
这可以通过创建
"IMPL"种类型的库类对象来实现。
在Google Colab上提供了一个教程,通过一些示例向您展示如何使用此API。
警告
调度器是PyTorch中的一个复杂概念,对调度器有深入的理解对于能够使用此API进行高级操作至关重要。 这篇博客文章 是一个学习调度器的好起点。
- class torch.library.Library(ns, kind, dispatch_key='')[source]¶
一个类,用于创建可以从 Python 中使用的库,以注册新操作符或覆盖现有库中的操作符。 用户可以选择性地传入一个调度键名,以便仅注册与特定调度键相对应的内核。
要创建一个用于重载现有库(名称为 ns)操作符的库,请将类型设置为“IMPL”。 要创建一个新的库(名称为 ns)以注册新的操作符,请将类型设置为“DEF”。 要创建一个可能已存在的库片段以注册操作符(并绕过给定命名空间只能有一个库的限制),请将类型设置为“FRAGMENT”。
- Parameters
ns – 库名称
类型 – “DEF”,“IMPL”(默认:“IMPL”),“FRAGMENT”
dispatch_key – PyTorch 调度键 (默认值:”“)
- define(schema, alias_analysis='')[source]¶
在命名空间 ns 中定义一个新的操作符及其语义。
- Parameters
模式 – 定义新算子的功能模式。
alias分析 (可选) – 表示操作数的别名属性是否可以从模式(默认行为)中推断出来,或者不能(“保守”)。
- Returns
根据模式推断的操作符名称。
- Example::
>>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- impl(op_name, fn, dispatch_key='')[source]¶
为库中定义的操作符注册函数实现。
- Parameters
op_name – 操作符名称(包括重载)或 OpOverload 对象。
fn – 输入分发键的操作实现函数,或
fallthrough_kernel()以注册一个备用操作。dispatch_key – 输入函数应注册的分派键。默认情况下,它使用库创建时的分派键。
- Example::
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
我们还添加了一些函数装饰器,以便于为操作符注册函数:
torch.library.impl()torch.library.define()