目录

自定义 Python 运算符

创建时间: Jun 18, 2024 |上次更新时间: 2025-1-02 |上次验证: Nov 05, 2024

您将学到什么
  • 如何将用 Python 编写的自定义运算符与 PyTorch 集成

  • 如何使用torch.library.opcheck

先决条件
  • PyTorch 2.4 或更高版本

PyTorch 提供了一个大型运算符库,这些运算符适用于张量(例如 、 、 等)。但是,您可能希望使用新的自定义 运算符,可能是由第三方库编写的。本教程 演示如何包装 Python 函数,使其行为类似于 PyTorch 本机 运维。您可能希望在 PyTorch 中创建自定义运算符的原因包括:torch.addtorch.sum

  • 将任意 Python 函数视为不透明的可调用对象 to (即 prevent 跟踪 到函数中)。torch.compiletorch.compile

  • 向任意 Python 函数添加训练支持

请注意,如果您的操作可以表示为 现有的 PyTorch 运算符,那么通常不需要使用自定义运算符 API – 所有内容(例如,培训支持)都应该 只是工作。torch.compile

示例:将 PIL 的裁剪包装到自定义运算符中

假设我们正在使用 PIL 的操作。crop

import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt

def crop(pic, box):
    img = to_pil_image(pic.cpu())
    cropped_img = img.crop(box)
    return pil_to_tensor(cropped_img).to(pic.device) / 255.

def display(img):
    plt.imshow(img.numpy().transpose((1, 2, 0)))

img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)
Python 自定义操作
cropped_img = crop(img, (10, 10, 50, 50))
display(cropped_img)
Python 自定义操作

crop未通过以下方式进行开箱即用的有效处理:在它无法处理的函数上引发“图形中断”,并且图形中断对性能不利。 下面的代码通过引发错误来演示这一点 ( with 如果 发生 Graph Break 发生)。torch.compiletorch.compiletorch.compilefullgraph=True

@torch.compile(fullgraph=True)
def f(img):
    return crop(img, (10, 10, 50, 50))

# The following raises an error. Uncomment the line to see it.
# cropped_img = f(img)

为了实现 黑盒 ,我们需要 做两件事:croptorch.compile

  1. 将函数包装到 PyTorch 自定义运算符中。

  2. 在 Operator 中添加一个 “ kernel” (又名 “meta kernel”)。 给定一些输入(没有存储空间的虚拟 Tensors), 此函数应返回您选择的虚拟 Tensor 和正确的 张量元数据 (shape/strides//device)。FakeTensorFakeTensorsdtype

from typing import Sequence

# Use torch.library.custom_op to define a new custom operator.
# If your operator mutates any input Tensors, their names must be specified
# in the ``mutates_args`` argument.
@torch.library.custom_op("mylib::crop", mutates_args=())
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
    img = to_pil_image(pic.cpu())
    cropped_img = img.crop(box)
    return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)

# Use register_fake to add a ``FakeTensor`` kernel for the operator
@crop.register_fake
def _(pic, box):
    channels = pic.shape[0]
    x0, y0, x1, y1 = box
    return pic.new_empty(channels, y1 - y0, x1 - x0)

在此之后,现在可以在没有图形中断的情况下工作:crop

@torch.compile(fullgraph=True)
def f(img):
    return crop(img, (10, 10, 50, 50))

cropped_img = f(img)
display(img)
Python 自定义操作
/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:820: FutureWarning:

'onnxscript.values.Op.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.

/usr/local/lib/python3.10/dist-packages/onnxscript/converter.py:820: FutureWarning:

'onnxscript.values.OnnxFunction.param_schemas' is deprecated in version 0.1 and will be removed in the future. Please use '.op_signature' instead.
display(cropped_img)
Python 自定义操作

添加对 crop 的训练支持

用于为操作员添加训练支持。 更喜欢这样做而不是直接使用 ;使用 PyTorch 操作员注册 API 的一些组合可能会导致(以及 导致了)与 合成时出现无声错误。torch.library.register_autogradtorch.autograd.Functionautograd.Functiontorch.compile

