目录

torch.fx

概述

FX 是一个开发人员使用的工具包,用于转换 nn.Module 实例。FX 包含三个主要组件:一个 符号追踪器, 一个 中间表示法Python 代码生成。这些组件的实际演示:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪器执行“符号执行”的Python代码。它通过代码传递假值,称为代理。对这些代理的操作会被记录下来。有关符号追踪的更多信息可以在symbolic_trace()Tracer文档中找到。

中间表示是符号跟踪期间记录的操作容器。它由一组节点组成,这些节点代表函数输入、调用点(到函数、方法或torch.nn.Module实例)和返回值。有关IR的更多信息,请参阅Graph的文档。IR是应用变换的格式。

Python代码生成 是使得FX成为一个Python到Python(或模块到模块)转换工具包的原因。对于每个图中间表示(Graph IR),我们可以创建与其语义匹配的有效Python代码。此功能被封装在GraphModule中,它是一个torch.nn.Module实例,其中包含一个Graph以及由图生成的forward方法。

这些组件的组合(符号追踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件可以单独使用。例如,符号追踪可以在隔离状态下使用以捕获代码的一种形式,用于分析(而非转换)。代码生成可用于程序化地生成模型,例如从配置文件中生成。FX 有众多用途!

可以在示例库中找到一些转换示例。

写入转换

什么是FX变换?本质上,它就像下面这个函数一样。

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

你的转换将接收一个 torch.nn.Module,从中获取一个 Graph,进行一些修改,然后返回一个新的 torch.nn.Module。你应该认为你的 FX 转换返回的 torch.nn.Module 与普通的 torch.nn.Module 是相同的 – 你可以将其传递给另一个 FX 转换,可以将其传递给 TorchScript,或者你可以运行它。确保你的 FX 转换的输入和输出是 torch.nn.Module 将允许其组合性。

注意

也可以修改现有的 GraphModule 而不是创建一个新的,例如:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

请注意,您必须调用 GraphModule.recompile() 以使生成的 forward() 方法与修改后的 Graph 同步。

鉴于你已经传入了一个 torch.nn.Module 并将其追踪为一个 Graph,现在有两种主要方法可以用来构建一个新的 Graph

图论入门简介

有关图语义的完整处理可以在Graph文档中找到,但我们在这里会介绍基础知识。一个Graph是一种数据结构,用于表示在GraphModule上的方法。所需的信息是:

  • 该方法的输入是什么?

  • 该方法内部运行哪些操作?

  • 该方法的输出(即返回)值是什么?

这三个概念都用Node实例表示。 让我们通过一个简短的例子来看看我们指的是什么:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

这里我们定义一个模块 MyModule 用于演示目的,实例化它, 符号性地追踪它,然后调用Graph.print_tabular()方法来打印 出一个表格,显示此Graph的节点:

opcode

name

target

args

kwargs

placeholder

x

x

()

{}

get_attr

linear_weight

linear.weight

()

{}

call_function

add_1

<built-in function add>

(x, linear_weight)

{}

call_module

linear_1

linear

(add_1,)

{}

call_method

relu_1

relu

(linear_1,)

{}

call_function

sum_1

<built-in method sum …>

(relu_1,)

{‘dim’: -1}

call_function

topk_1

<built-in method topk …>

(sum_1, 3)

{}

output

output

output

(topk_1,)

{}

我们可以使用这些信息来回答我们上面提出的问题。

  • 该方法的输入是什么?在FX中,方法输入通过特殊的placeholder节点指定。在这种情况下,我们有一个单一的placeholder节点,其targetx,这意味着我们有一个名为x的单个(非自身)参数。

  • 方法中的操作有哪些?get_attrcall_functioncall_modulecall_method 节点 代表方法中的操作。有关这些操作语义的完整说明,请参阅Node 文档。

  • 该方法的返回值是什么?在Graph中,返回值由一个特殊的output节点指定。

鉴于我们现在了解了FX中代码的基本表示方法,我们可以探索如何编辑一个 Graph

图操作

直接图 manipulation

构建这个新Graph的一种方法是直接操作旧的 一个。为了帮助实现这一点,我们可以简单地获取从符号 追踪中得到的Graph并进行修改。例如,假设我们希望用 torch.mul()调用来替换 torch.add()调用。

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

我们还可以进行更复杂的Graph重写,例如删除或添加节点。为了帮助这些转换, FX 提供了一些用于转换图的工具函数,可以在Graph文档中找到。下面是一个使用这些 API 添加一个torch.relu()调用的示例。

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

对于仅包含替换的简单转换,您也可以使用子图重写器。

子图重写与replace_pattern()

FX 还在直接图操作的基础上提供了另一个级别的自动化。 replace_pattern() API 实质上是一个用于编辑 Graph 的“查找/替换”工具。它允许您指定一个 patternreplacement 函数,并且它会遍历这些函数,在 pattern 图中找到操作实例,并将其替换为 replacement 图的副本。这可以帮助大大自动化繁琐的图操作代码,当转换变得更为复杂时,这种自动化尤为重要。

图操作示例

Proxy/Retracing

另一种操作 Graph 的方式是重用符号跟踪中使用的 Proxy 机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能性是在 F.relu 后插入比较和乘法,然后清理原始的 F.relu。然而,我们可以通过使用 Proxy 对象自动记录操作到 Graph 中来自动化这个过程。

要使用此方法,我们将想要插入的操作编写为常规的 PyTorch 代码,并使用Proxy 对象作为参数调用该代码。 这些Proxy 对象将捕获对其执行的操作并将它们附加到Graph 上。

# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式图操作外,使用Proxys 还可以让你以原生Python代码的形式指定重写规则。 对于需要大量重写规则的转换(例如vmap或grad),这通常可以提高规则的可读性和可维护性。 需要注意的是,在调用Proxy时,我们也传递了一个指向底层变量graph的追踪器。 这样做是为了在图中的操作是n元(例如加法是一个二元运算符)的情况下, 调用Proxy不会创建多个图追踪器实例,从而导致意外的运行时错误。 我们特别推荐在这种方法使用Proxy,尤其是在不能安全假设底层运算是单元的情况下。

使用Proxy进行Graph操作的一个示例可以在这里找到。

解释器模式

FX中一个有用的代码组织模式是对Node进行循环,并在Graph中执行它们。这可以用于多种用途,包括对流经图的值进行运行时分析或通过使用Proxy进行代码转换。例如,假设我们希望运行一个GraphModule并记录我们在运行时看到的节点的torch.Tensor形状和数据类型属性。那可能看起来像这样:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

正如你所见,一个完整的FX解释器并没有那么复杂, 但它可以非常有用。为了方便使用这种模式,我们提供了 Interpreter 类,该类以一种可以通过方法重写来覆盖解释器执行的某些方面的形式封装了上述逻辑。

除了执行操作外,我们还可以通过解释器传入 Proxy 个值来生成一个新的 Graph。 同样地,我们提供了 Transformer 类来涵盖这种模式。 Transformer 的行为类似于 Interpreter, 但与其调用 run 方法从模块中获取具体的输出值,您会调用 Transformer.transform() 方法返回一个新的 GraphModule,该对象会遵循您安装的任何转换规则作为重写方法。

解释器模式示例

调试

介绍

在编写转换代码的过程中,我们的代码往往不会完全正确。 在这种情况下,我们需要进行一些调试。关键是倒推:首先,检查调用生成模块的结果以证明或反驳其正确性。然后,检查并调试生成的代码。最后,调试导致生成代码的转换过程。

如果您不熟悉调试器,请参阅辅助部分 可用的调试器

转换编写中的常见陷阱

  • 非确定性 set 迭代顺序。在 Python 中,set 数据类型是无序的。使用 set 来包含对象集合(例如 Node),可能会导致意外的非确定性。一个例子是遍历一组 Node 并将它们插入到 Graph 中。由于 set 数据类型是无序的,输出程序中的操作顺序将是非确定性的,并且在程序调用之间可能会发生变化。推荐的替代方法是使用 dict 数据类型,该类型从 Python 3.7(以及 cPython 3.6)开始是插入有序的。dict 可以等效地用于集合,通过将需要去重的值存储在 dict 的键中。

检查模块的正确性

因为大多数深度学习模块的输出由浮点数torch.Tensor实例组成,检查两个torch.nn.Module结果之间的等价性并不像简单的相等性检查那样直观。为了说明这一点,让我们来看一个例子:

import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

这里,我们尝试使用==等于运算符检查两个深度学习模型的值是否相等。然而,这并不明确,原因有两点:一是该运算符返回的是一个张量而不是布尔值;二是浮点数值的比较应当使用误差范围(或epsilon)来考虑浮点数操作的非交换性(详见这里)。我们可以改用torch.allclose(),它会根据相对和绝对容差阈值给出近似的比较结果:

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们在工具箱中的第一个工具,用于检查转换后的模块是否如我们预期的那样工作,与参考实现进行比较。

调试生成的代码

因为FX在GraphModule上生成了forward()函数,使用传统的调试技术如print语句或pdb并不那么直接。幸运的是,我们有几种可以用于调试生成代码的技术。

使用 pdb

调用 pdb 以进入正在运行的程序。尽管表示 Graph 的代码不在任何源文件中,我们仍然可以在前向传播被调用时使用 pdb 手动进入它。

import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用to_folder函数来自GraphModule

GraphModule.to_folder()GraphModule 中的一种方法,允许你将生成的FX代码转储到一个文件夹中。虽然将前向传递复制到代码中通常就足够了(如 打印生成的代码 所示),但使用 to_folder 检查模块和参数可能会更容易。

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

运行上述示例后,我们可以查看foo/module.py中的代码并根据需要进行修改(例如添加print语句或使用pdb)以调试生成的代码。

调试转换

现在我们已经确定转换过程生成了错误的代码,是时候调试转换本身了。首先,我们将检查文档中的符号跟踪的局限性部分。一旦我们确认跟踪按预期工作,目标就变成了找出在我们的GraphModule转换过程中出了什么问题。可能在编写转换中有快速的答案,但如果没有,有几种方法可以检查我们的跟踪模块:

# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %y : [#users=1] = placeholder[target=y]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上述实用函数,我们可以在应用转换之前和之后比较我们的跟踪模块。有时候,简单的视觉比较就足以追踪到错误。如果仍然不清楚问题出在哪里,像pdb这样的调试器可能是下一步的好选择。

继续以上面的例子为基础,考虑以下代码:

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

    """
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

使用上面的例子,假设对print(traced)的调用显示我们的转换中存在错误。我们想通过调试器找出问题所在。我们开始一个pdb会话。我们可以通过在transform_graph(traced)处设置断点,然后按s“进入”对transform_graph(traced)的调用来查看转换过程中发生了什么。

我们也可以通过编辑print_tabular方法来打印图中节点的不同属性而获得好运。(例如,我们可能想要查看节点的input_nodesusers。)

可用的调试器

最常见的Python调试器是 pdb。你可以通过在命令行中输入 pdb 来启动你的程序的“调试模式”,输入 python -m pdb FILENAME.py,其中 FILENAME 是你想要调试的文件名。之后,你可以使用 pdb 调试命令 逐步运行你的程序。通常的做法是在启动 pdb 时设置一个断点(b LINE-NUMBER),然后调用 c 来运行程序直到该断点。这可以防止你必须逐行执行(使用 sn)才能到达你想检查的代码部分。或者,你可以在想要中断的行前写上 import pdb; pdb.set_trace()。如果你添加了 pdb.set_trace(),你的程序将在运行时自动进入调试模式。(换句话说,你只需在命令行中输入 python FILENAME.py 而不是 python -m pdb FILENAME.py)。一旦你在调试模式下运行文件,你可以逐步执行代码并使用某些命令检查程序的内部状态。网上有很多优秀的 pdb 教程,包括 RealPython 的 “Python 调试与 Pdb”

像PyCharm或VSCode这样的IDE通常内置有调试器。在你的IDE中,你可以选择:a) 使用pdb,通过在IDE中打开终端窗口(例如,在VSCode中,视图 → 终端),或者 b) 使用内置的调试器(通常是pdb的一个图形化封装)。

符号追踪的局限性

FX 使用一种符号跟踪系统(也称为 符号执行)来捕获程序的语义,以便以可转换/可分析的形式表示。该系统是跟踪的,因为它会执行程序(实际上是一个torch.nn.Module 或函数)以记录操作。它是符号化的,因为在执行过程中流经程序的数据不是真实数据,而是符号(在 FX 术语中为 Proxy)。

尽管符号追踪适用于大多数神经网络代码,但它也有一些限制。

动态控制流

符号式追踪的主要限制是它目前不支持动态控制流。也就是说,循环或if语句中的条件可能依赖于程序的输入值。

例如,让我们来分析以下程序:

def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

条件依赖于if语句的值,该值依赖于x.sum()的值, 而x.sum()又依赖于x的值,这是一个函数输入。由于 x可以改变(例如,如果你传递一个新的输入张量给跟踪的函数),这就是动态控制流。回溯会沿着你的代码向上追溯,向你展示这种情况发生的位置。

静态控制流

另一方面,所谓的静态控制流是受支持的。静态控制流是指值在调用之间不能改变的循环或if语句。通常,在PyTorch程序中,这种控制流出现在根据超参数对模型架构进行决策的代码中。具体例子如下:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if语句 if self.do_activation 不依赖于任何函数输入,因此它是静态的。do_activation 可以被视为一个超参数,具有不同参数值的MyModule的不同实例的跟踪代码各不相同。这是符号跟踪支持的有效模式。

许多动态控制流的情况在语义上是静态控制流。这些情况可以通过移除对输入值的数据依赖来支持符号追踪,例如将值移动到Module属性中,或在符号追踪期间将具体值绑定到参数:

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

在真正动态控制流的情况下,包含此代码的程序部分可以被追踪为对方法的调用(参见 使用 Tracer 类自定义追踪) 或函数(参见 wrap()),而不是通过它们进行追踪。

torch函数

FX 使用 __torch_function__ 作为拦截调用的机制(有关更多信息,请参阅 技术概述)。一些函数,例如内置的 Python 函数或 math 模块中的函数,不在 __torch_function__ 的覆盖范围内,但我们仍希望在符号跟踪中捕获它们。例如:

import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

错误告诉我们内置函数 len 不被支持。 我们可以通过使用 wrap() API,将此类函数记录在跟踪中作为直接调用:

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用Tracer类自定义跟踪

Tracer 类是实现 symbolic_trace 的基础类。通过子类化 Tracer,可以自定义跟踪行为,如下所示:

class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶模块

叶模块是在符号跟踪中作为调用出现的模块,而不是被跟踪通过的模块。默认的叶模块集合是标准的torch.nn模块实例的集合。例如:

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

可以通过覆盖Tracer.is_leaf_module()来定制叶模块集。

杂项

  • 张量构造函数(例如 torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor) 目前不可追踪。

    • 确定性构造函数(zerosones)可以使用, 它们生成的值将作为常量嵌入到跟踪中。只有在这些构造函数的参数引用动态输入大小时,才会出现问题。 在这种情况下,ones_likezeros_like 可能是可行的替代方案。

    • 非确定性构造函数(rand, randn)将在跟踪中嵌入一个随机值。这可能不是预期的行为。一种解决方法是将 torch.randn 包装在一个 torch.fx.wrap 函数中并调用该函数。

    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 此行为可能会在未来的版本中得到修复。

  • 类型注解

    • Python 3 风格的类型注解(例如 func(x : torch.Tensor, y : int) -> torch.Tensor) 被支持 并且将通过符号跟踪被保留。

    • Python 2风格的注释类型注解 # type: (torch.Tensor, int) -> torch.Tensor 目前不支持。

    • 函数中局部名称的注释目前不支持。

  • training 个标志和子模块周围抓取

    • 在使用像torch.nn.functional.dropout这样的函数时,通常会将训练参数作为self.training传递进来。在FX跟踪过程中,这很可能会被固化为一个常量值。

    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
      def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
      dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
      return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_allclose(traced(x), x)
    """
    AssertionError: Tensor-likes are not close!
    
    Mismatched elements: 15 / 15 (100.0%)
    Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
    Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
    """
    
    • 然而,当使用标准的nn.Dropout()子模块时,训练标志被封装起来——由于nn.Module对象模型的保留——可以被更改。

    class DropoutRepro2(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
      def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
      drop = self.drop(x);  x = None
      return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_allclose(traced(x), x)
    
  • Because of this difference, consider marking modules that interact with the training flag dynamically as leaf modules.

API 参考

torch.fx.symbolic_trace(root, concrete_args=None)[source]

符号追踪 API

给定一个 nn.Module 或函数实例 root,此函数将返回一个通过记录在跟踪 root 时看到的操作而构建的 GraphModule

concrete_args 允许你部分特化你的函数,无论是为了去除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

FX 通常无法追踪到这一点,因为存在控制流。但是,我们可以使用 concrete_args 来针对 b 的值进行专业化处理以追踪到这一点。

f = fx.symbolic_trace(f, concrete_args={‘b’: False}) assert f(3, False) == 6

请注意,尽管您仍然可以传入不同的b值,但它们将被忽略。

我们也可以使用 concrete_args 来消除函数中的数据结构处理。 这将使用 pytrees 来展平您的输入。为了避免过度特化,对于不应特化的值,请传入 fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
Parameters:
  • (联合[torch.nn.Module, 可调用对象]) – 要跟踪并转换为图形表示的模块或函数。

  • concrete_args可选[字典[字符串, 任意类型]]) – 部分专门化的输入

Returns:

root记录的操作创建的模块。

Return type:

GraphModule

注意

此 API 的向后兼容性得到保证。

torch.fx.wrap(fn_or_name)[source]

此函数可以在模块级别范围内调用,以将 fn_or_name 注册为“叶子函数”。 “叶子函数”在 FX 跟踪中将被保留为 CallFunction 节点,而不会被跟踪通过:

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

此函数也可以用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

封装的函数可以被认为是一个“叶函数”,类似于“叶模块”的概念,也就是说,它们是在FX跟踪中被保留为调用的函数,而不是被跟踪通过的函数。

Parameters:

fn_or_name (Union[str, Callable]) – 当调用时要插入到图中的全局函数的名称或函数

注意

此 API 的向后兼容性得到保证。

class torch.fx.GraphModule(*args, **kwargs)[source]

GraphModule 是从 fx.Graph 生成的 nn.Module。Graphmodule 具有 graph 属性,以及从该 graph 生成的 codeforward 属性。

警告

graph 被重新赋值时,codeforward 将自动生成。 然而,如果你在不重新赋值 graph 属性本身的情况下编辑 graph 的内容, 你必须调用 recompile() 来更新生成的代码。

注意

此 API 的向后兼容性得到保证。

__init__(root, graph, class_name='GraphModule')[source]

构造一个 GraphModule。

Parameters:
  • (联合[torch.nn.Module, 字典[字符串, 任意类型]) – root 可以是 nn.Module 实例或映射字符串到任意属性类型的字典。 如果 root 是一个 Module,那么在 Graph 的 Nodes 的 target 字段中通过限定名称引用的基于 Module 的对象将从 root 的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。 如果 root 是一个字典,在 Node 的 target 中找到的限定名称将直接在字典的键中查找。由字典映射的对象将被复制到 GraphModule 的模块层次结构中的适当位置。

  • () – graph 包含此 GraphModule 应用于代码生成的节点

  • 类名 (字符串) – name 表示此 GraphModule 的名称,用于调试目的。如果未设置,则所有错误消息都将报告为源自 GraphModule。将此设置为 root 的原始名称或在转换上下文中具有意义的名称可能会有所帮助。

注意

此 API 的向后兼容性得到保证。

add_submodule(target, m)[source]

将给定的子模块添加到 self

这将安装空模块(如果它们是 target 的子路径且尚不存在)。

Parameters:
  • 目标 (字符串) – 新子模块的完全限定字符串名称 (请参阅 nn.Module.get_submodule 中的示例以了解如何指定完全限定字符串。)

  • m (模块) – 子模块本身;我们希望在当前模块中安装的实际对象

Returns:

Whether or not the submodule could be inserted. For

此方法返回 True,链中的每个对象 由 target 表示的必须要么 a) 尚不存在, 要么 b) 引用一个 nn.Module (不是参数或其他属性)

Return type:

布尔

注意

此 API 的向后兼容性得到保证。

property code: str

返回由Graph生成的Python代码 GraphModule

delete_all_unused_submodules()[source]

self 中删除所有未使用的子模块。

如果以下任何一个条件为真,则模块被视为“已使用”: 1. 它有被使用的子模块 2. 它的前向传播直接通过call_module节点调用 3. 它有一个非模块属性从get_attr节点被使用

此方法可以调用以清理一个nn.Module,而无需手动调用每个未使用的子模块上的delete_submodule

注意

此 API 的向后兼容性得到保证。

delete_submodule(target)[source]

self中删除给定的子模块。

如果target不是一个有效的目标,模块将不会被删除。

Parameters:

目标 (字符串) – 新子模块的完全限定字符串名称 (请参阅 nn.Module.get_submodule 中的示例以了解如何指定完全限定字符串。)

Returns:

Whether or not the target string referenced a

要删除的子模块。返回值False表示target不是对子模块的有效引用。

Return type:

布尔

注意

此 API 的向后兼容性得到保证。

property graph: Graph

返回这个 Graph 底层的 GraphModule

print_readable()[source]

返回为当前 GraphModule 及其子 GraphModules 生成的 Python 代码

警告

此API为实验性功能,且向后兼容。

recompile()[source]

从其graph属性重新编译此GraphModule。应在编辑包含的graph之后调用此方法,否则此GraphModule生成的代码将过时。

注意

此 API 的向后兼容性得到保证。

Return type:

PythonCode

to_folder(folder, module_name='FxModule')[source]
Dumps out module to folder with module_name so that it can be

使用 from <folder> import <module_name> 导入

Args:

folder (Union[str, os.PathLike]): The folder to write the code out to

module_name (str): Top-level name to use for the Module while

writing out the code

警告

此API为实验性功能,且向后兼容。

class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]

Graph 是 FX 中间表示中使用的主要数据结构。 它由一系列 Node 组成,每个代表调用点(或其他语法结构)。将这些 Node 组合在一起构成一个有效的 Python 函数。

例如,以下代码

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

将生成以下图形:

print(gm.graph)
graph(x):
    %linear_weight : [#users=1] = self.linear.weight
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

对于在Graph中表示的操作语义,请参阅Node

注意

此 API 的向后兼容性得到保证。

__init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]

构造一个空的图。

注意

此 API 的向后兼容性得到保证。

call_function(the_function, args=None, kwargs=None, type_expr=None)[source]

call_function Node 插入到 Graph 中。一个 call_function 节点 表示对由 the_function 指定的 Python 可调用对象的调用。

Parameters:
  • the_function (Callable[..., Any]) – 要调用的函数。可以是任何PyTorch操作符、Python函数,或者是builtinsoperator命名空间中的成员。

  • args (可选[元组[参数, ...]]) – 要传递给被调用函数的位置参数。

  • kwargs (可选[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数

  • type_expr可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。

Returns:

新创建并插入的call_function节点。

Return type:

节点

注意

此方法的插入点和类型表达式规则与 Graph.create_node() 相同。

注意

此 API 的向后兼容性得到保证。

call_method(method_name, args=None, kwargs=None, type_expr=None)[source]

call_method Node 插入到 Graph 中。一个 call_method 节点 表示对 args 的第 0 个元素调用给定的方法。

Parameters:
  • 方法名称 (字符串) – 要应用于self参数的方法的名称。 例如,如果args[0]是一个表示TensorNode, 那么要对该Tensor调用relu(),请将relu传递给method_name

  • args (可选[元组[参数, ...]]) – 要传递给调用方法的位置参数。请注意,此参数 包含一个 self 参数。

  • kwargs (可选[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数

  • type_expr可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。

Returns:

新创建并插入的call_method节点。

Return type:

节点

注意

此方法的插入点和类型表达式规则与 Graph.create_node() 相同。

注意

此 API 的向后兼容性得到保证。

call_module(module_name, args=None, kwargs=None, type_expr=None)[source]

将一个 call_module Node 插入到 Graph 中。一个 call_module 节点 表示在 Module 层次结构中调用 Module 的 forward() 函数。

Parameters:
  • 模块名称 (字符串) – 在Module层次结构中要调用的Module的限定名称。例如,如果跟踪的Module有一个名为foo的子模块,该子模块又有一个名为bar的子模块,则应将限定名称foo.bar作为module_name传递以调用该模块。

  • args (可选[元组[参数, ...]]) – 要传递给调用方法的位置参数。请注意,这不应包括 self 参数。

  • kwargs (可选[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数

  • type_expr可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。

Returns:

新创建并插入的call_module节点。

Return type:

节点

注意

此方法的插入点和类型表达式规则与 Graph.create_node() 相同。

注意

此 API 的向后兼容性得到保证。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]

创建一个Node并将其添加到当前插入点的Graph中。 请注意,当前插入点可以通过Graph.inserting_before()Graph.inserting_after()设置。

Parameters:
  • 操作 (字符串) – 此节点的操作码。可以是 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’ 中的一个。这些操作码的语义在 Graph 的文档字符串中有所描述。

  • args (可选[元组[参数, ...]]) – 是传递给此节点的参数元组。

  • kwargs (可选[Dict[str, Argument]]) – 此节点的kwargs

  • 名称 (可选[字符串]) – 一个可选的字符串名称,用于Node。 这将影响在生成的Python代码中赋值的变量名称。

  • type_expr可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。

Returns:

新创建并插入的节点。

Return type:

节点

注意

此 API 的向后兼容性得到保证。

eliminate_dead_code()[source]

根据每个节点的用户数量以及节点是否有任何副作用,从图中删除所有无效代码。在调用之前,图必须进行拓扑排序。

Returns:

图是否由于传递而发生了变化。

Return type:

布尔

Example:

在消除死代码之前,a 来自 a = x + 1 下面没有使用者, 因此可以从图中消除而不会产生影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

消除无效代码后,a = x + 1 已被移除,forward 的其余部分仍然保留。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除有一些启发式方法来避免删除具有副作用的节点(参见 Node.is_impure),但总体覆盖范围非常差,因此除非你知道你的 FX 图完全由函数操作组成,否则不应假设调用此方法是可靠的。

注意

此 API 的向后兼容性得到保证。

erase_node(to_erase)[source]

Graph中删除一个Node。如果该节点在Graph中仍有使用者,则抛出异常。

Parameters:

要擦除的 (节点) – 从 Graph 中擦除 Node

注意

此 API 的向后兼容性得到保证。

get_attr(qualified_name, type_expr=None)[source]

将一个get_attr节点插入到图中。一个get_attr Node表示从Module层次结构中获取属性。

Parameters:
  • 限定名称 (字符串) – 要检索的属性的完全限定名称。 例如,如果追踪的模块有一个名为 foo 的子模块,该子模块有一个名为 bar 的子模块,该子模块有一个名为 baz 的属性,则应将限定名称 foo.bar.baz 作为 qualified_name 传递。

  • type_expr可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。

Returns:

新创建并插入的get_attr节点。

Return type:

节点

注意

此方法的插入点和类型表达式规则与Graph.create_node相同。

注意

此 API 的向后兼容性得到保证。

graph_copy(g, val_map, return_output_node=False)[source]

将给定图中的所有节点复制到 self

Parameters:
  • g () – 复制节点的源图。

  • val_map (Dict[节点, 节点]) – 将填充一个从g中的节点到self中的节点的映射的字典。请注意,可以传入已经包含某些值的val_map以覆盖某些值的复制。

Returns:

The value in self 的值现在等同于 g 中的输出值, 如果 g 有一个 output 节点。否则为 None

Return type:

可选[联合[元组[任意, …], 列表[任意], 字典[str, 任意], 切片, 节点, str, int, float, bool, 复数, 数据类型, 张量, 设备, 内存格式, 布局]]

注意

此 API 的向后兼容性得到保证。

inserting_after(n=None)[source]
Set the point at which create_node and companion methods will insert into the graph.

当在 'with' 语句中使用时,这将临时设置插入点,然后在 with 语句退出时恢复它:

with g.inserting_after(n):
    ... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) #  set the insert point permanently

Args:

n (Optional[Node]): The node before which to insert. If None this will insert after

the beginning of the entire graph.

Returns:

一个资源管理器,它将在__exit__上恢复插入点。

注意

此 API 的向后兼容性得到保证。

inserting_before(n=None)[source]
Set the point at which create_node and companion methods will insert into the graph.

当在 'with' 语句中使用时,这将临时设置插入点,然后在 with 语句退出时恢复它:

with g.inserting_before(n):
    ... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) #  set the insert point permanently

Args:

n (Optional[Node]): The node before which to insert. If None this will insert before

the beginning of the entire graph.

Returns:

一个资源管理器,它将在__exit__上恢复插入点。

注意

此 API 的向后兼容性得到保证。

lint()[source]

对这个图运行各种检查,以确保其结构良好。特别是: - 检查节点是否有正确的所有权(属于此图) - 检查节点是否按拓扑顺序出现 - 如果此图有 owning GraphModule,检查 targets 是否存在于该 GraphModule 中

注意

此 API 的向后兼容性得到保证。

node_copy(node, arg_transform=<function Graph.<lambda>>)[source]

将一个节点从一个图复制到另一个图。arg_transform 需要将参数从节点的图转换为自身的图。示例:

# Copying all the nodes in `g` into `new_graph`
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
Parameters:
  • 节点 (节点) – 将节点复制到 self

  • arg_transform (可调用[[节点], 参数]) – 一个将节点中的 Node 参数转换为 等效的 self 参数的函数。在最简单的情况下,这应该 从映射原始图中节点到 self 的表中检索值。

Return type:

节点

注意

此 API 的向后兼容性得到保证。

property nodes: _node_list

获取构成此图的节点列表。

请注意,此 Node 列表表示形式是一个双向链表。迭代期间的更改(例如删除节点、添加节点)是安全的。

Returns:

双向链表的节点。请注意,可以在此列表上调用reversed以切换迭代顺序。

on_generate_code(make_transformer)[source]

注册一个转换函数,当生成Python代码时

Args:
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):

a function that returns a code transformer to be registered. This function is called by on_generate_code to obtain the code transformer.

This function is also given as its input the currently registered code transformer (or None if nothing is registered), in case it is not desirable to overwrite it. This is useful to chain code transformers together.

Returns:

a context manager that when used in a with statement, to automatically restore the previously registered code transformer.

Example:

gm: fx.GraphModule = ...

# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]

# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(
    lambda _: insert_pdb
)

# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(
            current_trans(body) if current_trans
            else body
        )
    )
)

gm.recompile()
gm(*inputs)  # drops into pdb

This function can also be used as a context manager, with the benefit to automatically restores the previously registered code transformer:

# ... continue from previous example

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # do more stuff with `gm`...
    gm.recompile()
    gm(*inputs)  # drops into pdb

# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you
# run next `recompile()`).

警告

此API为实验性功能,且向后兼容。

output(result, type_expr=None)[source]

output Node 中插入一个 Graph。一个 output 节点表示 Python 代码中的 return 语句。result 是应该返回的值。

Parameters:
  • 结果 (参数) – 要返回的值。

  • type_expr可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。

注意

此方法的插入点和类型表达式规则与Graph.create_node相同。

注意

此 API 的向后兼容性得到保证。

property owning_module

返回拥有此 GraphModule 的模块,如果存在的话, None 如果没有拥有者模块或有多个拥有者模块。

placeholder(name, type_expr=None, default_value)[source]

将一个 placeholder 节点插入图中。一个 placeholder 表示一个函数输入。

Parameters:
  • 名称 (字符串) – 输入值的名称。这对应于该函数的位置参数名称,此 Graph 表示。

  • type_expr可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。在某些情况下,为了正确生成代码(例如,当函数在TorchScript编译中后续使用时),这是必需的。

  • 默认值 (任意类型) – 该函数参数应采用的默认值。注意:为了允许将 None 作为默认值,应该传递 inspect.Signature.empty 作为此参数,以指定该参数 _没有_ 默认值。

Return type:

节点

注意

此方法的插入点和类型表达式规则与Graph.create_node相同。

注意

此 API 的向后兼容性得到保证。

print_tabular()[source]

以表格格式打印图的中间表示。请注意,此 API 需要安装 tabulate 模块。

注意

此 API 的向后兼容性得到保证。

process_inputs(*args)[source]

处理参数以便它们可以传递给FX图。

警告

此API为实验性功能,且向后兼容。

process_outputs(out)[source]

警告

此API为实验性功能,且向后兼容。

python_code(root_module, *, verbose=False)[source]

将此 Graph 转换为有效的 Python 代码。

Parameters:

根模块 (str) – 用于查找限定名称目标的根模块的名称。这通常是‘self’。

Returns:

src: 代表对象的Python源代码 globals: 字典中的全局名称 src -> 它们所引用的对象。

Return type:

一个 PythonCode 对象,包含两个字段

注意

此 API 的向后兼容性得到保证。

set_codegen(codegen)[source]

警告

此API为实验性功能,且向后兼容。

class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source]

Node 是表示深度学习框架中各个操作的数据结构。大多数情况下,节点代表对各种实体的调用点,例如操作符、方法和模块(一些例外包括指定函数输入和输出的节点)。每个 Node 都有一个由其 op 属性指定的函数。每个 op 值的 Node 语义如下:

  • placeholder 表示函数输入。 name 属性指定该值将采用的名称。 target 是参数的名称。 args 包含:1)什么都没有,或 2)表示函数输入默认参数的单个参数 kwargs 是无关紧要的。占位符对应于图输出中的函数参数(例如 x)。

  • get_attr 从模块层次结构中检索参数。 name 类似地是获取结果被赋值的名称。 target 是参数在模块层次结构中的完全限定名称。 argskwargs 是无关紧要的

  • call_function 应用一个自由函数到某些值。name 类似地是赋值的目标名称。target 是要应用的函数。argskwargs 表示函数的参数,遵循 Python 调用约定

  • call_module 在模块层次结构的 forward() 方法中应用一个模块到给定的参数。 name 与之前相同。 target 是在模块层次结构中要调用的模块的完全限定名称。 argskwargs 表示用于调用模块的参数,包括 self 参数

  • call_method 调用值上的方法。 name 类似。 target 是要应用于 self 参数的方法的字符串名称。 argskwargs 表示调用模块时的参数, 包括 self 参数

  • output 包含其 args[0] 属性中的跟踪函数的输出。这对应于图形打印输出中的“返回”语句。

注意

此 API 的向后兼容性得到保证。

property all_input_nodes: List[Node]

返回作为此节点输入的所有节点。这相当于迭代 argskwargs 并仅收集那些是节点的值。

Returns:

列表中的 Nodes 出现在此 Nodeargskwargs 中,按此顺序。

append(x)[source]

在图的节点列表中,此节点后插入 x。 等同于 self.next.prepend(x)

Parameters:

x (节点) – 将此节点之后放置的节点。必须是同一图的成员。

注意

此 API 的向后兼容性得到保证。

property args: Tuple[Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout]], ...]

Node 的参数元组。参数的解释取决于节点的操作码。请参阅 Node 的文档字符串以获取更多信息。

允许对这一属性进行赋值。所有使用情况和用户的记录将在赋值时自动更新。

format_node(placeholder_names=None, maybe_return_typename=None)[source]

返回 self 的描述性字符串表示。

此方法可以用作调试工具,无需参数调用。

此函数也在__str__方法中内部使用 于Graph。一起,placeholder_namesmaybe_return_typename中的字符串构成了这个图的周围 GraphModule中自动生成的forward函数的签名。placeholder_namesmaybe_return_typename 不应在其他情况下使用。

Parameters:
  • 占位符名称 (可选[列表[字符串]]) – 一个将存储生成的 forward 函数中占位符的格式化字符串的列表。仅限内部使用。

  • maybe_return_typename (可选[列表[字符串]]) – 单元素列表,将存储 一个表示生成的 forward 函数输出的格式化字符串。仅限内部使用。

Returns:

If 1) we’re using format_node as an internal helper

__str__ 方法的 Graph 中,以及 2) self 是一个占位符节点,返回 None。否则, 返回当前节点的描述性字符串表示。

Return type:

字符串

注意

此 API 的向后兼容性得到保证。

is_impure()[source]

返回此操作是否为不纯的,即如果其操作是占位符或输出,或者如果调用的是不纯的 call_function 或 call_module。

Returns:

如果操作是不纯的或不是。

Return type:

布尔

警告

此API为实验性功能,且向后兼容。

property kwargs: Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout]]]

Node 的关键字参数字典。参数的解释取决于节点的 opcode。更多信息请参见 Node 的文档字符串。

允许对这一属性进行赋值。所有使用情况和用户的记录将在赋值时自动更新。

property next: Node

返回节点链表中的下一个Node

Returns:

链表中的下一个Node节点。

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source]

返回标准化的参数给Python目标。这意味着 args/kwargs 将与模块/功能的签名匹配,并按位置顺序返回仅关键字参数 如果 normalize_to_only_use_kwargs 为真。 同时填充默认值。不支持仅位置参数或可变参数。

支持模块调用。

可能需要arg_typeskwarg_types来消除重载的歧义。

Parameters:
  • (torch.nn.Module) – 用于解析模块目标的模块。

  • arg_types可选[Tuple[任何类型]]) – 参数类型的元组

  • kwarg_types (可选[字典[字符串, 任意类型]]) – 关键字参数类型的字典

  • normalize_to_only_use_kwargs (布尔值) – 是否仅使用关键字参数进行标准化。

Returns:

返回 NamedTuple ArgsKwargsPair,如果未成功则返回 None

Return type:

可选[ArgsKwargsPair]

警告

此API为实验性功能,且向后兼容。

prepend(x)[source]

在图的节点列表中此节点之前插入 x。示例:

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
Parameters:

x (节点) – 将此节点之前的节点。必须是同一图的成员。

注意

此 API 的向后兼容性得到保证。

property prev: Node

返回节点链表中的前一个 Node

Returns:

链表中节点的前一个 Node

replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>)[source]

将图中的所有 self 替换为节点 replace_with

Parameters:
  • 替换为 (节点) – 将所有使用 self 的地方替换为该节点。

  • delete_user_cb (可调用对象) – 用于确定是否应删除自节点的给定用户的回调函数。

Returns:

对此更改进行的节点列表。

Return type:

列表[节点]

注意

此 API 的向后兼容性得到保证。

replace_input_with(old_input, new_input)[source]

循环遍历输入节点 self,并将所有实例中的 old_input 替换为 new_input

Parameters:
  • 旧输入 (节点) – 要替换的旧输入节点。

  • 新输入 (节点) – 用于替换 old_input 的新输入节点。

注意

此 API 的向后兼容性得到保证。

property stack_trace: Optional[str]

返回在追踪期间记录的Python堆栈跟踪,如果有的话。 此属性通常由Tracer.create_proxy填充。为了在追踪期间记录 堆栈跟踪以供调试使用,请将record_stack_traces = True设置为Tracer实例上的值。

update_arg(idx, arg)[source]

更新现有位置参数以包含新值 arg。调用后,self.args[idx] == arg

Parameters:
  • 索引 (整数) – 要更新的元素在 self.args 中的索引

  • 参数 (参数) – 要写入 args 的新参数值

注意

此 API 的向后兼容性得到保证。

update_kwarg(key, arg)[source]

更新现有关键字参数以包含新值 arg。调用后,self.kwargs[key] == arg

Parameters:
  • (字符串) – 要更新元素在 self.kwargs 中的键

  • 参数 (参数) – 要写入 kwargs 的新参数值

注意

此 API 的向后兼容性得到保证。

class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source]

Tracer is the class that implements the symbolic tracing functionality of torch.fx.symbolic_trace. A call to symbolic_trace(m) is equivalent to Tracer().trace(m).

Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on this class.

注意

此 API 的向后兼容性得到保证。

call_module(m, forward, args, kwargs)[source]

指定此 Tracer 在遇到对 nn.Module 实例的调用时的行为的方法。

默认情况下,行为是检查被调用的模块是否为叶模块 通过is_leaf_module。如果是,则发出一个call_module节点,引用 mGraph中。否则,正常调用Module,跟踪其 forward函数中的操作。

此方法可以被重写以实现例如创建嵌套的追踪GraphModules,或在跨越 Module 边界时实现任何你希望的行为。

Parameters:
  • m (模块) – 正在发出调用的模块

  • forward (可调用对象) – 要调用的 Module 的 forward() 方法

  • 参数 (元组) – 模块调用站点的参数

  • kwargs (字典) – 模块调用站点的kwargs

Returns:

模块调用的返回值。如果发出了一个call_module节点,则这是一个Proxy值。否则,它是从Module调用返回的任何值。

Return type:

任何

注意

此 API 的向后兼容性得到保证。

create_arg(a)[source]

在准备将值用作Graph中的节点参数时,指定跟踪行为的方法。

默认行为包括:

  1. 遍历集合类型(例如元组、列表、字典)并递归地对元素调用create_args

  2. 给定一个代理对象,返回对底层IR的引用Node

  3. 给定一个非代理张量对象,为各种情况发出IR:

    • For a Parameter, emit a get_attr node referring to that Parameter

    • For a non-Parameter Tensor, store the Tensor away in a special attribute referring to that attribute.

此方法可以被重写以支持更多类型。

Parameters:

a (任意) – 要作为 ArgumentGraph 中发出的值。

Returns:

该值 a 转换为适当的 Argument

Return type:

可选[联合[元组[任意, …], 列表[任意], 字典[str, 任意], 切片, 节点, str, int, float, bool, 复数, 数据类型, 张量, 设备, 内存格式, 布局]]

注意

此 API 的向后兼容性得到保证。

create_args_for_root(root_fn, is_module, concrete_args=None)[source]

创建 placeholder 个节点,对应于 root 模块的签名。此方法检查根签名并相应地生成这些节点,同时支持 *args**kwargs

警告

此API为实验性功能,且向后兼容。

create_node(kind, target, args, kwargs, name=None, type_expr=None)

根据目标、参数、关键字参数和名称插入一个图节点。

此方法可以被重写以进行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止将就地操作记录下来。

注意

此 API 的向后兼容性得到保证。

Return type:

节点

create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)

从给定的参数创建一个节点,然后返回包装在代理对象中的节点。

如果 kind = 'placeholder',那么我们正在创建一个表示函数参数的节点。如果我们需要编码默认参数,我们使用args元组。args对于placeholder节点来说是空的。

注意

此 API 的向后兼容性得到保证。

getattr(attr, attr_val, parameter_proxy_cache)[source]

当我们在调用 Tracer 的 getattr 方法时,指定此方法的行为 在调用 nn.Module 实例时。

默认情况下,行为是为属性返回一个代理值。它还把代理值存储在parameter_proxy_cache中,以便将来调用时重用该代理而不是创建一个新的代理。

此方法可以被重写,例如,在查询参数时不返回代理。

Parameters:
  • 属性 (字符串) – 正在查询的属性名称

  • attr_val (任何类型) – 属性的值

  • parametr_proxy_cache (Dict[str, Any]) – 一个缓存,用于存储属性名称到代理的映射

Returns:

getattr 调用的返回值。

警告

此API为实验性功能,且向后兼容。

is_leaf_module(m, module_qualified_name)[source]

指定给定 nn.Module 是否为“叶子”模块的方法。

叶模块是出现在 IR 中的原子单元,由 call_module 调用引用。默认情况下, PyTorch 标准库命名空间(torch.nn)中的模块 是叶模块。所有其他模块都会被追踪并通过 记录其组成操作,除非通过此参数另行指定。

Parameters:
  • m (模块) – 正在查询的模块

  • 模块限定名称 (字符串) – 该模块的根路径。例如, 如果你有一个模块层次结构,其中子模块 foo 包含 子模块 bar,而该子模块又包含子模块 baz,那么该模块将 在此处以限定名称 foo.bar.baz 显示。

Return type:

布尔

注意

此 API 的向后兼容性得到保证。

iter(obj)
Called when a proxy object is being iterated over, such as

当在控制流中使用时。通常我们不知道该做什么,因为我们不知道代理的值,但是自定义跟踪器可以使用 create_node 将更多信息附加到图节点上,并可以选择返回一个迭代器。

注意

此 API 的向后兼容性得到保证。

Return type:

迭代器

keys(obj)
Called when a proxy object is has the keys() method called.

这是当在代理上调用**时发生的情况。这应该返回一个迭代器,**应该在你的自定义追踪器中工作。

注意

此 API 的向后兼容性得到保证。

Return type:

任何

path_of_module(mod)[source]

辅助方法,用于在mod的模块层次结构中找到合格名称 root。例如,如果root有一个名为foo的子模块, 该子模块又有一个名为bar的子模块,将bar传递给此函数将返回字符串“foo.bar”。

Parameters:

模块 (字符串) – 要检索其限定名称的Module

Return type:

字符串

注意

此 API 的向后兼容性得到保证。

proxy(node)

注意

此 API 的向后兼容性得到保证。

Return type:

代理

to_bool(obj)
Called when a proxy object is being converted to a boolean, such as

当在控制流中使用时。通常我们不知道该做什么,因为我们不知道代理的值,但是自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。

注意

此 API 的向后兼容性得到保证。

Return type:

布尔

trace(root, concrete_args=None)[source]

跟踪 root 并返回相应的 FX Graph 表示。 root 可以是 nn.Module 实例或 Python 可调用对象。

请注意,在此调用之后,self.root 可能与这里传递的 root 不同。例如,当将一个自由函数传递给 trace() 时,我们将创建一个 nn.Module 实例作为根节点,并添加嵌入的常量。

Parameters:
  • (联合[模块, 可调用对象]) – 既可以是 Module,也可以是要追踪的函数。此参数向后兼容性得到保证。

  • concrete_args (可选[字典[字符串, 任意类型]]) – 应被视为具体参数而不是代理的参数。此参数是实验性的,且其向后兼容性不保证。

Returns:

A Graph 表示传入的 root 的语义。

Return type:

注意

此 API 的向后兼容性得到保证。

class torch.fx.Proxy(node, tracer=None)[source]

Proxy 个对象是 Node 个包装器,在符号跟踪过程中流经程序并记录它们所接触的所有操作(torch 函数调用、方法调用、操作符)到不断增长的 FX 图中。

如果你正在进行图变换,你可以围绕原始的Proxy方法编写自己的方法, 这样你可以使用重载的操作符向Graph中添加额外的内容。

Proxy 对象无法被迭代。换句话说,如果在循环中使用 Proxy 或将其作为 *args/**kwargs 函数参数,符号跟踪器将会抛出错误。

有两种主要的方法可以解决这个问题: 1. 将无法追踪的逻辑提取到顶级函数中,并对其使用fx.wrap。 2. 如果控制流是静态的(即循环次数基于某些超参数),则可以将代码保留在其原始位置,并重构为类似以下内容:

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

有关代理内部结构的更详细描述,请参阅torch/fx/OVERVIEW.md中的“代理”部分

注意

此 API 的向后兼容性得到保证。

class torch.fx.Interpreter(module, garbage_collect_values=True)[source]

解释器逐节点执行FX图。这种模式对于许多事情都很有用,包括编写代码转换以及分析过程。

Interpreter 类中的方法可以被重写以自定义执行行为。可重写方法的调用层次结构图:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们想要将所有 torch.neg 实例与 torch.sigmoid 互换,反之亦然(包括它们的 Tensor 方法等效项)。我们可以像这样继承 Interpreter:

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_allclose(result, torch.neg(input).sigmoid())
Parameters:
  • 模块 (图模块) – 要执行的模块

  • garbage_collect_values (布尔值) – 是否在模块执行过程中在其最后一次使用后删除值。这确保了执行期间的最佳内存使用。可以禁用此功能,例如,通过查看Interpreter.env属性来检查执行中的所有中间值。

注意

此 API 的向后兼容性得到保证。

call_function(target, args, kwargs)[source]

执行一个call_function节点并返回结果。

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Return type:

任何

Return

任意:函数调用返回的值

注意

此 API 的向后兼容性得到保证。

call_method(target, args, kwargs)[source]

执行一个call_method节点并返回结果。

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Return type:

任何

Return

任意:方法调用返回的值

注意

此 API 的向后兼容性得到保证。

call_module(target, args, kwargs)[source]

执行一个call_module节点并返回结果。

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Return type:

任何

Return

任意:模块调用返回的值

注意

此 API 的向后兼容性得到保证。

fetch_args_kwargs_from_env(n)[source]

从当前执行环境中获取节点nargskwargs的具体值。

Parameters:

n (节点) – 应获取 argskwargs 的节点。

Returns:

argskwargs 与具体的 n 值。

Return type:

元组[元组, 字典]

注意

此 API 的向后兼容性得到保证。

fetch_attr(target)[source]

Module层次结构中的self.module获取一个属性。

Parameters:

目标 (字符串) – 要获取的属性的完全限定名称

Returns:

属性的值。

Return type:

任何

注意

此 API 的向后兼容性得到保证。

get_attr(target, args, kwargs)[source]

执行一个 get_attr 节点。将从 Module 层次结构的 self.module 中检索属性值。

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Returns:

检索到的属性的值

Return type:

任何

注意

此 API 的向后兼容性得到保证。

map_nodes_to_values(args, n)[source]

递归地向下遍历 args 并在当前执行环境中查找每个 Node 的具体值。

Parameters:
  • 参数 (参数) – 用于查找具体值的数据结构

  • n (节点) – args 所属的节点。这仅用于错误报告。

Return type:

可选[联合[元组[任意, …], 列表[任意], 字典[str, 任意], 切片, 节点, str, int, float, bool, 复数, 数据类型, 张量, 设备, 内存格式, 布局]]

注意

此 API 的向后兼容性得到保证。

output(target, args, kwargs)[source]

执行一个output节点。这实际上只是检索由output节点引用的值并返回它。

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Returns:

输出节点所引用的返回值

Return type:

任何

注意

此 API 的向后兼容性得到保证。

placeholder(target, args, kwargs)[source]

执行一个 placeholder 节点。请注意这是有状态的: Interpreter 保持传递给 run 的参数的内部迭代器,并且此方法返回该迭代器上的 next()。

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Returns:

检索到的参数值。

Return type:

任何

注意

此 API 的向后兼容性得到保证。

run(*args, initial_env=None, enable_io_processing=True)[source]

通过解释运行 module 并返回结果。

Parameters:
  • *args – 传递给 Module 的参数,按位置顺序排列

  • initial_env (可选[Dict[节点, 任何类型]]) – 可选的执行起始环境。 这是一个将 Node 映射到任何值的字典。例如,可以用来预先填充某些 Nodes 的结果,以便在解释器中仅进行部分评估。

  • enable_io_processing (bool) – 如果为真,我们首先使用图的 process_inputs 和 process_outputs 函数处理输入和输出,然后再使用它们。

Returns:

执行 Module 后返回的值

Return type:

任何

注意

此 API 的向后兼容性得到保证。

run_node(n)[source]

运行特定节点 n 并返回结果。 调用占位符、get_attr、call_function、 call_method、call_module 或 output,具体取决于 node.op

Parameters:

n (节点) – 要执行的节点

Returns:

执行 n 的结果

Return type:

任何

注意

此 API 的向后兼容性得到保证。

class torch.fx.Transformer(module)[source]

Transformer 是一种特殊的解释器,它生成一个新的 Module。它暴露了一个 transform() 方法,该方法返回转换后的 ModuleTransformer 运行时不需要参数,而 Interpreter 需要。Transformer 完全以符号方式工作。

示例

假设我们想要将所有 torch.neg 的实例与 torch.sigmoid 互换(包括它们的 Tensor 方法等价形式)。我们可以像这样子类化 Transformer

class NegSigmSwapXformer(Transformer):
    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid())
Parameters:

模块 (图模块) – 要转换的 Module

注意

此 API 的向后兼容性得到保证。

call_function(target, args, kwargs)[source]

注意

此 API 的向后兼容性得到保证。

Return type:

任何

call_module(target, args, kwargs)[source]

注意

此 API 的向后兼容性得到保证。

Return type:

任何

get_attr(target, args, kwargs)[source]

执行一个 get_attr 节点。在 Transformer 中,这被重写以将一个新的 get_attr 节点插入输出图中。

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Return type:

代理

注意

此 API 的向后兼容性得到保证。

placeholder(target, args, kwargs)[source]

执行一个 placeholder 节点。在 Transformer 中,这被重写以在输出图中插入一个新的 placeholder

Parameters:
  • 目标 (目标) – 此节点的调用目标。详见 节点 的语义说明

  • args (元组) – 此调用的位置参数元组

  • kwargs (字典) – 此调用的关键字参数的字典

Return type:

代理

注意

此 API 的向后兼容性得到保证。

transform()[source]

转换 self.module 并返回转换后的 GraphModule

注意

此 API 的向后兼容性得到保证。

Return type:

GraphModule

torch.fx.replace_pattern(gm, pattern, replacement)[source]

匹配所有可能的非重叠运算符集及其数据依赖(pattern)在GraphModule的图中(gm),然后将每个匹配的子图替换为另一个子图(replacement)。

Parameters:
Returns:

一个包含 Match 个对象的列表,表示在原始图中与 pattern 匹配的位置。如果没有匹配,则列表为空。Match 定义为:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

Return type:

List[Match]

Examples:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码将首先匹配 patternforward 方法中的 traced_module。模式匹配是基于使用-定义关系,而不是节点名称。例如,如果你在 p = torch.cat([a, b]) 中有 pattern,你可以匹配原始 forward 函数中的 m = torch.cat([a, b]),尽管变量名称不同(p vs m)。

The return 语句在 pattern 中仅根据其值进行匹配;它可能与较大图中的 return 语句匹配,也可能不匹配。换句话说,模式不必扩展到较大图的末尾。

当匹配模式时,它将从较大的函数中移除并替换为replacement。如果在较大的函数中有多个pattern的匹配项,则每个非重叠匹配项都将被替换。在匹配项重叠的情况下,将在重叠匹配项集合中找到的第一个匹配项将被替换。(这里的“第一个”是指根据节点使用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的内容。)

值得注意的是,pattern Callable 的参数必须在 Callable 本身中使用, 并且 replacement Callable 的参数必须匹配模式。第一个规则解释了为什么在上面的代码块中, forward 函数有参数 x, w1, w2,而 pattern 函数只有参数 w1, w2pattern 不使用 x,因此不应将其指定为参数。 作为第二个规则的一个例子,考虑替换

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

def replacement(x, y):
    return torch.relu(x)

在这种情况下,replacement 需要与 pattern 相同数量的参数 (xy),尽管参数 yreplacement 中未被使用。

调用 subgraph_rewriter.replace_pattern 后,生成的 Python 代码如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

此 API 的向后兼容性得到保证。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源