torch.library¶
PyTorch 操作符注册 API 提供了将用户自定义操作符扩展到 PyTorch 核心操作符库中的能力。目前,这可以通过两种方式实现:
创建新的库
允许您通过指定适当的分发密钥,为各种后端和功能注册新操作符和内核。例如,
Consider registering a new operator
addin your newly created namespacefoo. You can access this operator using thetorch.opsAPI and calling into by callingtorch.ops.foo.add. You can also access specific registered overloads by callingtorch.ops.foo.add.{overload_name}.If you registered a new kernel for the
CUDAdispatch key for this operator, then your custom defined function will be called for CUDA tensor inputs.
这可以通过创建
"DEF"种类型的库类对象来实现。
扩展现有的 C++ 库(例如 aten)
允许你通过指定适当的调度键,为各种后端和功能对应的现有操作符注册内核。
这在通过调度键实现的功能需要填补操作符支持不完整时可能会派上用场。例如,
You can add operator support for Meta Tensors (by registering function to the
Metadispatch key).
这可以通过创建
"IMPL"种类型的库类对象来实现。
在Google Colab上提供了一个教程,通过一些示例向您展示如何使用此API。
警告
调度器是PyTorch中的一个复杂概念,对调度器有深入的理解对于能够使用此API进行高级操作至关重要。 这篇博客文章 是一个学习调度器的好起点。
- class torch.library.Library(ns, kind, dispatch_key='')[source]¶
一个类,用于创建可以从 Python 中使用的库,以注册新操作符或覆盖现有库中的操作符。 用户可以选择性地传入一个调度键名,以便仅注册与特定调度键相对应的内核。
要创建一个库以覆盖现有库(名称为 ns)中的运算符,请将 kind 设置为 “IMPL”。 要创建一个新库(名称为 ns)以注册新的运算符,请将 kind 设置为 “DEF”。 :param ns: 库名称 :param kind: “DEF”, “IMPL”(默认值: “IMPL”) :param dispatch_key: PyTorch 调度键(默认值: “”)
- define(schema, alias_analysis='')[source]¶
在命名空间 ns 中定义一个新的操作符及其语义。
- Parameters:
模式 – 定义新算子的功能模式。
alias分析 (可选) – 表示操作数的别名属性是否可以从模式(默认行为)中推断出来,或者不能(“保守”)。
- Returns:
根据模式推断的操作符名称。
- Example::
>>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor")
- impl(op_name, fn, dispatch_key='')[source]¶
为库中定义的操作符注册函数实现。
- Parameters:
op_name – 操作符名称(包括重载)或 OpOverload 对象。
fn – 用于输入调度密钥的操作实现的函数。
dispatch_key – 输入函数应注册的分派键。默认情况下,它使用库创建时的分派键。
- Example::
>>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", "CPU")
我们还添加了一些函数装饰器,以便于为操作符注册函数:
torch.library.impl()torch.library.define()