目录

(测试版)在 FX 中构建卷积/Batch Norm 定影器

创建时间: Mar 04, 2021 |上次更新时间:2024 年 1 月 16 日 |上次验证: Nov 05, 2024

作者Horace He

在本教程中,我们将使用 FX,这是一个用于可组合函数的工具包 转换,以执行以下操作:

  1. 在数据依赖关系中查找 conv/batch 范数的模式。

  2. 对于 1) 中找到的模式,将 batch norm 统计数据折叠到卷积权重中。

请注意,此优化仅适用于推理模式下的模型(即 mode.eval())

我们将构建此处存在的定影器:https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py

首先,让我们把一些导入(我们将使用所有 这些内容稍后在代码中)。

from typing import Type, Dict, Any, Tuple, Iterable
import copy
import torch.fx as fx
import torch
import torch.nn as nn

在本教程中,我们将创建一个由卷积组成的模型 和批处理规范。请注意,此模型有一些棘手的组件 - 一些 conv/batch 范数模式隐藏在 Sequentials 中,其中一个被包装在另一个 Module 中。BatchNorms

class WrappedBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = nn.BatchNorm2d(1)
    def forward(self, x):
        return self.mod(x)

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.bn1 = nn.BatchNorm2d(1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.nested = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 1, 1),
        )
        self.wrapped = WrappedBatchNorm()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.nested(x)
        x = self.wrapped(x)
        return x

model = M()

model.eval()

融合卷积与 Batch Norm

尝试自动融合卷积的主要挑战之一 而 PyTorch 中的 batch norm 是 PyTorch 没有提供一种简单的 访问计算图。FX 通过象征性地解决了这个问题 跟踪调用的实际操作,以便我们可以跟踪计算 通过转发调用,嵌套在 Sequential 模块中,或包装在 用户定义的模块。

traced_model = torch.fx.symbolic_trace(model)
print(traced_model.graph)

这为我们提供了模型的图形表示。请注意,这两个模块 隐藏在 sequential 和 wrapped 中的 Module 已被内联 放入图表中。这是默认的抽象级别,但可以是 由 Pass Writer 配置。更多信息可以在 FX 上找到 概述 https://pytorch.org/docs/master/fx.html#module-torch.fx

融合卷积与 Batch Norm

与其他一些融合不同,卷积与批量范数的融合不会 需要任何新的运算符。相反,在推理过程中作为 batch norm 由逐点加法和乘法组成,这些操作可以被“烘焙” 添加到前面的卷积的权重中。这允许我们删除批处理 norm 完全来自我们的模型!阅读 https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ 了解更多详情。这 此处的代码是从 https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py 清晰目的复制而来的。

def fuse_conv_bn_eval(conv, bn):
    """
    Given a conv Module `A` and an batch_norm module `B`, returns a conv
    module `C` such that C(x) == B(A(x)) in inference mode.
    """
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)

    return fused_conv

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)

FX Fusion 通行证

现在我们有了计算图以及融合方法 convolution 和 batch norm 的迭代,剩下的就是迭代 FX 图 并应用所需的融合。

def _parent_name(target : str) -> Tuple[str, str]:
    """
    Splits a ``qualname`` into parent path and last atom.
    For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
    """
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name

def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    setattr(modules[parent_name], name, new_module)


def fuse(model: torch.nn.Module) -> torch.nn.Module:
    model = copy.deepcopy(model)
    # The first step of most FX passes is to symbolically trace our model to
    # obtain a `GraphModule`. This is a representation of our original model
    # that is functionally identical to our original model, except that we now
    # also have a graph representation of our forward pass.
    fx_model: fx.GraphModule = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())

    # The primary representation for working with FX are the `Graph` and the
    # `Node`. Each `GraphModule` has a `Graph` associated with it - this
    # `Graph` is also what generates `GraphModule.code`.
    # The `Graph` itself is represented as a list of `Node` objects. Thus, to
    # iterate through all of the operations in our graph, we iterate over each
    # `Node` in our `Graph`.
    for node in fx_model.graph.nodes:
        # The FX IR contains several types of nodes, which generally represent
        # call sites to modules, functions, or methods. The type of node is
        # determined by `Node.op`.
        if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
            continue
        # For call sites, `Node.target` represents the module/function/method
        # that's being called. Here, we check `Node.target` to see if it's a
        # batch norm module, and then check `Node.args[0].target` to see if the
        # input `Node` is a convolution.
        if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
            if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
                continue
            conv = modules[node.args[0].target]
            bn = modules[node.target]
            fused_conv = fuse_conv_bn_eval(conv, bn)
            replace_node_module(node.args[0], modules, fused_conv)
            # As we've folded the batch nor into the conv, we need to replace all uses
            # of the batch norm with the conv.
            node.replace_all_uses_with(node.args[0])
            # Now that all uses of the batch norm have been replaced, we can
            # safely remove the batch norm.
            fx_model.graph.erase_node(node)
    fx_model.graph.lint()
    # After we've modified our graph, we need to recompile our graph in order
    # to keep the generated code in sync.
    fx_model.recompile()
    return fx_model

注意

为了演示目的,我们在这里进行了一些简化,例如仅 匹配的 2D 卷积。查看 https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py 以获得更有用的通行证。

测试我们的 Fusion Pass

现在,我们可以在初始玩具模型上运行此融合通道,并验证我们的 结果是相同的。此外,我们可以打印出 fused 的代码 建模并验证是否不再有 BATCH 规范。

fused_model = fuse(model)
print(fused_model.code)
inp = torch.randn(5, 1, 1, 1)
torch.testing.assert_allclose(fused_model(inp), model(inp))

在 ResNet18 上对我们的 Fusion 进行基准测试

我们可以在更大的模型(如 ResNet18)上测试我们的融合通道,看看有多少 此通道可提高推理性能。

import torchvision.models as models
import time

rn18 = models.resnet18()
rn18.eval()

inp = torch.randn(10, 3, 224, 224)
output = rn18(inp)

def benchmark(model, iters=20):
    for _ in range(10):
        model(inp)
    begin = time.time()
    for _ in range(iters):
        model(inp)
    return str(time.time()-begin)

fused_rn18 = fuse(rn18)
print("Unfused time: ", benchmark(rn18))
print("Fused time: ", benchmark(fused_rn18))

正如我们之前看到的,我们的 FX 转换的输出是 (“torchscriptable”)PyTorch 代码,我们可以很容易地将输出尝试 并进一步提高我们的性能。通过这种方式,我们的 FX 模型 transformation 使用 TorchScript 编写,没有问题。jit.script

jit_rn18 = torch.jit.script(fused_rn18)
print("jit time: ", benchmark(jit_rn18))


############
# Conclusion
# ----------
# As we can see, using FX we can easily write static graph transformations on
# PyTorch code.
#
# Since FX is still in beta, we would be happy to hear any
# feedback you have about using it. Please feel free to use the
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
# you might have.

脚本总运行时间:(0 分 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源