如果您不需要培训支持,则无需使用 。 如果您最终使用没有 autograd 的 registration,我们将引发错误消息。torch.library.register_autogradcustom_op

的梯度公式本质上是(我们将 作为对读者的练习)。我们首先包装成一个 自定义运算符:cropPIL.pastepaste

@torch.library.custom_op("mylib::paste", mutates_args=())
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
    assert im1.device == im2.device
    assert im1.dtype == im2.dtype
    im1_pil = to_pil_image(im1.cpu())
    im2_pil = to_pil_image(im2.cpu())
    PIL.Image.Image.paste(im1_pil, im2_pil, coord)
    return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)

@paste.register_fake
def _(im1, im2, coord):
    assert im1.device == im2.device
    assert im1.dtype == im2.dtype
    return torch.empty_like(im1)

现在让我们用来指定 的梯度公式 :register_autogradcrop

def backward(ctx, grad_output):
    grad_input = grad_output.new_zeros(ctx.pic_shape)
    grad_input = paste(grad_input, grad_output, ctx.coords)
    return grad_input, None

def setup_context(ctx, inputs, output):
    pic, box = inputs
    ctx.coords = box[:2]
    ctx.pic_shape = pic.shape

crop.register_autograd(backward, setup_context=setup_context)

请注意,backward 必须是 PyTorch 理解的运算符的组合, 这就是为什么我们将 paste 包装到自定义运算符中,而不是直接使用 PIL 的糊状物。

img = img.requires_grad_()
result = crop(img, (10, 10, 50, 50))
result.sum().backward()
display(img.grad)
Python 自定义操作

这是正确的渐变,裁剪区域为 1(白色),0 (黑色)。

测试 Python 自定义运算符

用于测试自定义运算符是否已注册 正确。这不会测试梯度在数学上是否正确; 请为此编写单独的测试(手动或 )。torch.library.opchecktorch.autograd.gradcheck

要使用 ,请向其传递一组要测试的示例输入。如果您的 operator 支持训练,则示例应包含满足以下条件的 Tensor 需要 grad。如果您的操作员支持多个设备,则示例 应包含来自每个设备的 Tensor。opcheck

examples = [
    [torch.randn(3, 64, 64), [0, 0, 10, 10]],
    [torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
    [torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
    [torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
]

for example in examples:
    torch.library.opcheck(crop, example)

可变 Python 自定义运算符

您还可以包装一个 Python 函数,该函数将其输入更改为自定义 算子。 改变输入的函数很常见,因为那是低级 内核被写入;例如,计算的内核可能会接收 输入和输出张量,并写入输出张量。sininput.sin()

我们将用它来演示一个可变 Python 的示例 custom 运算符。numpy.sin

import numpy as np

@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
    assert input.device == output.device
    assert input.device.type == "cpu"
    input_np = input.numpy()
    output_np = output.numpy()
    np.sin(input_np, out=output_np)

因为 operator 不返回任何内容,所以不需要注册 一个内核(元内核)来使其与 .FakeTensortorch.compile

@torch.compile(fullgraph=True)
def f(x):
    out = torch.empty(3)
    numpy_sin(x, out)
    return out

x = torch.randn(3)
y = f(x)
assert torch.allclose(y, x.sin())

这是一个运行,告诉我们我们确实正确注册了 operator。 例如,如果我们忘记将输出添加到 ,则会出错。opcheckopcheckmutates_args

example_inputs = [
    [torch.randn(3), torch.empty(3)],
    [torch.randn(0, 3), torch.empty(0, 3)],
    [torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
]

for example in example_inputs:
    torch.library.opcheck(numpy_sin, example)

结论

在本教程中,我们学习了如何使用 在 Python 中创建与 PyTorch 子系统配合使用的自定义运算符 例如 和 autograd。torch.library.custom_optorch.compile

本教程提供了自定义运算符的基本介绍。 有关更多详细信息,请参阅:

脚本总运行时间:(0 分 4.595 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源