目录

torch.fx

概述

FX 是开发人员用于转换实例的工具包。FX 由三个主要组件组成:符号跟踪器、中间表示Python 代码生成。一个 这些组件的实际演示:nn.Module

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号跟踪器执行 Python 的 “符号执行” 法典。它通过代码提供称为 Proxies 的假值。操作 在这些代理上被记录下来。有关符号跟踪的更多信息 可以在 文档中找到。

中间表示是操作的容器 ,这些 ID 的 S 是在符号跟踪期间记录的。它由一个 表示函数输入、调用点(函数、方法、 或实例)和返回值。更多信息 有关 IR 的信息,请参阅 的文档。这 IR 是应用转换的格式。

Python 代码生成使 FX 成为 Python 到 Python(或 Module-to-Module) 转换工具包。对于每个 Graph IR,我们可以 创建与 Graph 的语义匹配的有效 Python 代码。这 功能封装在 中,它是一个实例,其中包含从 Graph 生成的 a 和方法。forward

总而言之,这个组件管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成) 构成了 FX 的 Python 到 Python 转换管道。在 此外,这些组件可以单独使用。例如 符号跟踪可以单独用于捕获某种形式的 用于分析 (而不是转换) 目的的代码。法典 generation 可用于以编程方式生成模型,对于 example 来自配置文件。FX 有很多用途!

可以在 examples 存储库中找到几个示例转换。

编写转换

什么是 FX 转换?从本质上讲,它是一个看起来像这样的函数。

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

您的转换将接收 ,从中获取 a,进行一些修改,并返回一个新的 .你应该想想你的 FX transform 返回与常规相同的 - 您可以将其传递给另一个 FX 转换,你可以将其传递给 TorchScript,或者你可以 运行它。确保 FX 转换的输入和输出是 a 将允许可组合性。

注意

也可以修改 existing 而不是 创建一个新的,如下所示:

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

请注意,您必须调用以使 上生成的方法与修改后的 .forward()GraphModule

鉴于您传入的 a 已被追踪到 a 中,现在您可以采用两种主要方法来构建新的 .

图形快速入门

文档中找到对图语义的完整处理,但我们将在这里介绍基础知识。A 是 表示 .这 这需要的信息是:

  • 该方法的输入是什么?

  • 方法内部运行哪些操作?

  • 该方法的输出(即 return)值是多少?

所有这三个概念都用实例表示。 让我们通过一个简短的例子来了解我们的意思:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

这里我们定义了一个用于演示目的的模块,实例化它, 符号式跟踪它,然后调用该方法以打印 出一个表格,显示 this 的节点:MyModule

操作码

名字

目标

参数

夸格斯

占 位 符

x

x

()

{}

get_attr

linear_weight

线性权重

()

{}

call_function

add_1

<内置函数 add>

(x, linear_weight)

{}

call_module

linear_1

线性

(add_1,)

{}

call_method

relu_1

RELU

(linear_1,)

{}

call_function

sum_1

<内置方法 sum ...>

(relu_1,)

{'dim': -1}

call_function

topk_1

<内置方法 topk ...>

(sum_1,3)

{}

输出

输出

输出

(topk_1,)

{}

我们可以使用这些信息来回答我们上面提出的问题。

  • 该方法的输入是什么?在 FX 中,指定方法输入 通过特殊节点。在这种情况下,我们有一个 a 为 的节点,这意味着我们有 名为 x 的单个(非 self)参数。placeholderplaceholdertargetx

  • 方法中有哪些操作?、 、 、 和 节点 表示方法中的操作。完整的处理 所有这些的语义都可以在文档中找到get_attrcall_functioncall_modulecall_method

  • 该方法的返回值是多少?a 中的返回值由特殊节点指定。output

鉴于我们现在了解了代码如何表示的基础知识 FX 的 API 中,我们现在可以探索如何编辑 .

图形操作

Direct Graph Manipulation

构建新方法的一种方法是直接操纵旧的 一。为了帮助解决这个问题,我们可以简单地从 symbolic 获取 跟踪并对其进行修改。例如,假设我们希望将 calls 替换为 calls。

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

我们还可以做更多涉及的重写,例如 删除或追加节点。为了帮助进行这些转换, FX 具有用于转换图形的实用函数,这些函数可以 在文档中找到。一 使用这些 API 追加调用的示例 可以在下面找到。torch.relu()

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

对于仅包含替换的简单转换,您还可以 使用 Subgraph Rewriter。

使用 replace_pattern() 重写子图

FX 还在直接图形操作的基础上提供了另一个级别的自动化。 API 本质上是一个用于编辑的 “查找/替换” 工具。它允许您指定 and 函数 它将跟踪这些函数,找到操作组的实例 ,然后将这些实例替换为图表的副本。这有助于大大自动化繁琐的图形操作代码,从而 随着转换变得越来越复杂,它会变得笨拙。patternreplacementpatternreplacement

代理/回溯

另一种操作 s 的方法是重用符号跟踪中使用的机制。例如,让我们 想象一下,我们想编写一个分解的 PyTorch 函数转换为较小的操作。它会将每个调用转换为 .一种可能性是 执行必要的图形重写以插入比较,然后 乘法,然后清理原始 。但是,我们可以通过使用对象将操作自动记录到 .F.relu(x)(x > 0) * xF.reluF.relu

要使用此方法,我们编写要作为常规插入的操作 PyTorch 代码,并使用对象作为参数调用该代码。 这些对象将捕获执行的操作 并将它们附加到 .

# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式的图形操作外,使用 s 还允许您将重写规则指定为本机 Python 代码。 对于需要大量重写规则的转换 (例如 vmap 或 grad)中,这通常可以提高可读性,并且 规则的可维护性。请注意,在调用 传递指向底层变量 Graph 的跟踪器。这样做 if if graph 中的操作是 n-ary (例如 add 是二元运算符) 对 的调用不会创建图形的多个实例 tracer 的 Tracer,这可能会导致意外的运行时错误。我们推荐这种方法 的 using 特别是当底层运算符不能 安全地假设为 1 元。

使用 s 进行操作的工作示例 可以在这里找到。

解释器模式

FX 中一个有用的代码组织模式是遍历所有 s 并执行它们。这可用于多种用途,包括 流经代码的图形或转换的值的运行时分析 通过使用 s.例如,假设我们要运行 a 并记录 shape 和 dtype 属性。这可能看起来像:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

如您所见,FX 的完整解释器并不复杂 但它可能非常有用。为了简化此模式的使用,我们提供了 类,它包含上述逻辑 以解释器执行的某些方面可以 通过 method overrides 覆盖。

除了执行操作之外,我们还可以通过解释器提供值来生成新的 Graph。 同样,我们提供了 encompass 类 这个模式。的行为类似于 ,但不是调用 从 Module 中获取具体的输出值,您将调用该方法以返回一个受任何转换规则约束的 new 您作为 overridden methods 安装。run

解释器模式的示例

调试

介绍

通常在创作转换的过程中,我们的代码不会完全正确。 在这种情况下,我们可能需要做一些调试。关键是工作 backwards:首先,检查调用生成的 Module 的结果来证明或 反驳正确性。然后,检查并调试生成的代码。然后,调试 导致生成代码的转换过程。

如果您不熟悉 Debuggers,请参阅辅助部分 Available Debuggers

