目录

torch.library

PyTorch 操作符注册 API 提供了将用户自定义操作符扩展到 PyTorch 核心操作符库中的能力。目前,这可以通过两种方式实现:

  1. 创建新的库

    • 允许您通过指定适当的分发密钥,为各种后端和功能注册新操作符和内核。例如,

      • Consider registering a new operator add in your newly created namespace foo. You can access this operator using the torch.ops API and calling into by calling torch.ops.foo.add. You can also access specific registered overloads by calling torch.ops.foo.add.{overload_name}.

      • If you registered a new kernel for the CUDA dispatch key for this operator, then your custom defined function will be called for CUDA tensor inputs.

    • 这可以通过创建"DEF"种类型的库类对象来实现。

  2. 扩展现有的 C++ 库(例如 aten)

    • 允许你通过指定适当的调度键,为各种后端和功能对应的现有操作符注册内核。

    • 这在通过调度键实现的功能需要填补操作符支持不完整时可能会派上用场。例如,

      • You can add operator support for Meta Tensors (by registering function to the Meta dispatch key).

    • 这可以通过创建"IMPL"种类型的库类对象来实现。

Google Colab上提供了一个教程,通过一些示例向您展示如何使用此API。

警告

调度器是PyTorch中的一个复杂概念,对调度器有深入的理解对于能够使用此API进行高级操作至关重要。 这篇博客文章 是一个学习调度器的好起点。

class torch.library.Library(ns, kind, dispatch_key='')[source]

一个类,用于创建可以从 Python 中使用的库,以注册新操作符或覆盖现有库中的操作符。 用户可以选择性地传入一个调度键名,以便仅注册与特定调度键相对应的内核。

要创建一个用于重载现有库(名称为 ns)操作符的库,请将类型设置为“IMPL”。 要创建一个新的库(名称为 ns)以注册新的操作符,请将类型设置为“DEF”。 要创建一个可能已存在的库片段以注册操作符(并绕过给定命名空间只能有一个库的限制),请将类型设置为“FRAGMENT”。

Parameters
  • ns – 库名称

  • 类型 – “DEF”,“IMPL”(默认:“IMPL”),“FRAGMENT”

  • dispatch_key – PyTorch 调度键 (默认值:”“)

define(schema, alias_analysis='')[source]

在命名空间 ns 中定义一个新的操作符及其语义。

Parameters
  • 模式 – 定义新算子的功能模式。

  • alias分析 (可选) – 表示操作数的别名属性是否可以从模式(默认行为)中推断出来,或者不能(“保守”)。

Returns

根据模式推断的操作符名称。

Example::
>>> my_lib = Library("foo", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
impl(op_name, fn, dispatch_key='')[source]

为库中定义的操作符注册函数实现。

Parameters
  • op_name – 操作符名称(包括重载)或 OpOverload 对象。

  • fn – 输入分发键的操作实现函数,或fallthrough_kernel() 以注册一个备用操作。

  • dispatch_key – 输入函数应注册的分派键。默认情况下,它使用库创建时的分派键。

Example::
>>> my_lib = Library("aten", "IMPL")
>>> def div_cpu(self, other):
>>>     return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
torch.library.fallthrough_kernel()[source]

一个虚拟函数,用于传递给Library.impl以注册一个贯穿。

我们还添加了一些函数装饰器,以便于为操作符注册函数:

  • torch.library.impl()

  • torch.library.define()

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源