转换创作中的常见陷阱

  • 非确定性迭代顺序。在 Python 中,数据类型为 无序。用于包含对象的集合,如 s, 例如,可能会导致意外的不确定性。一个例子是迭代 在一组 S 上,将它们插入到 .由于数据类型是无序的,因此输出中操作的顺序 程序将是不确定的,并且可以在程序调用之间更改。 推荐的替代方法是使用数据类型,该数据类型从 Python 3.7(和 cPython 3.6)开始按插入顺序排列。A 可以等效使用 通过将要删除重复数据的值存储在 .setsetsetNodeNodeGraphsetdictdictdict

检查模块的正确性

因为大多数深度学习模块的输出都是由浮点 Point 实例, 检查 两个的结果就没有那么简单了 就像执行简单的相等性检查一样。为了实现这一点,让我们使用 例:

import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在这里,我们尝试检查两个深度学习的值是否相等 模型。然而,这并不好—— 由于该运算符返回张量的问题,因此定义了两者 不是 bool,还因为浮点值的比较 应使用误差幅度(或 epsilon)来说明 浮点运算的非交换性(有关更多信息,请参阅此处 详细信息)。我们可以改用,这将得到 us 的近似比较,考虑了相对 和 绝对容差阈值:==

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们工具箱中第一个检查转换后的模块是否为 与参考实现相比,其行为符合我们的预期。

调试生成的代码

由于 FX 在 s 上生成函数,因此使用 传统的调试技术,如 Statements 或 IS 没有那么简单。幸运的是,我们可以使用几种技术 调试生成的代码。forward()printpdb

pdb

调用 以单步执行正在运行的程序。尽管 表示 is 不在任何源文件中,我们仍然可以单步执行 到其中。pdbpdb

import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用函数 fromto_folderGraphModule

是一个方法,它允许 u 将生成的 FX 代码转储到一个文件夹中。尽管将 forward 传递到代码中通常就足够了,如 Print the Generated Code, 使用 检查模块和参数可能更容易。GraphModuleto_folder

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

运行上述示例后,我们可以查看其中的代码并根据需要对其进行修改(例如添加语句或使用 )以调试生成的代码。foo/module.pyprintpdb

调试转换

现在,我们已经确定转换正在创建不正确的 code,那么是时候调试转换本身了。首先,我们将检查 文档中的 Limitations of Symbolic Tracing 部分。 一旦我们验证跟踪是否按预期工作,目标 变成了弄清楚我们的转型过程中出了什么问题。编写转换中可能有一个快速的答案,但如果没有,有几种方法可以 检查我们的 traced 模块:GraphModule

# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上面的实用函数,我们可以比较我们跟踪的 Module 之前和之后我们应用了转换。有时,一个 简单的视觉比较就足以追踪 bug。如果它仍然是 不清楚出了什么问题,像这样的调试器可能是一个不错的 下一步。pdb

从上面的示例开始,请考虑以下代码:

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

    """
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

使用上面的示例,假设 to 调用向我们显示转换中存在错误。我们想找到 使用 Debugger 时会出现什么问题。我们开始一个会话。我们可以看到 转换期间发生的情况,方法是 中断 ,然后按 以“单步执行”调用 自。print(traced)pdbtransform_graph(traced)stransform_graph(traced)

我们也可能通过编辑方法打印来获得好运 Graph 中 Nodes 的不同属性。(例如,我们可能 想要查看 Node 的 和 。print_tabularinput_nodesusers

可用的调试器

最常见的 Python 调试器是 pdb。您可以开始 在命令行中键入 “debug mode” 的程序,其中 是要调试的文件的名称。之后,您可以使用 debugger 命令逐步浏览正在运行的程序。通常将 breakpoint () ,然后调用 运行程序直到该点。这样,您就不必踏步 通过每行执行 (使用 或 ) 到达部件 。或者,您可以在要换行的行之前写下。 如果添加 ,您的程序将自动启动 在 DEBUG 模式下运行它。(换句话说,您只需在命令行中键入内容,而不是 。在 中运行文件后 debug 模式下,您可以单步调试代码并检查程序的 internal 状态。有很多优秀的 在线教程,包括 RealPython 的 “Python Debugging With Pdb”。pdbpython -m pdb FILENAME.pyFILENAMEpdbb LINE-NUMBERpdbcsnimport pdb; pdb.set_trace()pdb.set_trace()python FILENAME.pypython -m pdb FILENAME.pypdb

PyCharm 或 VSCode 等 IDE 通常内置了调试器。在您的 IDE 中,您可以选择 a) 通过拉起终端来使用 窗口中(例如 VSCode 中的 View → Terminal),或 b) 使用 内置调试器(通常是 周围的图形包装器)。pdbpdb

符号跟踪的限制

FX 使用符号跟踪系统(也称为符号 执行) 以可转换/可分析的形式捕获程序的语义。 系统正在进行跟踪,因为它执行程序(实际上是 a 或 function)来记录操作。它是象征性的,因为在此期间流经程序的数据 执行不是真实数据,而是品种(用 FX 的话说)。

尽管符号跟踪适用于大多数神经网络代码,但它有一些 局限性。

动态控制流

符号跟踪的主要限制是它目前不支持动态控制流。也就是说,其中 condition 可能取决于程序的 input 值。if

例如,让我们检查一下以下程序:

def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

语句的条件依赖于 的值 , 它依赖于 的值 ,一个函数输入。由于可以更改(即,如果您将新的输入张量传递给跟踪的 功能),这是动态控制流。回溯回溯 通过您的代码向您展示这种情况发生的位置。ifx.sum()xx

静态控制流

另一方面,支持所谓的静态控制流。静态的 控制流是其值无法更改的循环或语句 跨调用。通常,在 PyTorch 程序中,此控制流 对于代码做出有关模型架构的决策时出现 hyper-parameters 的举个具体的例子:if

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if 语句不依赖于任何 function inputs,因此它是静态的。 可以考虑 设置为超参数,并且该参数具有不同值的不同实例的跟踪具有不同的 法典。这是符号跟踪支持的有效模式。if self.do_activationdo_activationMyModule

动态控制流的许多实例在语义上是静态控制 流。这些实例可以通过以下方式支持符号跟踪 删除对输入值的数据依赖性,例如通过移动 values 绑定到 attributes 或将具体值绑定到参数 在符号跟踪期间:Module

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

在真正动态控制流的情况下,程序的各个部分 可以跟踪为对 Method (请参阅 使用 Tracer 类自定义跟踪) 或函数 (请参阅 )的调用,而不是通过它们进行跟踪。

非函数torch

FX 用作拦截 调用(请参阅技术 overview 以了解有关此内容的更多信息)。一些函数,例如内置的 Python 函数或模块中的函数未被 覆盖,但我们仍希望在 符号跟踪。例如:__torch_function__math__torch_function__

import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

该错误告诉我们内置函数不受支持。 我们可以将这样的函数记录在跟踪中,作为 使用 API 的直接调用:len

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用类自定义跟踪Tracer

类是 的实现。跟踪的行为可以是 通过子类化 Tracer 进行自定义,如下所示:symbolic_trace

class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶模块

叶模块是在符号跟踪中显示为调用的模块 而不是被追踪。默认的叶模块集是 一组标准模块实例。例如:torch.nn

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

叶模块集可以通过覆盖 来自定义。

杂项

  • 张量构造函数(例如 、 ) 目前无法追踪。torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor

    • 可以使用确定性构造函数 (, ) 它们生成的值将作为 不断。只有当这些 constructors 是指动态输入大小。在这种情况下,或者可能是一个可行的替代品。zerosonesones_likezeros_like

    • 非确定性构造函数 (, ) 将具有 跟踪中嵌入的单个随机值。这可能不是 预期行为。一种解决方法是包装一个函数并改为调用它。randrandntorch.randntorch.fx.wrap

    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 此行为可能会在将来的版本中修复。

  • 类型注释

    • 支持 Python 3 样式的类型注释(例如 ) 并将通过符号跟踪进行保留。func(x : torch.Tensor, y : int) -> torch.Tensor

    • Python 2 样式的注释类型注释目前不可用 支持。# type: (torch.Tensor, int) -> torch.Tensor

    • 函数中本地名称的注释当前不是 支持。

  • flag 和子模块的陷阱training

    • 当使用像 这样的函数时,训练参数通常以 的形式传入。在 FX 跟踪期间,这可能会作为常量值烘焙。torch.nn.functional.dropoutself.training

    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
      def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
      dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
      return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    """
    AssertionError: Tensor-likes are not close!
    
    Mismatched elements: 15 / 15 (100.0%)
    Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
    Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
    """
    
    • 但是,当使用标准子模块时,训练标志被封装,并且由于保留了对象模型,可以更改。nn.Dropout()nn.Module

    class DropoutRepro2(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
      def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
      drop = self.drop(x);  x = None
      return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    
  • 由于这种差异,请考虑将与标志动态交互的模块标记为叶模块。training

API 参考

torch.fx 中。symbolic_tracerootconcrete_args=None[来源]

符号跟踪 API

给定一个 or 函数实例,此函数将返回一个通过记录在跟踪时看到的操作构建的 。nn.ModulerootGraphModuleroot

concrete_args允许您部分专用化您的函数,无论是删除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于存在控制,FX 通常无法跟踪此情况 流。但是,我们可以使用 concrete_args 来专门化 b 的值,以便通过以下方式进行跟踪:

f = fx.symbolic_trace(f, concrete_args={'b': False})
assert f(3, False)  == 6

请注意,尽管你仍然可以传入不同的 b 值,但它们将被忽略。

我们还可以使用 concrete_args 来消除 我们的函数。这将使用 pytree 来展平您的输入。为避免 过度专业化,传入 FX。PH 值 专业。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
参数
  • rootUnion[torch.nn.ModuleCallable]) – 要跟踪和转换的模块或函数 转换为 Graph 表示形式。

  • concrete_argsOptional[Dict[strany]]) – 要部分专用的输入

返回

根据 中记录的操作创建的模块 。root

返回类型

GraphModule

注意

保证此 API 的向后兼容性。

torch.fx 中。wrapfn_or_name[来源]

可以在模块级范围内调用此函数,以将fn_or_name注册为“叶函数”。 “叶函数”将保留为 FX 跟踪中的 CallFunction 节点,而不是 追踪方式:

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

此函数也可以等效地用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

包装函数可以被认为是“叶函数”,类似于 “叶模块”,即,它们是在 FX 跟踪中作为调用保留的函数 而不是被追踪。

参数

fn_or_nameUnion[strCallable]) – 要插入到 graph 时调用

注意

保证此 API 的向后兼容性。

torch.fx 中。GraphModule*args**kwargs[来源]

GraphModule 是一个 nn.从 fx.图。Graphmodule 具有 attribute 以及生成的 和 attribute 从那个 .graphcodeforwardgraph

警告

when 重新分配,并将自动 再生。但是,如果您编辑 的内容而不重新分配 属性本身,您必须调用该属性来更新生成的 法典。graphcodeforwardgraphgraphrecompile()

注意

保证此 API 的向后兼容性。

__init__class_name='GraphModule'[来源]

构造一个 GraphModule。

参数
  • rootUnion[torch.nn.ModuleDict[strAny]) —— 可以是 nn.Module 实例或 Dict 映射字符串到任何属性类型。 在是 Module 的情况下,对基于 Module 的对象的任何引用(通过限定的 name) 将从相应的位置复制 的 Module 层次结构中放入 GraphModule 的 module 层次结构中。 如果是 dict,则在 Node 中找到的限定名称将是 直接在 dict 的 keys 中查找。Dict 映射到的对象将被复制 移动到 GraphModule 的模块层次结构中的适当位置。rootroottargetrootroottarget

  • graphGraph) —— 包含此 GraphModule 应该用于代码生成的节点graph

  • class_namestr) – 表示此 GraphModule 的名称,用于调试目的。如果未设置,则所有 错误消息将报告为 源自 。设置此项可能会有所帮助 更改为 的原始名称或在转换上下文中有意义的名称。nameGraphModuleroot

注意

保证此 API 的向后兼容性。

add_submodule目标m[来源]

将给定的子模块添加到 。self

这将安装尚不存在的空模块 的子路径。target

参数
  • targetstr) – 新子模块的完全限定字符串名称 (请参阅 中的示例了解如何 指定完全限定的字符串。nn.Module.get_submodule

  • mModule) - 子模块本身;我们想要的实际对象 install 在当前 Module 中

返回

是否可以插入 submodule。为

this 方法返回 True,链中的每个对象 表示必须 a) 尚不存在, 或 b) 引用 an(不是参数或 other 属性)targetnn.Module

返回类型

布尔

注意

保证此 API 的向后兼容性。

属性代码 str

返回从底层 this 生成的 Python 代码。GraphGraphModule

delete_all_unused_submodules)[来源]

从 中删除所有未使用的子模块 。self

如果满足以下任何一项条件,则认为 Module 是 “used” 的 真: 1. 它有被使用的子项 2. 它的 forward 是通过节点直接调用的 3. 它具有从节点使用的非 Module 属性call_moduleget_attr

可以调用此方法来清理 without 手动调用每个未使用的 submodule。nn.Moduledelete_submodule

注意

保证此 API 的向后兼容性。

delete_submodule目标[来源]

从 中删除给定的子模块 。self

如果 无效,则不会删除该模块 目标。target

参数

targetstr) – 新子模块的完全限定字符串名称 (请参阅 中的示例了解如何 指定完全限定的字符串。nn.Module.get_submodule

返回

目标字符串是否引用了

sub模块。返回值 表示 不是对 一个子模块。Falsetarget

返回类型

布尔

注意

保证此 API 的向后兼容性。

property graph 图形

返回底层的 thisGraphGraphModule

print_readableprint_output=include_stride=include_device=彩色=[来源]

返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码

警告

此 API 是实验性的,向后兼容。

recompile)[来源]

从其 attribute 中重新编译此 GraphModule。这应该是 在编辑包含的 的代码将过时。graphgraphGraphModule

注意

保证此 API 的向后兼容性。

返回类型

Python代码

to_folder文件夹module_name='FxModule'[来源]
将模块转储到 with,以便它可以foldermodule_name

导入方式from <folder> import <module_name>

参数:

文件夹 (Union[str, os.PathLike]):要将代码写出到的文件夹

module_name (str):用于 while 的顶级名称Module

写出代码

警告

此 API 是实验性的,向后兼容。

torch.fx 中。Graphowning_module=tracer_cls=tracer_extras=[来源]

Graph是 FX 中间表示中使用的主要数据结构。 它由一系列 s 组成,每个 s 代表调用站点(或其他 syntactic constructs) 的 Alpha Constructs)。这些列表加在一起构成了一个 有效的 Python 函数。NodeNode

例如,以下代码

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

将生成以下 Graph:

print(gm.graph)
graph(x):
    %linear_weight : [num_users=1] = self.linear.weight
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

有关 中表示的操作的语义,请参阅 Graph

注意

保证此 API 的向后兼容性。

__init__owning_module=tracer_cls=tracer_extras=[来源]

构造一个空 Graph。

注意

保证此 API 的向后兼容性。

call_functionthe_functionargs=kwargs=type_expr=[来源]

将 插入 中。一个节点 表示对 Python 可调用对象的调用,由 .call_functionNodeGraphcall_functionthe_function

参数
  • the_functionCallable[...Any]) – 要调用的函数。可以是任何 PyTorch operator、Python 函数或 or 命名空间的成员。builtinsoperator

  • argsOptional[Tuple[Argument...]]) – 要传递的位置参数 添加到被调用的函数中。

  • kwargsOptional[Dict[strArgument]]) – 要传递的关键字参数 到被调用的函数

  • type_exprOptional[Any]) – 一个可选的类型注释,表示 Python 类型,则此节点的输出将具有。

返回

新创建并插入的节点。call_function

返回类型

节点

注意

相同的插入点和类型表达式规则适用于此方法 作为 .

注意

保证此 API 的向后兼容性。

call_methodmethod_nameargs=kwargs=type_expr=[来源]

将 插入 中。一个节点 表示对 的第 0 个元素上的给定方法的调用 。call_methodNodeGraphcall_methodargs

参数
  • method_namestr) – 要应用于 self 参数的方法的名称。 例如,如果 args[0] 是表示 , 然后调用 该 ,将 传递给 。NodeTensorrelu()Tensorrelumethod_name

  • argsOptional[Tuple[Argument...]]) – 要传递的位置参数 添加到被调用的方法中。请注意,这应该包括一个参数。self

  • kwargsOptional[Dict[strArgument]]) – 要传递的关键字参数 添加到被调用的方法

  • type_exprOptional[Any]) – 一个可选的类型注释,表示 Python 类型,则此节点的输出将具有。

返回

新创建并插入的节点。call_method

返回类型

节点

注意

相同的插入点和类型表达式规则适用于此方法 作为 .

注意

保证此 API 的向后兼容性。

call_modulemodule_nameargs=kwargs=type_expr=[来源]

将 插入 中。一个节点 表示对层次结构中 a 的 forward() 函数的调用。call_moduleNodeGraphcall_moduleModuleModule

参数
  • module_namestr) – 要调用的层次结构中 的限定名称。例如,如果被跟踪的 子模块名为 ,该子模块具有一个名为 的子模块,该 限定名称应传递给 调用该模块。ModuleModuleModulefoobarfoo.barmodule_name

  • argsOptional[Tuple[Argument...]]) – 要传递的位置参数 添加到被调用的方法中。请注意,这不应包含参数。self

  • kwargsOptional[Dict[strArgument]]) – 要传递的关键字参数 添加到被调用的方法

  • type_exprOptional[Any]) – 一个可选的类型注释,表示 Python 类型,则此节点的输出将具有。

返回

新创建并插入的节点。call_module

返回类型

节点

注意

相同的插入点和类型表达式规则适用于此方法 作为 .

注意

保证此 API 的向后兼容性。

create_nodeoptargetargs=kwargs=name=type_expr=没有[来源]

创建一个并将其添加到当前插入点。 请注意,当前的插入点可以通过 和 来设置。NodeGraph

参数
  • opstr) – 此 Node 的操作码。'call_function'、'call_method'、'get_attr' 之一, 'call_module'、'placeholder' 或 'output'。这些操作码的语义是 在文档字符串中描述。Graph

  • argsOptional[Tuple[Argument...]]) – 是此节点的参数元组。

  • kwargsOptional[Dict[strArgument]]) – 此节点的 kwargs

  • nameOptional[str]) – . 这将影响在 Python 生成的代码。Node

  • type_exprOptional[Any]) – 一个可选的类型注释,表示 Python 类型,则此节点的输出将具有。

返回

新创建并插入的节点。

返回类型

节点

注意

保证此 API 的向后兼容性。

eliminate_dead_codeis_impure_node=[来源]

根据每个节点的 users,以及节点是否有任何副作用。图形必须为 在调用之前进行拓扑排序。

参数
  • is_impure_nodeOptional[Callable[[Node]bool]]) – 返回

  • None节点是否不纯。如果是)–

  • to则默认行为为) –

  • Node.is_impure。使用) –

返回

图形是否因传递而更改。

返回类型

布尔

例:

在消除死代码之前,下面 a = x + 1 中的 a 没有用户 因此可以从图形中消除而不会产生影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

消除死代码后,a = x + 1 已被删除,其余的 的前进遗骸。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除有一些启发式方法可以避免删除 副作用节点(参见 Node.is_impure),但一般覆盖 很糟糕,所以你应该假设这个方法不健全 调用,除非您知道您的 FX 图表完全由 的功能性操作,或者您提供自己的自定义 检测副作用节点的功能。

注意

保证此 API 的向后兼容性。

erase_nodeto_erase[来源]

从 中擦除 。如果出现 在 .NodeGraphGraph

参数

to_erase (Node) (Node) ( (Node)) – 要从 中擦除的 .NodeGraph

注意

保证此 API 的向后兼容性。

find_nodes*optarget=Nonesort=True[来源]

允许快速查询节点

参数
  • opstr) – 操作的名称

  • targetOptional[Target]) – 节点的目标。对于call_function, 目标是必需的。对于其他操作,目标是可选的。

  • sortbool) – 是否按节点出现的顺序返回节点 在图表上。

返回

具有请求的 op 和 target 的节点的可迭代。

警告

此 API 是实验性的,向后兼容。

get_attrqualified_nametype_expr=[来源]

在 Graph 中插入节点。A 表示 从层次结构中获取属性。get_attrget_attrNodeModule

参数
  • qualified_namestr) – 要检索的属性的完全限定名称。 例如,如果跟踪的 Module 有一个名为 的子模块,该子模块有一个 子模块,该子模块具有一个名为 、 的合格 name 应作为 .foobarbazfoo.bar.bazqualified_name

  • type_exprOptional[Any]) – 一个可选的类型注释,表示 Python 类型,则此节点的输出将具有。

返回

新创建并插入的节点。get_attr

返回类型

节点

注意

相同的插入点和类型表达式规则适用于此方法 如。Graph.create_node

注意

保证此 API 的向后兼容性。

graph_copygval_mapreturn_output_node=False[来源]

将给定图形中的所有节点复制到 中。self

参数
  • gGraph) (图形) – 要从中复制节点的源图形。

  • val_mapDict[NodeNode]) – 将填充映射的字典 从 中的节点到 中的 节点。请注意,可以传递 in 中,以覆盖某些值的复制。gselfval_map

返回

其中的值现在等效于 中的 输出值 。 如果有一个节点。 否则。selfggoutputNone

返回类型

可选[Union[Tuple[Any, ...], List[Any], Dict[strAny], 切片范围NodestrintfloatboolcomplexdtypeTensor设备memory_format布局OpOverloadSymIntSymBoolSymFloat]]

注意

保证此 API 的向后兼容性。

inserting_aftern=[来源]
设置 create_node 和 Companion 方法将插入到图形中的点。

当在 'with' 语句中使用时,这将临时设置插入点和 然后在 with 语句退出时恢复它:

with g.inserting_after(n):
    ... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) #  set the insert point permanently

参数:

n (Optional[Node]):要插入的节点。如果为 None,则将在

整个图形的开头。

返回:

将恢复 上的插入点的资源管理器。__exit__

注意

保证此 API 的向后兼容性。

inserting_beforen=[来源]
设置 create_node 和 Companion 方法将插入到图形中的点。

当在 'with' 语句中使用时,这将临时设置插入点和 然后在 with 语句退出时恢复它:

with g.inserting_before(n):
    ... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) #  set the insert point permanently

参数:

n (Optional[Node]):要插入的节点。如果为 None,则将在

整个图形的开头。

返回:

将恢复 上的插入点的资源管理器。__exit__

注意

保证此 API 的向后兼容性。

lint)[来源]

对此 Graph 运行各种检查,以确保其格式正确。在 特定: - 检查节点是否具有正确的所有权(由此图表拥有) - 检查 节点按拓扑顺序显示 - 如果此 Graph 具有拥有的 GraphModule,则检查目标 存在于该 GraphModule 中

注意

保证此 API 的向后兼容性。

node_copynodearg_transform=<function Graph.<lambda>>[来源]

将节点从一个图形复制到另一个图形中。 需要将参数从 Node 的 Graph 到 self 的 Graph。例:arg_transform

# Copying all the nodes in `g` into `new_graph`
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
参数
  • nodeNode) – 要复制到的节点。self

  • arg_transformCallable[[node]Argument]) – 一个函数,用于将 node 中的参数转换为 等效参数。在最简单的情况下,这应该 从表映射原始节点中检索值 graph 设置为 .Nodeargskwargsselfself

返回类型

节点

注意

保证此 API 的向后兼容性。

属性节点_node_list

获取构成此 Graph 的 Node 的列表。

请注意,此列表表示形式是一个双向链表。突变 在迭代期间(例如,删除 Node、添加 Node)是安全的。Node

返回

节点的双向链表。请注意,可以调用 此列表用于切换迭代顺序。reversed

on_generate_codemake_transformer[来源]

生成 python 代码时注册 transformer 函数

参数:
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):

返回要注册的代码转换器的函数。 on_generate_code 调用此函数以获取 代码转换器。

此函数也作为其当前 注册代码转换器(如果未注册任何内容,则为 None), 以防不希望覆盖它。这对于 将 CODE 转换器链接在一起。

返回:

一个上下文管理器,当在 with 语句中使用时,它会自动 恢复以前注册的 Code Transformer。

例:

gm: fx.GraphModule = ...

# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]

# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(
    lambda _: insert_pdb
)

# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(
            current_trans(body) if current_trans
            else body
        )
    )
)

gm.recompile()
gm(*inputs)  # drops into pdb

此功能还可以用作上下文管理器,其优点是 自动恢复以前注册的代码转换器:

# ... continue from previous example

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # do more stuff with `gm`...
    gm.recompile()
    gm(*inputs)  # drops into pdb

# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you
# run next `recompile()`).

警告

此 API 是实验性的,向后兼容。

outputresulttype_expr=None[来源]

将 插入 中。一个节点表示 Python 代码中的语句。 的值应该 被退回。outputNodeGraphoutputreturnresult

参数
  • resultArgument) – 要返回的值。

  • type_exprOptional[Any]) – 一个可选的类型注释,表示 Python 类型,则此节点的输出将具有。

注意

相同的插入点和类型表达式规则适用于此方法 如。Graph.create_node

注意

保证此 API 的向后兼容性。

placeholdernametype_expr=Nonedefault_value[来源]

在 Graph 中插入节点。A 表示 函数输入。placeholderplaceholder

参数
  • namestr) – 输入值的名称。这与名称 的 position 参数传递给 this 表示的函数。Graph

  • type_exprOptional[Any]) – 一个可选的类型注释,表示 Python 类型,则此节点的输出将具有。这在一些 正确生成代码的情况(例如,当使用函数时 随后在 TorchScript 编译中)。

  • default_valueAny) – 此函数参数应采用的默认值 上。注意:要允许 None 作为默认值,请检查。Signature.empty 应作为此参数传递,以指定该参数 _not_ 具有默认值。

返回类型

节点

注意

相同的插入点和类型表达式规则适用于此方法 如。Graph.create_node

注意

保证此 API 的向后兼容性。

print_tabular[来源]

以表格形式打印图形的中间表示 格式。请注意,此 API 要求模块为 安装。tabulate

注意

保证此 API 的向后兼容性。

process_inputs*args[来源]

处理 args,以便它们可以传递到 FX 图。

警告

此 API 是实验性的,向后兼容。

process_outputsout[来源]

警告

此 API 是实验性的,向后兼容。

python_coderoot_module*verbose=Falseinclude_stride=Falseinclude_device=False彩色=False[来源]

将此代码转换为有效的 Python 代码。Graph

参数

root_modulestr) – 要查找的根模块的名称 限定名称目标。这通常是 “self”。

返回

src:表示对象的 Python 源代码 globals:src 中的全局名称字典 ->它们引用的对象。

返回类型

一个 PythonCode 对象,由两个字段组成

注意

保证此 API 的向后兼容性。

set_codegencodegen[来源]

警告

此 API 是实验性的,向后兼容。

torch.fx 中。Nodegraphnameoptargetargskwargsreturn_type=None[来源]

Node是表示 一个。在大多数情况下,节点表示各种实体的调用点, 例如运算符、方法和模块(一些例外包括 指定函数输入和输出)。每个 API 都有一个指定的函数 通过其属性。每个值的语义如下:GraphNodeopNodeop

  • placeholder表示函数输入。该属性指定此值将采用的名称。 同样是参数的名称。 包含:1) 什么都没有,或 2) 单个参数 表示函数 input 的默认参数。 是 don't-care。占位符对应于 图形打印输出中的函数参数 (例如 )。nametargetargskwargsx

  • get_attr从 Module 层次结构中检索参数。 同样,名称是 fetch 被分配给。 是参数在模块层次结构中的位置的完全限定名称。 他们不在乎nametargetargskwargs

  • call_function将 free 函数应用于某些值。 同样是要赋值的值的名称 自。 是要应用的函数。 并表示函数的参数, 遵循 Python 调用约定nametargetargskwargs

  • call_module将 Module 层次结构的方法中的 Module 应用于给定的参数。 是 如前所述。 是模块层次结构中要调用的模块的完全限定名称。 并表示要调用模块的参数,不包括 self 参数forward()nametargetargskwargs

  • call_method对值调用方法。 也一样。 是方法的字符串名称 以应用于参数。 并表示要调用模块的参数,包括 self 参数nametargetselfargskwargs

  • output在其属性中包含 traced 函数的输出。这对应于 “return” 语句 在 Graph 打印输出中。args[0]

注意

保证此 API 的向后兼容性。

属性 all_input_nodes:List[Node]

返回作为此节点输入的所有节点。这相当于 迭代 并且仅收集 是 Node 的节点。argskwargs

返回

按此顺序出现在 和 of this 中的列表。NodesargskwargsNode

appendx[来源]

在图形中的节点列表中的此节点后插入。 相当于xself.next.prepend(x)

参数

xNode) (节点) – 要放在此节点之后的节点。必须是同一图形的成员。

注意

保证此 API 的向后兼容性。

属性参数 Tuple[Optional[Union[Tuple[Any ...]List[Any]Dict[str Any]切片 范围 Node str int float bool complex dtypeTensordevicememory_formatlayoutOpOverloadSymIntSymBoolSymFloat]]...]

this 的参数元组。论点的解释 取决于节点的操作码。有关更多信息,请参阅文档字符串 信息。Node

允许分配此属性。所有使用和用户的会计 在分配时自动更新。

format_nodeplaceholder_names=maybe_return_typename=[来源]

返回 的描述性字符串表示形式。self

此方法可以在没有参数的情况下用作调试 效用。

此函数也在方法内部使用 之。这些字符串一起构成了 autogenerated 函数 GraphModule 的 不应以其他方式使用。__str__Graphplaceholder_namesmaybe_return_typenameforwardplaceholder_namesmaybe_return_typename

参数
  • placeholder_namesOptional[List[str]]) – 将存储格式化字符串的列表 表示生成的函数中的占位符。仅供内部使用。forward

  • maybe_return_typenameOptional[List[str]]) – 将存储 一个格式化字符串,表示 generated 函数。仅供内部使用。forward

返回

如果 1) 我们用作内部帮助程序format_node

在 的方法中,并且 2) 是一个占位符 Node,返回 。否则 返回 当前节点。__str__GraphselfNone

返回类型

str

注意

保证此 API 的向后兼容性。

insert_argidxarg[来源]

将位置参数插入到具有给定索引的参数列表中。

参数
  • idxint) - 要插入的元素 in 的索引。self.args

  • argArgument) – 要插入的新参数值args

注意

保证此 API 的向后兼容性。

is_impure)[来源]

返回此 op 是否为 impure,即其 op 是否为占位符或 输出,或者如果call_function或不纯call_module。

返回

操作是否不纯。

返回类型

布尔

警告

此 API 是实验性的,向后兼容。

属性 kwargs Dict[str Optional[Union[Tuple[Any ...]List[Any]Dict[str Any]切片 范围 Node str int float bool complex dtypeTensordevicememory_format、布局OpOverloadSymIntSymBoolSymFloat]]]

此 的关键字参数的 dict 。论点的解释 取决于节点的操作码。有关更多信息,请参阅文档字符串 信息。Node

允许分配此属性。所有使用和用户的会计 在分配时自动更新。

property next 节点

返回 Node 链表中的 next。Node

返回

节点链表中的下一个。Node

normalized_argumentsrootarg_types=kwarg_types=normalize_to_only_use_kwargs=False[来源]

将规范化参数返回给 Python 目标。这意味着 args/kwargs 将与 module/functional 的 signature 并按位置顺序独占返回 kwargs 如果 normalize_to_only_use_kwargs 为 true。 同时填充默认值。不支持仅位置 parameters 或 varargs 参数。

支持模块调用。

可能需要 arg_typeskwarg_types 以消除重载的歧义。

参数
  • roottorch.nn.Module) – 解析模块目标所依据的模块。

  • arg_typesOptional[Tuple[Any]]) – args 的 arg 类型元组

  • kwarg_typesOptional[Dict[strAny]]) – kwargs 的 arg 类型字典

  • normalize_to_only_use_kwargsbool) – 是否标准化为仅使用 kwargs。

返回

返回 NamedTuple ArgsKwargsPair,如果不成功,则返回 None

返回类型

可选[ArgsKwargsPair]

警告

此 API 是实验性的,向后兼容。

prependx[来源]

在图形的节点列表中的此节点之前插入 x。例:

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
参数

xNode) – 要放在此节点之前的节点。必须是同一图形的成员。

注意

保证此 API 的向后兼容性。

property prev 节点

返回 Node 链表中的上一个。Node

返回

节点链表中的上一个。Node

replace_all_uses_withreplace_withdelete_user_cb=<function Node.<lambda>>*propagate_meta=False[来源]

将 Graph 中的所有 使用 替换为 Node .selfreplace_with

参数
  • replace_with (Node) (节点) – 要将 的所有使用替换为的节点。self

  • delete_user_cbCallable) – 为确定而调用的回调 是否应删除 self 节点的给定用户。

  • propagate_metabool) – 是否复制所有属性 在原始节点的 .meta 字段上。 为了安全起见,这仅在替换节点 还没有现有的 .meta 字段。

返回

进行了此更改的 Node 的列表。

返回类型

列表[Node]

注意

保证此 API 的向后兼容性。

replace_input_withold_inputnew_input[来源]

遍历 的输入节点 ,并将 的所有实例替换为 。selfold_inputnew_input

参数
  • old_input (Node) ( (Node) (节点)) – 要替换的旧输入节点。

  • new_input (Node) ( (Node) (节点)) – 要替换的新输入节点。old_input

注意

保证此 API 的向后兼容性。

属性stack_trace:Optional[str]

返回在跟踪期间记录的 Python 堆栈跟踪(如果有)。 使用 fx.Tracer 中,此属性通常由 Tracer.create_proxy 填充。要在跟踪期间记录堆栈跟踪以进行调试,请执行以下操作: 在 Tracer 实例上设置 record_stack_traces = True。 使用 dynamo 进行追踪时,默认情况下,此属性将由 OutputGraph.create_proxy 填充。

stack_trace 将在字符串的末尾具有最内层的帧。

update_argidxarg[来源]

更新现有位置参数以包含新值 。调用后,.argself.args[idx] == arg

参数
  • idxint) – 要更新的元素的索引self.args

  • argArgument) – 要写入的新参数值args

注意

保证此 API 的向后兼容性。

update_kwargkeyarg[来源]

更新现有关键字参数以包含新值 。调用后,.argself.kwargs[key] == arg

参数
  • keystr) – 要更新的元素的 key inself.kwargs

  • argArgument) – 要写入的新参数值kwargs

注意

保证此 API 的向后兼容性。

torch.fx 中。Tracerautowrap_modules=(math,)autowrap_functions=()[来源]

Tracer是实现符号跟踪功能的类 之。对 的调用是等效的 自。torch.fx.symbolic_tracesymbolic_trace(m)Tracer().trace(m)

Tracer 可以被子类化以覆盖跟踪的各种行为 过程。描述了可以覆盖的不同行为 在此类上方法的文档字符串中。

注意

保证此 API 的向后兼容性。

call_modulemforwardargskwargs[来源]

指定 this 在遇到 对实例的调用。Tracernn.Module

默认情况下,行为是检查被调用的模块是否为叶模块 通过。如果是,则发出一个引用 .否则,请 normally 调用 tracethrough 其函数中的操作。is_leaf_modulecall_modulemGraphModuleforward

例如,可以覆盖此方法以创建嵌套跟踪 GraphModules 或跨边界跟踪时所需的任何其他行为。Module

参数
  • mModule) – 要发出调用的模块

  • forwardCallable) - 要调用的 forward() 方法Module

  • argsTuple) – 模块 callsite 的 args

  • kwargsDict) – 模块 callsite 的 kwargs

返回

Module 调用的返回值。如果发出了节点,则这是一个值。否则,它就是任何东西 值。call_moduleProxyModule

返回类型

任何

注意

保证此 API 的向后兼容性。

create_arga[来源]

在将值准备为 用作 .Graph

默认情况下,行为包括:

  1. 迭代集合类型(例如 tuple、list、dict)和递归 调用元素。create_args

  2. 给定一个 Proxy 对象,返回对基础 IR 的引用Node

  3. 给定一个非 Proxy Tensor 对象,针对各种情况发出 IR:

    • 对于 Parameter(参数),发出引用该 Parameter 的节点get_attr

    • 对于非 Parameter Tensor,请将 Tensor 存储在特殊的 属性引用该属性。

可以重写此方法以支持更多类型。

参数

aAny) – 要在 .ArgumentGraph

返回

该值转换为适当的aArgument

返回类型

论点

注意

保证此 API 的向后兼容性。

create_args_for_rootroot_fnis_moduleconcrete_args=[来源]

创建与 Module 的签名对应的节点。此方法内省 root 的签名并发出 节点,也支持 和 。placeholderroot*args**kwargs

警告

此 API 是实验性的,向后兼容。

create_nodekindtargetargskwargsname=Nonetype_expr=无)

插入给定 target、args、kwargs 和 name 的图形节点。

可以重写此方法以执行额外的检查、验证或 修改节点创建中使用的值。例如,一个人可能会 想要禁止记录就地操作。

注意

保证此 API 的向后兼容性。

返回类型

节点

create_proxykindtargetargskwargsname=type_expr=proxy_factory_fn=没有)

从给定的参数创建一个 Node,然后返回 Node 包装在 Proxy 对象中。

如果 kind = 'placeholder',那么我们将创建一个 Node,该 Node 表示函数的参数。如果我们需要编码 一个默认参数,我们使用 Tuples。 是 否则,对于 Nodes 为空。argsargsplaceholder

注意

保证此 API 的向后兼容性。

get_fresh_qualname前缀[来源]

获取前缀的新名称并返回它。此功能可确保 它不会与图形上的现有属性冲突。

注意

保证此 API 的向后兼容性。

返回类型

str

getattrattrattr_valparameter_proxy_cache[来源]

指定调用 getattr 时 this 行为的方法 在调用实例时。Tracernn.Module

默认情况下,该行为是返回属性的代理值。它 还将 proxy 值存储在 中,以便将来 calls 将重用代理,而不是创建新代理。parameter_proxy_cache

例如,在以下情况下,可以覆盖此方法以不返回代理 查询参数。

参数
  • attrstr) – 正在查询的属性的名称

  • attr_valAny) – 属性的值

  • parameter_proxy_cacheDict[strAny]) – 代理的 attr 名称缓存

返回

getattr 调用的返回值。

警告

此 API 是实验性的,向后兼容。

is_leaf_modulemmodule_qualified_name[来源]

指定 given 是否为 “leaf” 模块的方法。nn.Module

叶模块是出现在 IR,由调用引用。默认情况下, PyTorch 标准库命名空间 (torch.nn) 中的模块 是叶模块。所有其他模块都通过 和 进行跟踪 除非另有说明,否则将记录其组成操作 通过此参数。call_module

参数
  • mModule) – 被查询的模块

  • module_qualified_namestr) – 此模块的根路径。例如 如果你的模块层次结构中 submodule 包含 submodule ,其中包含 submodule ,该模块将 在此处显示 qualified name 。foobarbazfoo.bar.baz

返回类型

布尔

注意

保证此 API 的向后兼容性。

iterOBJ)
在迭代代理对象时调用,例如

在 Control Flow 中使用时。通常我们不知道该怎么做,因为 我们不知道代理的值,但自定义跟踪器可以附加更多 信息添加到图形节点中create_node,并且可以选择返回迭代器。

注意

保证此 API 的向后兼容性。

返回类型

迭 代

keys对象)
当代理对象被调用时调用了 keys() 方法。

这就是在代理上调用 ** 时发生的情况。这应该会返回一个 iterator it ** 应该在您的自定义跟踪器中工作。

注意

保证此 API 的向后兼容性。

返回类型

任何

path_of_modulemod[来源]

在 Module 层次结构中查找 的限定名称的 Helper 方法 之。例如,如果有一个名为 的子模块,该子模块具有 传入此函数的名为 的 子模块将返回 字符串 “foo.bar”。modrootrootfoobarbar

参数

modstr) – 要检索其限定名称的 。Module

返回类型

str

注意

保证此 API 的向后兼容性。

proxy节点)

注意

保证此 API 的向后兼容性。

返回类型

代理

to_boolOBJ))
当代理对象转换为布尔值(如

在 Control Flow 中使用时。通常我们不知道该怎么做,因为 我们不知道代理的值,但自定义跟踪器可以附加更多 信息添加到图形节点中create_node,并且可以选择返回一个值。

注意

保证此 API 的向后兼容性。

返回类型

布尔

tracerootconcrete_args=None[来源]

跟踪并返回相应的 FX 表示形式。 可以是实例或 Python 可调用对象。rootGraphrootnn.Module

注意,此调用后,可能与传递的 在这里。例如,当 free 函数传递给 时,我们将 创建用作根的实例并添加嵌入的常量 自。self.rootroottrace()nn.Module

参数
  • rootUnion[ModuleCallable]) – 一个或一个函数 追踪。此参数的向后兼容性为 保证。Module

  • concrete_argsOptional[Dict[strany]]) – 应该 不被视为代理。此参数是实验性的,并且 不保证其向后兼容性。

返回

A 表示传入的 .Graphroot

返回类型

注意

保证此 API 的向后兼容性。

torch.fx 中。代理nodetracer=None[来源]

Proxy对象是流经 编程并记录所有操作 ( 函数调用、方法调用、运算符) 进入不断增长的 FX Graph。Nodetorch

如果你正在执行图形转换,则可以将自己的方法包装在 raw 周围,以便可以使用重载的 运算符向 .ProxyNodeGraph

Proxy对象无法迭代。换句话说,象征性的 如果在循环中使用 a 或 AS / 函数参数。Proxy*args**kwargs

有两种主要方法可以解决这个问题: 1. 将不可追踪的 logic 分解为顶级函数,然后 使用。 2. 如果控制流是静态的(即循环跳闸计数为 基于某些超参数),代码可以保留在其原始 position 并重构为如下内容:fx.wrap

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

有关 Proxy 内部结构的更详细说明,请查看 torch/fx/README.md 中的 “Proxy” 部分

注意

保证此 API 的向后兼容性。

torch.fx 中。Interpretermodulegarbage_collect_values=Truegraph=None[来源]

解释器逐个节点执行 FX 图。此模式 可用于许多事情,包括编写代码 转换以及分析通道。

可以重写 Interpreter 类中的方法以自定义 执行行为。可覆盖方法的映射 在调用层次结构方面:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

假设我们想要交换 with 的所有实例,反之亦然(包括它们的方法等价物)。我们可以像这样子类化 Interpreter:torch.negtorch.sigmoidTensor

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数
  • moduletorch.nn.Module) – 要执行的模块

  • garbage_collect_valuesbool) – 是否删除最后一个值 use 在 Module 的执行中。这可确保在 执行。例如,可以禁用此选项以检查所有中间 值。Interpreter.env

  • graphOptional[Graph]) – 如果通过,解释器将执行此 graph 而不是 module.graph,使用提供的 module 参数来满足对 state 的任何请求。

注意

保证此 API 的向后兼容性。

boxed_runargs_list[来源]

通过解释运行 module 并返回结果。这将使用 “boxed” 调用约定,其中传递参数列表,该列表将被清除 由口译员。这可确保及时释放 input 张量。

注意

保证此 API 的向后兼容性。

call_functiontargetargskwargs[来源]

执行一个节点并返回结果。call_function

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回类型

任何

返回

Any:函数调用返回的值

注意

保证此 API 的向后兼容性。

call_methodtargetargskwargs[来源]

执行一个节点并返回结果。call_method

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回类型

任何

返回

Any:方法调用返回的值

注意

保证此 API 的向后兼容性。

call_moduletargetargskwargs[来源]

执行一个节点并返回结果。call_module

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回类型

任何

返回

Any:模块调用返回的值

注意

保证此 API 的向后兼容性。

fetch_args_kwargs_from_envn[来源]

从当前执行环境中获取 和 node 的具体值。argskwargsn

参数

nNode) (节点) – 应为其获取的节点。argskwargs

返回

args和 的具体值。kwargsn

返回类型

元组[元组, dict]

注意

保证此 API 的向后兼容性。

fetch_attr目标[来源]

从 的层次结构中获取属性。Moduleself.module

参数

targetstr) – 要获取的属性的完全限定名称

返回

属性的值。

返回类型

任何

注意

保证此 API 的向后兼容性。

get_attrtargetargskwargs[来源]

执行一个节点。将检索一个属性 值。get_attrModuleself.module

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回

检索到的属性的值

返回类型

任何

注意

保证此 API 的向后兼容性。

map_nodes_to_valuesargsn[来源]

递归 descend through 并查找具体值 对于每个。argsNode

参数
  • argsArgument) – 在其中查找具体值的数据结构

  • nNode) – 所属的节点。这仅用于错误报告。args

返回类型

可选[Union[Tuple[Any, ...], List[Any], Dict[strAny], 切片范围NodestrintfloatboolcomplexdtypeTensor设备memory_format布局OpOverloadSymIntSymBoolSymFloat]]

注意

保证此 API 的向后兼容性。

输出targetargskwargs[来源]

执行一个节点。这真的只是检索 Node 引用的值并返回该值。outputoutput

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回

输出节点引用的返回值

返回类型

任何

注意

保证此 API 的向后兼容性。

placeholdertargetargskwargs[来源]

执行一个节点。请注意,这是有状态的:在 参数传递给 ,并且此方法返回 next() 的 URL 中。placeholderInterpreterrun

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回

检索到的参数值。

返回类型

任何

注意

保证此 API 的向后兼容性。

run*argsinitial_env=enable_io_processing=True[来源]

通过解释运行 module 并返回结果。

参数
  • *args – 要运行的模块的参数,按位置顺序排列

  • initial_envOptional[Dict[NodeAny]]) – 用于执行的可选启动环境。 这是一个将 Node 映射到任何值的 dict。例如,这可用于 为某些 Node 预先填充结果,以便仅在 解释器。

  • enable_io_processingbool) – 如果为 true,则我们使用图形的 process_inputs 处理输入和输出,并且 process_outputs 函数后再使用它们。

返回

执行 Module 返回的值

返回类型

任何

注意

保证此 API 的向后兼容性。

run_noden[来源]

运行特定节点并返回结果。 调用 placeholder、get_attr、call_function、 call_method、call_module 或输出取决于 上nnode.op

参数

nNode) – 要执行的 Node

返回

执行的结果n

返回类型

任何

注意

保证此 API 的向后兼容性。

torch.fx 中。变压器模块[来源]

Transformer是一种特殊类型的解释器,它会生成 新增功能。它公开了一个方法,该方法返回 变换后的 . 不需要 参数运行。 工程 完全象征性地。Moduletransform()ModuleTransformerInterpreterTransformer

假设我们想要交换 with 的所有实例,反之亦然(包括它们的方法等价物)。我们可以像这样子类化:torch.negtorch.sigmoidTensorTransformer

class NegSigmSwapXformer(Transformer):
    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数

moduleGraphModule) – 要转换的。Module

注意

保证此 API 的向后兼容性。

call_functiontargetargskwargs[来源]

注意

保证此 API 的向后兼容性。

返回类型

任何

call_moduletargetargskwargs[来源]

注意

保证此 API 的向后兼容性。

返回类型

任何

get_attrtargetargskwargs[来源]

执行一个节点。在 中,这是 overridden 将新节点插入到输出中 图。get_attrTransformerget_attr

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回类型

代理

注意

保证此 API 的向后兼容性。

placeholdertargetargskwargs[来源]

执行一个节点。在 中,这是 overridden 将 New 插入到输出中 图。placeholderTransformerplaceholder

参数
  • targetTarget) – 此节点的调用目标。请参阅 Node for 有关语义的详细信息

  • argsTuple) – 此调用的位置 args 元组

  • kwargsDict) – 此调用的关键字参数的 Dict

返回类型

代理

注意

保证此 API 的向后兼容性。

transform)[来源]

Transform 并返回转换后的 .self.moduleGraphModule

注意

保证此 API 的向后兼容性。

返回类型

GraphModule

torch.fx 中。replace_patternGMPatternReplacement[来源]

匹配所有可能的非重叠运算符集及其 graphModule 的 Graph 中的 data dependencies () (),然后将每个匹配的子图替换为另一个 子图 ()。patterngmreplacement

参数
返回

表示地点的对象列表 在匹配的原始图表中。列表 如果没有匹配项,则为空。 定义为:MatchpatternMatch

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回类型

列表[匹配]

例子:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上面的代码将首先在 .模式匹配基于 use-def 关系,而不是节点名称。例如,如果你有 in ,则可以在原始函数 尽管变量名称不同 ( vs )。patternforwardtraced_modulep = torch.cat([a, b])patternm = torch.cat([a, b])forwardpm

中的语句 in 根据其 仅值;它可能匹配也可能不匹配 更大的图表。换句话说,模式不必扩展 拖动到较大图表的末尾。returnpatternreturn

当模式匹配时,它将从较大的 函数,并替换为 .如果有多个 匹配项 for 在较大的函数中,每个 match 将被替换。在匹配重叠的情况下,第一个 将替换重叠匹配项集中的 Found Match。 (此处的“First”定义为拓扑排序中的第一个 节点的 use-def 关系。在大多数情况下,第一个节点 是紧跟在 之后的参数,而 last Node 是函数返回的任何值。replacementpatternself

需要注意的一件重要事情是,Callable 的参数必须在 Callable 本身中使用, 并且 Callable 的参数必须匹配 模式。第一条规则是为什么在上面的代码块中,函数有 parameters ,但函数只有 parameters 。 不使用 ,因此它不应指定为参数。 作为第二条规则的示例,请考虑将patternreplacementforwardx, w1, w2patternw1, w2patternxx

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

def replacement(x, y):
    return torch.relu(x)

在这种情况下, 需要相同数量的参数 as(和 ),即使该参数未在 中使用。replacementpatternxyyreplacement

调用 后,生成的 Python 代码如下所示:subgraph_rewriter.replace_pattern

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

保证此 API 的向后兼容性。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源