torch.fx¶
概述¶
FX 是一个开发人员使用的工具包,用于转换 nn.Module
实例。FX 包含三个主要组件:一个 符号追踪器,
一个 中间表示法 和 Python 代码生成。这些组件的实际演示:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号追踪器执行“符号执行”的Python代码。它通过代码传递假值,称为代理。对这些代理的操作会被记录下来。有关符号追踪的更多信息可以在symbolic_trace()和Tracer文档中找到。
中间表示是符号跟踪期间记录的操作容器。它由一组节点组成,这些节点代表函数输入、调用点(到函数、方法或torch.nn.Module实例)和返回值。有关IR的更多信息,请参阅Graph的文档。IR是应用变换的格式。
Python代码生成 是使得FX成为一个Python到Python(或模块到模块)转换工具包的原因。对于每个图中间表示(Graph IR),我们可以创建与其语义匹配的有效Python代码。此功能被封装在GraphModule中,它是一个torch.nn.Module实例,其中包含一个Graph以及由图生成的forward方法。
这些组件的组合(符号追踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件可以单独使用。例如,符号追踪可以在隔离状态下使用以捕获代码的一种形式,用于分析(而非转换)。代码生成可用于程序化地生成模型,例如从配置文件中生成。FX 有众多用途!
可以在示例库中找到一些转换示例。
写入转换¶
什么是FX变换?本质上,它就像下面这个函数一样。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
你的转换将接收一个 torch.nn.Module,从中获取一个 Graph,
进行一些修改,并返回一个新的 torch.nn.Module。你应该将你的FX转换返回的 torch.nn.Module 视为与普通的 torch.nn.Module 相同——你可以将其传递给另一个FX转换,可以传递给TorchScript,或者可以直接运行它。确保你的FX转换的输入和输出是一个 torch.nn.Module 将允许其组合。
注意
也可以修改现有的 GraphModule 而不是创建一个新的,例如:
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
请注意,您必须调用 GraphModule.recompile() 以使生成的
forward() 方法与修改后的 Graph 同步。
鉴于你已经传入了一个 torch.nn.Module 并将其追踪为一个
Graph,现在有两种主要方法可以用来构建一个新的
Graph。
图论入门简介¶
有关图语义的完整处理可以在Graph文档中找到,但我们在这里会介绍基础知识。一个Graph是一种数据结构,用于表示在GraphModule上的方法。所需的信息是:
该方法的输入是什么?
该方法内部运行哪些操作?
该方法的输出(即返回)值是什么?
这三个概念都用Node实例表示。
让我们通过一个简短的例子来看看我们指的是什么:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
这里我们定义一个模块 MyModule 用于演示目的,实例化它,
符号性地追踪它,然后调用Graph.print_tabular()方法来打印
出一个表格,显示此Graph的节点:
opcode
name
target
args
kwargs
placeholder
x
x
()
{}
get_attr
linear_weight
linear.weight
()
{}
call_function
add_1
<built-in function add>
(x, linear_weight)
{}
call_module
linear_1
linear
(add_1,)
{}
call_method
relu_1
relu
(linear_1,)
{}
call_function
sum_1
<built-in method sum …>
(relu_1,)
{‘dim’: -1}
call_function
topk_1
<built-in method topk …>
(sum_1, 3)
{}
output
output
output
(topk_1,)
{}
我们可以使用这些信息来回答我们上面提出的问题。
该方法的输入是什么?在FX中,方法输入通过特殊的
placeholder节点指定。在这种情况下,我们有一个单一的placeholder节点,其target为x,这意味着我们有一个名为x的单个(非自身)参数。方法中的操作有哪些?
get_attr,call_function,call_module和call_method节点 代表方法中的操作。有关这些操作语义的完整说明,请参阅Node文档。该方法的返回值是什么?在
Graph中,返回值由一个特殊的output节点指定。
鉴于我们现在了解了FX中代码的基本表示方法,我们可以探索如何编辑一个 Graph。
图操作¶
直接图 manipulation¶
构建这个新Graph的一种方法是直接操作旧的
一个。为了帮助实现这一点,我们可以简单地获取从符号
追踪中得到的Graph并进行修改。例如,假设我们希望用
torch.mul()调用来替换
torch.add()调用。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我们还可以进行更复杂的Graph重写,例如删除或添加节点。为了帮助这些转换,
FX 提供了一些用于转换图的工具函数,可以在Graph文档中找到。下面是一个使用这些 API 添加一个torch.relu()调用的示例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
对于仅包含替换的简单转换,您也可以使用子图重写器。
子图重写与replace_pattern()¶
FX 还在直接图操作的基础上提供了另一个级别的自动化。
replace_pattern() API 实质上是一个用于编辑
Graph 的“查找/替换”工具。它允许您指定一个
pattern 和 replacement 函数,并且它会遍历这些函数,在
pattern 图中找到操作实例,并将其替换为 replacement 图的副本。这可以帮助大大自动化繁琐的图操作代码,当转换变得更为复杂时,这种自动化尤为重要。
Proxy/Retracing¶
另一种操作 Graph 的方式是重用符号跟踪中使用的 Proxy 机制。例如,假设我们想编写一个将 PyTorch 函数分解为更小操作的转换。它会将每个 F.relu(x) 调用转换为 (x > 0) * x。一种可能性是在 F.relu 后插入比较和乘法,然后清理原始的 F.relu。然而,我们可以通过使用 Proxy 对象自动记录操作到 Graph 中来自动化这个过程。
要使用此方法,我们将想要插入的操作编写为常规的 PyTorch 代码,并使用Proxy 对象作为参数调用该代码。
这些Proxy 对象将捕获对其执行的操作并将它们附加到Graph 上。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式图操作外,使用Proxys
还可以让你以原生Python代码的形式指定重写规则。
对于需要大量重写规则的转换(例如vmap或grad),这通常可以提高规则的可读性和可维护性。
需要注意的是,在调用Proxy时,我们也传递了一个指向底层变量graph的追踪器。
这样做是为了在图中的操作是n元(例如加法是一个二元运算符)的情况下,
调用Proxy不会创建多个图追踪器实例,从而导致意外的运行时错误。
我们特别推荐在这种方法使用Proxy,尤其是在不能安全假设底层运算是单元的情况下。
解释器模式¶
FX中一个有用的代码组织模式是对Node进行循环,并在Graph中执行它们。这可以用于多种用途,包括对流经图的值进行运行时分析或通过使用Proxy进行代码转换。例如,假设我们希望运行一个GraphModule并记录我们在运行时看到的节点的torch.Tensor形状和数据类型属性。那可能看起来像这样:
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
正如你所见,一个完整的FX解释器并没有那么复杂,
但它可以非常有用。为了方便使用这种模式,我们提供了
Interpreter 类,该类以一种可以通过方法重写来覆盖解释器执行的某些方面的形式封装了上述逻辑。
除了执行操作外,我们还可以通过解释器传入 Proxy 个值来生成一个新的 Graph。
同样地,我们提供了 Transformer 类来涵盖这种模式。
Transformer 的行为类似于 Interpreter,
但与其调用 run 方法从模块中获取具体的输出值,您会调用 Transformer.transform() 方法返回一个新的 GraphModule,该对象会遵循您安装的任何转换规则作为重写方法。
调试¶
介绍¶
在编写转换代码的过程中,我们的代码往往不会完全正确。 在这种情况下,我们需要进行一些调试。关键是倒推:首先,检查调用生成模块的结果以证明或反驳其正确性。然后,检查并调试生成的代码。最后,调试导致生成代码的转换过程。
如果您不熟悉调试器,请参阅辅助部分 可用的调试器。
转换编写中的常见陷阱¶
非确定性
set迭代顺序。在 Python 中,set数据类型是无序的。使用set来包含对象集合(例如Node),可能会导致意外的非确定性。一个例子是遍历一组Node并将它们插入到Graph中。由于set数据类型是无序的,输出程序中的操作顺序将是非确定性的,并且在程序调用之间可能会发生变化。推荐的替代方法是使用dict数据类型,该类型从 Python 3.7(以及 cPython 3.6)开始是插入有序的。dict可以等效地用于集合,通过将需要去重的值存储在dict的键中。
检查模块的正确性¶
因为大多数深度学习模块的输出由浮点数torch.Tensor实例组成,检查两个torch.nn.Module结果之间的等价性并不像简单的相等性检查那样直观。为了说明这一点,让我们来看一个例子:
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
这里,我们尝试使用==等于运算符检查两个深度学习模型的值是否相等。然而,这并不明确,原因有两点:一是该运算符返回的是一个张量而不是布尔值;二是浮点数值的比较应当使用误差范围(或epsilon)来考虑浮点数操作的非交换性(详见这里)。我们可以改用torch.allclose(),它会根据相对和绝对容差阈值给出近似的比较结果:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们在工具箱中的第一个工具,用于检查转换后的模块是否如我们预期的那样工作,与参考实现进行比较。
调试生成的代码¶
因为FX在GraphModule上生成了forward()函数,使用传统的调试技术如print语句或pdb并不那么直接。幸运的是,我们有几种可以用于调试生成代码的技术。
使用 pdb¶
调用 pdb 以进入正在运行的程序。尽管表示 Graph 的代码不在任何源文件中,我们仍然可以在前向传播被调用时使用 pdb 手动进入它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码¶
如果你想多次运行相同的代码,那么使用pdb逐步找到正确的代码可能会有点繁琐。在这种情况下,一种方法是简单地复制粘贴生成的forward传递到你的代码中,并从那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用to_folder函数来自GraphModule¶
GraphModule.to_folder() 是 GraphModule 中的一种方法,允许你将生成的FX代码转储到一个文件夹中。虽然将前向传递复制到代码中通常就足够了(如 打印生成的代码 所示),但使用 to_folder 检查模块和参数可能会更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上述示例后,我们可以查看foo/module.py中的代码并根据需要进行修改(例如添加print语句或使用pdb)以调试生成的代码。
调试转换¶
现在我们已经确定转换过程生成了错误的代码,是时候调试转换本身了。首先,我们将检查文档中的符号跟踪的局限性部分。一旦我们确认跟踪按预期工作,目标就变成了找出在我们的GraphModule转换过程中出了什么问题。可能在编写转换中有快速的答案,但如果没有,有几种方法可以检查我们的跟踪模块:
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上述实用函数,我们可以在应用转换之前和之后比较我们的跟踪模块。有时候,简单的视觉比较就足以追踪到错误。如果仍然不清楚问题出在哪里,像pdb这样的调试器可能是下一步的好选择。
继续以上面的例子为基础,考虑以下代码:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的例子,假设对print(traced)的调用显示我们的转换中存在错误。我们想通过调试器找出问题所在。我们开始一个pdb会话。我们可以通过在transform_graph(traced)处设置断点,然后按s“进入”对transform_graph(traced)的调用来查看转换过程中发生了什么。
我们也可以通过编辑print_tabular方法来打印图中节点的不同属性而获得好运。(例如,我们可能想要查看节点的input_nodes和users。)
可用的调试器¶
最常见的Python调试器是
pdb。你可以通过在命令行中输入
pdb 来启动你的程序的“调试模式”,输入
python -m pdb FILENAME.py,其中 FILENAME
是你想要调试的文件名。之后,你可以使用
pdb 调试命令
逐步运行你的程序。通常的做法是在启动 pdb 时设置一个断点(b LINE-NUMBER),然后调用 c
来运行程序直到该断点。这可以防止你必须逐行执行(使用 s 或 n)才能到达你想检查的代码部分。或者,你可以在想要中断的行前写上
import pdb; pdb.set_trace()。如果你添加了 pdb.set_trace(),你的程序将在运行时自动进入调试模式。(换句话说,你只需在命令行中输入
python FILENAME.py 而不是 python -m pdb FILENAME.py)。一旦你在调试模式下运行文件,你可以逐步执行代码并使用某些命令检查程序的内部状态。网上有很多优秀的
pdb 教程,包括 RealPython 的 “Python 调试与 Pdb”。
像PyCharm或VSCode这样的IDE通常内置有调试器。在你的IDE中,你可以选择:a) 使用pdb,通过在IDE中打开终端窗口(例如,在VSCode中,视图 → 终端),或者 b) 使用内置的调试器(通常是pdb的一个图形化封装)。
符号追踪的局限性¶
FX 使用一种符号跟踪系统(也称为 符号执行)来捕获程序的语义,以便以可转换/可分析的形式表示。该系统是跟踪的,因为它会执行程序(实际上是一个torch.nn.Module 或函数)以记录操作。它是符号化的,因为在执行过程中流经程序的数据不是真实数据,而是符号(在 FX 术语中为 Proxy)。
尽管符号追踪适用于大多数神经网络代码,但它也有一些限制。
动态控制流¶
符号式追踪的主要限制是它目前不支持动态控制流。也就是说,循环或if语句中的条件可能依赖于程序的输入值。
例如,让我们来分析以下程序:
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
条件依赖于if语句的值,该值依赖于x.sum()的值,
而x.sum()又依赖于x的值,这是一个函数输入。由于
x可以改变(例如,如果你传递一个新的输入张量给跟踪的函数),这就是动态控制流。回溯会沿着你的代码向上追溯,向你展示这种情况发生的位置。
静态控制流¶
另一方面,所谓的静态控制流是受支持的。静态控制流是指值在调用之间不能改变的循环或if语句。通常,在PyTorch程序中,这种控制流出现在根据超参数对模型架构进行决策的代码中。具体例子如下:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if语句 if self.do_activation 不依赖于任何函数输入,因此它是静态的。do_activation 可以被视为一个超参数,具有不同参数值的MyModule的不同实例的跟踪代码各不相同。这是符号跟踪支持的有效模式。
许多动态控制流的情况在语义上是静态控制流。这些情况可以通过移除对输入值的数据依赖来支持符号追踪,例如将值移动到Module属性中,或在符号追踪期间将具体值绑定到参数:
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
在真正动态控制流的情况下,包含此代码的程序部分可以被追踪为对方法的调用(参见
使用 Tracer 类自定义追踪) 或函数(参见
wrap()),而不是通过它们进行追踪。
非torch函数¶
FX 使用 __torch_function__ 作为拦截调用的机制(有关更多信息,请参阅 技术概述)。一些函数,例如内置的 Python 函数或 math 模块中的函数,不在 __torch_function__ 的覆盖范围内,但我们仍希望在符号跟踪中捕获它们。例如:
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
错误告诉我们内置函数 len 不被支持。
我们可以通过使用 wrap() API,将此类函数记录在跟踪中作为直接调用:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用Tracer类自定义跟踪¶
Tracer 类是实现 symbolic_trace 的基础类。通过子类化 Tracer,可以自定义跟踪行为,如下所示:
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶模块¶
叶模块是在符号跟踪中作为调用出现的模块,而不是被跟踪通过的模块。默认的叶模块集合是标准的torch.nn模块实例的集合。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以通过覆盖Tracer.is_leaf_module()来定制叶模块集。
杂项¶
张量构造函数(例如
torch.zeros、torch.ones、torch.rand、torch.randn、torch.sparse_coo_tensor) 目前不可追踪。确定性构造函数(
zeros,ones)可以使用, 它们生成的值将作为常量嵌入到跟踪中。只有在这些构造函数的参数引用动态输入大小时,才会出现问题。 在这种情况下,ones_like或zeros_like可能是可行的替代方案。非确定性构造函数(
rand,randn)将在跟踪中嵌入一个随机值。这可能不是预期的行为。一种解决方法是将torch.randn包装在一个torch.fx.wrap函数中并调用该函数。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行为可能会在未来的版本中得到修复。
类型注解
Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -> torch.Tensor) 被支持 并且将通过符号跟踪被保留。Python 2风格的注释类型注解
# type: (torch.Tensor, int) -> torch.Tensor目前不支持。函数中局部名称的注释目前不支持。
在
training个标志和子模块周围抓取在使用像
torch.nn.functional.dropout这样的函数时,通常会将训练参数作为self.training传递进来。在FX跟踪过程中,这很可能会被固化为一个常量值。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
然而,当使用标准的
nn.Dropout()子模块时,训练标志被封装起来——由于nn.Module对象模型的保留——可以被更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
Because of this difference, consider marking modules that interact with the
trainingflag dynamically as leaf modules.
API 参考¶
- torch.fx.symbolic_trace(root, concrete_args=None)[source][source]¶
符号追踪 API
给定一个
nn.Module或函数实例root,此函数将返回一个通过记录在跟踪root时看到的操作而构建的GraphModule。concrete_args允许你部分特化你的函数,无论是为了去除控制流还是数据结构。例如:
def f(a, b): if b == True: return a else: return a * 2
FX 通常无法通过控制流对此进行跟踪。但是,我们可以使用 concrete_args 对 b 的值进行专门化以跟踪此过程:
f = fx.symbolic_trace(f, concrete_args={"b": False}) assert f(3, False) == 6
请注意,尽管您仍然可以传入不同的b值,但它们将被忽略。
我们也可以使用 concrete_args 来消除函数中的数据结构处理。 这将使用 pytrees 来展平您的输入。为了避免过度特化,对于不应特化的值,请传入 fx.PH。例如:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) assert f({"a": 1, "b": 2, "c": 4}) == 7
- Parameters
根 (联合[torch.nn.Module, 可调用对象]) – 要跟踪并转换为图形表示的模块或函数。
concrete_args (可选[字典[字符串, 任意类型]]) – 部分专门化的输入
- Returns
从
root记录的操作创建的模块。- Return type
注意
此 API 的向后兼容性得到保证。
- torch.fx.wrap(fn_or_name)[source][source]¶
此函数可以在模块级别范围内调用,以将 fn_or_name 注册为“叶子函数”。 “叶子函数”在 FX 跟踪中将被保留为 CallFunction 节点,而不会被跟踪通过:
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap("my_custom_function") def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函数也可以用作装饰器:
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
封装的函数可以被认为是一个“叶函数”,类似于“叶模块”的概念,也就是说,它们是在FX跟踪中被保留为调用的函数,而不是被跟踪通过的函数。
- Parameters
fn_or_name (Union[str, Callable]) – 当调用时要插入到图中的全局函数的名称或函数
注意
此 API 的向后兼容性得到保证。
- class torch.fx.GraphModule(*args, **kwargs)[source][source]¶
GraphModule 是从 fx.Graph 生成的 nn.Module。Graphmodule 具有
graph属性,以及从该graph生成的code和forward属性。警告
当
graph被重新赋值时,code和forward将自动生成。 然而,如果你在不重新赋值graph属性本身的情况下编辑graph的内容, 你必须调用recompile()来更新生成的代码。注意
此 API 的向后兼容性得到保证。
- __init__(root, graph, class_name='GraphModule')[source][source]¶
构造一个 GraphModule。
- Parameters
根 (联合[torch.nn.Module, 字典[字符串, 任意类型]) –
root可以是 nn.Module 实例或映射字符串到任意属性类型的字典。 如果root是一个 Module,那么在 Graph 的 Nodes 的target字段中通过限定名称引用的基于 Module 的对象将从root的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。 如果root是一个字典,在 Node 的target中找到的限定名称将直接在字典的键中查找。由字典映射的对象将被复制到 GraphModule 的模块层次结构中的适当位置。图 (图) –
graph包含此 GraphModule 应用于代码生成的节点类名 (字符串) –
name表示此 GraphModule 的名称,用于调试目的。如果未设置,则所有错误消息都将报告为源自GraphModule。将此设置为root的原始名称或在转换上下文中具有意义的名称可能会有所帮助。
注意
此 API 的向后兼容性得到保证。
- add_submodule(target, m)[source][source]¶
将给定的子模块添加到
self。这将安装空模块(如果它们是
target的子路径且尚不存在)。- Parameters
- Returns
- Whether or not the submodule could be inserted. For
此方法返回 True,链中的每个对象 由
target表示的必须要么 a) 尚不存在, 要么 b) 引用一个nn.Module(不是参数或其他属性)
- Return type
注意
此 API 的向后兼容性得到保证。
- delete_all_unused_submodules()[source][source]¶
从
self中删除所有未使用的子模块。如果以下任何一个条件为真,则模块被视为“已使用”: 1. 它有被使用的子模块 2. 它的前向传播直接通过
call_module节点调用 3. 它有一个非模块属性从get_attr节点被使用此方法可以调用以清理一个
nn.Module,而无需手动调用每个未使用的子模块上的delete_submodule。注意
此 API 的向后兼容性得到保证。
- delete_submodule(target)[source][source]¶
从
self中删除给定的子模块。如果
target不是一个有效的目标,模块将不会被删除。- Parameters
目标 (字符串) – 新子模块的完全限定字符串名称 (请参阅
nn.Module.get_submodule中的示例以了解如何指定完全限定字符串。)- Returns
- Whether or not the target string referenced a
要删除的子模块。返回值
False表示target不是对子模块的有效引用。
- Return type
注意
此 API 的向后兼容性得到保证。
- print_readable(print_output=True, include_stride=False, include_device=False, colored=False)[source][source]¶
返回为当前 GraphModule 及其子 GraphModules 生成的 Python 代码
警告
此API为实验性功能,且不向后兼容。
- recompile()[source][source]¶
从其
graph属性重新编译此GraphModule。应在编辑包含的graph之后调用此方法,否则此GraphModule生成的代码将过时。注意
此 API 的向后兼容性得到保证。
- Return type
PythonCode
- to_folder(folder, module_name='FxModule')[source][source]¶
- Dumps out module to
folderwithmodule_nameso that it can be 使用
from <folder> import <module_name>导入Args:
folder (Union[str, os.PathLike]): The folder to write the code out to
- module_name (str): Top-level name to use for the
Modulewhile writing out the code
- module_name (str): Top-level name to use for the
警告
此API为实验性功能,且不向后兼容。
- Dumps out module to
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source][source]¶
Graph是 FX 中间表示中使用的主要数据结构。 它由一系列Node组成,每个代表调用点(或其他语法结构)。将这些Node组合在一起构成一个有效的 Python 函数。例如,以下代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk( torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 ) m = MyModule() gm = torch.fx.symbolic_trace(m)
将生成以下图形:
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1对于在
Graph中表示的操作语义,请参阅Node。注意
此 API 的向后兼容性得到保证。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source][source]¶
构造一个空的图。
注意
此 API 的向后兼容性得到保证。
- call_function(the_function, args=None, kwargs=None, type_expr=None)[source][source]¶
将
call_functionNode插入到Graph中。一个call_function节点 表示对由the_function指定的 Python 可调用对象的调用。- Parameters
the_function (Callable[..., Any]) – 要调用的函数。可以是任何PyTorch操作符、Python函数,或者是
builtins或operator命名空间中的成员。args (可选[元组[参数, ...]]) – 要传递给被调用函数的位置参数。
kwargs (可选[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数
type_expr (可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。
- Returns
新创建并插入的
call_function节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此 API 的向后兼容性得到保证。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[source][source]¶
将
call_methodNode插入到Graph中。一个call_method节点 表示对args的第 0 个元素调用给定的方法。- Parameters
- Returns
新创建并插入的
call_method节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此 API 的向后兼容性得到保证。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[source][source]¶
将一个
call_moduleNode插入到Graph中。一个call_module节点 表示在Module层次结构中调用Module的 forward() 函数。- Parameters
- Returns
新创建并插入的
call_module节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node()相同。注意
此 API 的向后兼容性得到保证。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source][source]¶
创建一个
Node并将其添加到当前插入点的Graph中。 请注意,当前插入点可以通过Graph.inserting_before()和Graph.inserting_after()设置。- Parameters
操作 (字符串) – 此节点的操作码。可以是 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’ 中的一个。这些操作码的语义在
Graph的文档字符串中有所描述。args (可选[元组[参数, ...]]) – 是传递给此节点的参数元组。
kwargs (可选[Dict[str, Argument]]) – 此节点的kwargs
名称 (可选[字符串]) – 一个可选的字符串名称,用于
Node。 这将影响在生成的Python代码中赋值的变量名称。type_expr (可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。
- Returns
新创建并插入的节点。
- Return type
注意
此 API 的向后兼容性得到保证。
- eliminate_dead_code(is_impure_node=None)[source][source]¶
根据每个节点的用户数量以及节点是否有任何副作用,从图中删除所有无效代码。在调用之前,图必须进行拓扑排序。
- Parameters
- Returns
图是否由于传递而发生了变化。
- Return type
Example:
在消除死代码之前,a 来自 a = x + 1 下面没有使用者, 因此可以从图中消除而不会产生影响。
def forward(self, x): a = x + 1 return x + self.attr_1
消除无效代码后,a = x + 1 已被移除,forward 的其余部分仍然保留。
def forward(self, x): return x + self.attr_1
警告
死代码消除有一些启发式方法来避免移除有副作用的节点(见 Node.is_impure),但总体来说覆盖率非常差,因此你应该假设这种方法在调用时是不安全的,除非你知道你的 FX 图完全由函数操作组成,或者你提供自己的自定义函数来检测有副作用的节点。
注意
此 API 的向后兼容性得到保证。
- erase_node(to_erase)[source][source]¶
从
Graph中删除一个Node。如果该节点在Graph中仍有使用者,则抛出异常。- Parameters
要擦除的 (节点) – 从
Graph中擦除Node。
注意
此 API 的向后兼容性得到保证。
- find_nodes(*, op, target=None, sort=True)[source][source]¶
允许快速查询节点
- Parameters
- Returns
包含请求的操作和目标的节点的可迭代对象。
警告
此API为实验性功能,且不向后兼容。
- get_attr(qualified_name, type_expr=None)[source][source]¶
将一个
get_attr节点插入到图中。一个get_attrNode表示从Module层次结构中获取属性。- Parameters
限定名称 (字符串) – 要检索的属性的完全限定名称。 例如,如果追踪的模块有一个名为
foo的子模块,该子模块有一个名为bar的子模块,该子模块有一个名为baz的属性,则应将限定名称foo.bar.baz作为qualified_name传递。type_expr (可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。
- Returns
新创建并插入的
get_attr节点。- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node相同。注意
此 API 的向后兼容性得到保证。
- graph_copy(g, val_map, return_output_node=False)[source][source]¶
将给定图中的所有节点复制到
self。- Parameters
- Returns
The value in
self的值现在等同于g中的输出值, 如果g有一个output节点。否则为None。- Return type
可选[联合[元组[可选[联合[元组[参数, …], 序列[参数], 映射[字符串, 参数], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]], …], 序列[可选[联合[元组[参数, …], 序列[参数], 映射[字符串, 参数], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]]], 映射[字符串, 可选[联合[元组[参数, …], 序列[参数], 映射[字符串, 参数], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]]], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]]
注意
此 API 的向后兼容性得到保证。
- inserting_after(n=None)[source][source]¶
- Set the point at which create_node and companion methods will insert into the graph.
当在 'with' 语句中使用时,这将临时设置插入点,然后在 with 语句退出时恢复它:
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
Args:
- n (Optional[Node]): The node before which to insert. If None this will insert after
the beginning of the entire graph.
- Returns:
一个资源管理器,它将在
__exit__上恢复插入点。
注意
此 API 的向后兼容性得到保证。
- inserting_before(n=None)[source][source]¶
- Set the point at which create_node and companion methods will insert into the graph.
当在 'with' 语句中使用时,这将临时设置插入点,然后在 with 语句退出时恢复它:
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
Args:
- n (Optional[Node]): The node before which to insert. If None this will insert before
the beginning of the entire graph.
- Returns:
一个资源管理器,它将在
__exit__上恢复插入点。
注意
此 API 的向后兼容性得到保证。
- lint()[source][source]¶
对这个图运行各种检查,以确保其结构良好。特别是: - 检查节点是否有正确的所有权(属于此图) - 检查节点是否按拓扑顺序出现 - 如果此图有 owning GraphModule,检查 targets 是否存在于该 GraphModule 中
注意
此 API 的向后兼容性得到保证。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[source][source]¶
将一个节点从一个图复制到另一个图。
arg_transform需要将参数从节点的图转换为自身的图。示例:# Copying all the nodes in `g` into `new_graph` g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
- Parameters
- Return type
注意
此 API 的向后兼容性得到保证。
- property nodes: _node_list¶
获取构成此图的节点列表。
请注意,此
Node列表表示形式是一个双向链表。迭代期间的更改(例如删除节点、添加节点)是安全的。- Returns
双向链表的节点。请注意,可以在此列表上调用
reversed以切换迭代顺序。
- on_generate_code(make_transformer)[source][source]¶
注册一个转换函数,当生成Python代码时
- Args:
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
a function that returns a code transformer to be registered. This function is called by on_generate_code to obtain the code transformer.
This function is also given as its input the currently registered code transformer (or None if nothing is registered), in case it is not desirable to overwrite it. This is useful to chain code transformers together.
- Returns:
a context manager that when used in a with statement, to automatically restore the previously registered code transformer.
Example:
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb(current_trans(body) if current_trans else body) ) ) gm.recompile() gm(*inputs) # drops into pdb
This function can also be used as a context manager, with the benefit to automatically restores the previously registered code transformer:
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此API为实验性功能,且不向后兼容。
- output(result, type_expr=None)[source][source]¶
在
outputNode中插入一个Graph。一个output节点表示 Python 代码中的return语句。result是应该返回的值。- Parameters
结果 (参数) – 要返回的值。
type_expr (可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。
注意
此方法的插入点和类型表达式规则与
Graph.create_node相同。注意
此 API 的向后兼容性得到保证。
- placeholder(name, type_expr=None, default_value)[source][source]¶
将一个
placeholder节点插入图中。一个placeholder表示一个函数输入。- Parameters
名称 (字符串) – 输入值的名称。这对应于该函数的位置参数名称,此
Graph表示。type_expr (可选[任意])– 一个可选的类型注释,表示此节点输出的Python类型。在某些情况下,为了正确生成代码(例如,当函数在TorchScript编译中后续使用时),这是必需的。
默认值 (任意类型) – 该函数参数应采用的默认值。注意:为了允许将 None 作为默认值,应该传递 inspect.Signature.empty 作为此参数,以指定该参数 _没有_ 默认值。
- Return type
注意
此方法的插入点和类型表达式规则与
Graph.create_node相同。注意
此 API 的向后兼容性得到保证。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[source][source]¶
将此
Graph转换为有效的 Python 代码。- Parameters
根模块 (str) – 用于查找限定名称目标的根模块的名称。这通常是‘self’。
- Returns
src: 代表对象的Python源代码 globals: 字典中的全局名称 src -> 它们所引用的对象。
- Return type
一个 PythonCode 对象,包含两个字段
注意
此 API 的向后兼容性得到保证。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source][source]¶
Node是表示深度学习框架中各个操作的数据结构。大多数情况下,节点代表对各种实体的调用点,例如操作符、方法和模块(一些例外包括指定函数输入和输出的节点)。每个Node都有一个由其op属性指定的函数。每个op值的Node语义如下:placeholder表示函数输入。name属性指定该值将采用的名称。target是参数的名称。args包含:1)什么都没有,或 2)表示函数输入默认参数的单个参数kwargs是无关紧要的。占位符对应于图输出中的函数参数(例如x)。get_attr从模块层次结构中检索参数。name类似地是获取结果被赋值的名称。target是参数在模块层次结构中的完全限定名称。args和kwargs是无关紧要的call_function应用一个自由函数到某些值。name类似地是赋值的目标名称。target是要应用的函数。args和kwargs表示函数的参数,遵循 Python 调用约定call_module在模块层次结构的forward()方法中应用一个模块到给定的参数。name如之前所述。target是要调用的模块在模块层次结构中的完全限定名称。args和kwargs表示调用模块时的参数,不包括 self 参数。call_method调用值上的方法。name类似。target是要应用于self参数的方法的字符串名称。args和kwargs表示调用模块时的参数, 包括 self 参数output包含其args[0]属性中的跟踪函数的输出。这对应于图形打印输出中的“返回”语句。
注意
此 API 的向后兼容性得到保证。
- property all_input_nodes: List[Node]¶
返回作为此节点输入的所有节点。这相当于迭代
args和kwargs并仅收集那些是节点的值。- Returns
列表中的
Nodes出现在此Node的args和kwargs中,按此顺序。
- append(x)[source][source]¶
在图的节点列表中,此节点后插入
x。 等同于self.next.prepend(x)- Parameters
x (节点) – 将此节点之后放置的节点。必须是同一图的成员。
注意
此 API 的向后兼容性得到保证。
- property args: Tuple[Optional[Union[Tuple[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], ...], Sequence[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[str, Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], ...]¶
此
Node的参数元组。参数的解释取决于节点的操作码。请参阅Node的文档字符串以获取更多信息。允许对这一属性进行赋值。所有使用情况和用户的记录将在赋值时自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[source][source]¶
返回
self的描述性字符串表示。此方法可以用作调试工具,无需参数调用。
此函数也在
__str__方法中内部使用 于Graph。一起,placeholder_names和maybe_return_typename中的字符串构成了这个图的周围 GraphModule中自动生成的forward函数的签名。placeholder_names和maybe_return_typename不应在其他情况下使用。- Parameters
- Returns
- If 1) we’re using
format_nodeas an internal helper 在
__str__方法的Graph中,以及 2)self是一个占位符节点,返回None。否则, 返回当前节点的描述性字符串表示。
- If 1) we’re using
- Return type
注意
此 API 的向后兼容性得到保证。
- insert_arg(idx, arg)[source][source]¶
将一个位置参数插入到给定索引的参数列表中。
- Parameters
idx (int) – 要在
self.args中插入的元素的索引。参数 (参数) – 要插入到
args的新参数值
注意
此 API 的向后兼容性得到保证。
- is_impure()[source][source]¶
返回此操作是否为不纯的,即如果其操作是占位符或输出,或者如果调用的是不纯的 call_function 或 call_module。
- Returns
如果操作是不纯的或不是。
- Return type
警告
此API为实验性功能,且不向后兼容。
- property kwargs: Dict[str, Optional[Union[Tuple[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], ...], Sequence[Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], Mapping[str, Optional[Union[Tuple[Argument, ...], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]]¶
此
Node的关键字参数字典。参数的解释取决于节点的 opcode。更多信息请参见Node的文档字符串。允许对这一属性进行赋值。所有使用情况和用户的记录将在赋值时自动更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source][source]¶
返回标准化的参数给Python目标。这意味着 args/kwargs 将与模块/功能的签名匹配,并按位置顺序返回仅关键字参数 如果 normalize_to_only_use_kwargs 为真。 同时填充默认值。不支持仅位置参数或可变参数。
支持模块调用。
可能需要arg_types和kwarg_types来消除重载的歧义。
- Parameters
根 (torch.nn.Module) – 用于解析模块目标的模块。
arg_types (可选[Tuple[任何类型]]) – 参数类型的元组
kwarg_types (可选[字典[字符串, 任意类型]]) – 关键字参数类型的字典
normalize_to_only_use_kwargs (布尔值) – 是否仅使用关键字参数进行标准化。
- Returns
返回 NamedTuple ArgsKwargsPair,如果未成功则返回 None。
- Return type
可选[ArgsKwargsPair]
警告
此API为实验性功能,且不向后兼容。
- prepend(x)[source][source]¶
在图的节点列表中此节点之前插入 x。示例:
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- Parameters
x (节点) – 将此节点之前的节点。必须是同一图的成员。
注意
此 API 的向后兼容性得到保证。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source][source]¶
将图中的所有
self替换为节点replace_with。- Parameters
- Returns
对此更改进行的节点列表。
- Return type
注意
此 API 的向后兼容性得到保证。
- replace_input_with(old_input, new_input)[source][source]¶
循环遍历输入节点
self,并将所有实例中的old_input替换为new_input。注意
此 API 的向后兼容性得到保证。
- property stack_trace: Optional[str]¶
返回在跟踪期间记录的Python堆栈跟踪(如果有)。 当使用fx.Tracer进行跟踪时,此属性通常由Tracer.create_proxy填充。要为调试目的记录跟踪期间的堆栈跟踪, 请在Tracer实例上设置record_stack_traces = True。 当使用dynamo进行跟踪时,默认情况下此属性将由OutputGraph.create_proxy填充。
堆栈跟踪会在字符串的末尾显示最内层的帧。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source][source]¶
Traceris the class that implements the symbolic tracing functionality oftorch.fx.symbolic_trace. A call tosymbolic_trace(m)is equivalent toTracer().trace(m).Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on this class.
注意
此 API 的向后兼容性得到保证。
- call_module(m, forward, args, kwargs)[source][source]¶
指定此
Tracer在遇到对nn.Module实例的调用时的行为的方法。默认情况下,行为是检查被调用的模块是否为叶模块 通过
is_leaf_module。如果是,则发出一个call_module节点,引用m在Graph中。否则,正常调用Module,跟踪其forward函数中的操作。此方法可以被重写以实现例如创建嵌套的追踪GraphModules,或在跨越
Module边界时实现任何你希望的行为。- Parameters
m (模块) – 正在发出调用的模块
forward (可调用对象) – 要调用的
Module的 forward() 方法参数 (元组) – 模块调用站点的参数
kwargs (字典) – 模块调用站点的kwargs
- Returns
模块调用的返回值。如果发出了一个
call_module节点,则这是一个Proxy值。否则,它是从Module调用返回的任何值。- Return type
注意
此 API 的向后兼容性得到保证。
- create_arg(a)[source][source]¶
在准备将值用作
Graph中的节点参数时,指定跟踪行为的方法。默认行为包括:
遍历集合类型(例如元组、列表、字典)并递归地对元素调用
create_args。给定一个代理对象,返回对底层IR的引用
Node给定一个非代理张量对象,为各种情况发出IR:
For a Parameter, emit a
get_attrnode referring to that ParameterFor a non-Parameter Tensor, store the Tensor away in a special attribute referring to that attribute.
此方法可以被重写以支持更多类型。
- Parameters
a (任意) – 要作为
Argument在Graph中发出的值。- Returns
该值
a转换为适当的Argument- Return type
参数
注意
此 API 的向后兼容性得到保证。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source][source]¶
创建
placeholder个节点,对应于root模块的签名。此方法检查根签名并相应地生成这些节点,同时支持*args和**kwargs。警告
此API为实验性功能,且不向后兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]¶
根据目标、参数、关键字参数和名称插入一个图节点。
此方法可以被重写以进行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止将就地操作记录下来。
注意
此 API 的向后兼容性得到保证。
- Return type
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]¶
从给定的参数创建一个节点,然后返回包装在代理对象中的节点。
如果 kind = 'placeholder',那么我们正在创建一个表示函数参数的节点。如果我们需要编码默认参数,我们使用
args元组。args对于placeholder节点来说是空的。注意
此 API 的向后兼容性得到保证。
- get_fresh_qualname(prefix)[source][source]¶
获取前缀的新鲜名称并返回它。此函数确保它不会与图上现有的属性冲突。
注意
此 API 的向后兼容性得到保证。
- Return type
- getattr(attr, attr_val, parameter_proxy_cache)[source][source]¶
当我们在调用
Tracer的 getattr 方法时,指定此方法的行为 在调用nn.Module实例时。默认情况下,行为是为属性返回一个代理值。它还把代理值存储在
parameter_proxy_cache中,以便将来调用时重用该代理而不是创建一个新的代理。此方法可以被重写,例如,在查询参数时不返回代理。
- Parameters
- Returns
getattr 调用的返回值。
警告
此API为实验性功能,且不向后兼容。
- is_leaf_module(m, module_qualified_name)[source][source]¶
指定给定
nn.Module是否为“叶子”模块的方法。叶模块是出现在 IR 中的原子单元,由
call_module调用引用。默认情况下, PyTorch 标准库命名空间(torch.nn)中的模块 是叶模块。所有其他模块都会被追踪并通过 记录其组成操作,除非通过此参数另行指定。- Parameters
- Return type
注意
此 API 的向后兼容性得到保证。
- iter(obj)[source]¶
- Called when a proxy object is being iterated over, such as
当在控制流中使用时。通常我们不知道该做什么,因为我们不知道代理的值,但是自定义跟踪器可以使用 create_node 将更多信息附加到图节点上,并可以选择返回一个迭代器。
注意
此 API 的向后兼容性得到保证。
- Return type
- keys(obj)[source]¶
- Called when a proxy object is has the keys() method called.
这是当在代理上调用**时发生的情况。这应该返回一个迭代器,**应该在你的自定义追踪器中工作。
注意
此 API 的向后兼容性得到保证。
- Return type
- path_of_module(mod)[source][source]¶
辅助方法,用于在
mod的模块层次结构中找到合格名称root。例如,如果root有一个名为foo的子模块, 该子模块又有一个名为bar的子模块,将bar传递给此函数将返回字符串“foo.bar”。注意
此 API 的向后兼容性得到保证。
- to_bool(obj)[source]¶
- Called when a proxy object is being converted to a boolean, such as
当在控制流中使用时。通常我们不知道该做什么,因为我们不知道代理的值,但是自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。
注意
此 API 的向后兼容性得到保证。
- Return type
- class torch.fx.Proxy(node, tracer=None)[source][source]¶
Proxy个对象是Node个包装器,在符号跟踪过程中流经程序并记录它们所接触的所有操作(torch函数调用、方法调用、操作符)到不断增长的 FX 图中。如果你正在进行图变换,你可以围绕原始的
Proxy方法编写自己的方法, 这样你可以使用重载的操作符向Graph中添加额外的内容。Proxy对象无法被迭代。换句话说,如果在循环中使用Proxy或将其作为*args/**kwargs函数参数,符号跟踪器将会抛出错误。有两种主要的方法可以解决这个问题: 1. 将无法追踪的逻辑提取到顶级函数中,并对其使用
fx.wrap。 2. 如果控制流是静态的(即循环次数基于某些超参数),则可以将代码保留在其原始位置,并重构为类似以下内容:for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有关代理内部结构的更详细描述,请参阅torch/fx/README.md中的“代理”部分
注意
此 API 的向后兼容性得到保证。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source][source]¶
解释器逐节点执行FX图。这种模式对于许多事情都很有用,包括编写代码转换以及分析过程。
Interpreter 类中的方法可以被重写以自定义执行行为。可重写方法的调用层次结构图:
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们想要将所有
torch.neg实例与torch.sigmoid互换,反之亦然(包括它们的Tensor方法等效项)。我们可以像这样继承 Interpreter:class NegSigmSwapInterpreter(Interpreter): def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- Parameters
模块 (torch.nn.Module) – 要执行的模块
garbage_collect_values (布尔值) – 是否在模块执行过程中在其最后一次使用后删除值。这确保了执行期间的最佳内存使用。可以禁用此功能,例如,通过查看
Interpreter.env属性来检查执行中的所有中间值。图 (可选[Graph]) – 如果传递了此参数,解释器将执行此图而不是module.graph, 使用提供的module参数来满足任何状态请求。
注意
此 API 的向后兼容性得到保证。
- boxed_run(args_list)[source][source]¶
运行 module 通过解释并返回结果。这使用“boxed”调用约定,您传递一个参数列表,解释器将清除这些参数。这确保输入张量被及时释放。
注意
此 API 的向后兼容性得到保证。
- call_function(target, args, kwargs)[source][source]¶
执行一个
call_function节点并返回结果。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Return type
- Return
任意:函数调用返回的值
注意
此 API 的向后兼容性得到保证。
- call_method(target, args, kwargs)[source][source]¶
执行一个
call_method节点并返回结果。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Return type
- Return
任意:方法调用返回的值
注意
此 API 的向后兼容性得到保证。
- call_module(target, args, kwargs)[source][source]¶
执行一个
call_module节点并返回结果。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Return type
- Return
任意:模块调用返回的值
注意
此 API 的向后兼容性得到保证。
- fetch_args_kwargs_from_env(n)[source][source]¶
从当前执行环境中获取节点
n的args和kwargs的具体值。- Parameters
n (节点) – 应获取
args和kwargs的节点。- Returns
args和kwargs与具体的n值。- Return type
元组[元组, 字典]
注意
此 API 的向后兼容性得到保证。
- fetch_attr(target)[source][source]¶
从
Module层次结构中的self.module获取一个属性。- Parameters
目标 (字符串) – 要获取的属性的完全限定名称
- Returns
属性的值。
- Return type
任何
注意
此 API 的向后兼容性得到保证。
- get_attr(target, args, kwargs)[source][source]¶
执行一个
get_attr节点。将从Module层次结构的self.module中检索属性值。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Returns
检索到的属性的值
- Return type
任何
注意
此 API 的向后兼容性得到保证。
- map_nodes_to_values(args, n)[source][source]¶
递归地向下遍历
args并在当前执行环境中查找每个Node的具体值。- Parameters
参数 (参数) – 用于查找具体值的数据结构
n (节点) –
args所属的节点。这仅用于错误报告。
- Return type
可选[联合[元组[可选[联合[元组[参数, …], 序列[参数], 映射[字符串, 参数], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]], …], 序列[可选[联合[元组[参数, …], 序列[参数], 映射[字符串, 参数], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]]], 映射[字符串, 可选[联合[元组[参数, …], 序列[参数], 映射[字符串, 参数], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]]], 切片, 范围, 节点, 字符串, 整数, 浮点数, 布尔值, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 操作重载, 符号整数, 符号布尔值, 符号浮点数]]
注意
此 API 的向后兼容性得到保证。
- output(target, args, kwargs)[source][source]¶
执行一个
output节点。这实际上只是检索由output节点引用的值并返回它。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Returns
输出节点所引用的返回值
- Return type
任何
注意
此 API 的向后兼容性得到保证。
- placeholder(target, args, kwargs)[source][source]¶
执行一个
placeholder节点。请注意这是有状态的:Interpreter保持传递给run的参数的内部迭代器,并且此方法返回该迭代器上的 next()。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Returns
检索到的参数值。
- Return type
任何
注意
此 API 的向后兼容性得到保证。
- class torch.fx.Transformer(module)[source][source]¶
Transformer是一种特殊的解释器,它生成一个新的Module。它暴露了一个transform()方法,该方法返回转换后的Module。Transformer运行时不需要参数,而Interpreter需要。Transformer完全以符号方式工作。示例
假设我们想要将所有
torch.neg的实例与torch.sigmoid互换(包括它们的Tensor方法等价形式)。我们可以像这样子类化Transformer:class NegSigmSwapXformer(Transformer): def call_function( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- Parameters
模块 (图模块) – 要转换的
Module。
注意
此 API 的向后兼容性得到保证。
- get_attr(target, args, kwargs)[source][source]¶
执行一个
get_attr节点。在Transformer中,这被重写以将一个新的get_attr节点插入输出图中。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Return type
注意
此 API 的向后兼容性得到保证。
- placeholder(target, args, kwargs)[source][source]¶
执行一个
placeholder节点。在Transformer中,这被重写以在输出图中插入一个新的placeholder。- Parameters
目标 (目标) – 此节点的调用目标。详见 节点 的语义说明
args (元组) – 此调用的位置参数元组
kwargs (字典) – 此调用的关键字参数的字典
- Return type
注意
此 API 的向后兼容性得到保证。
- torch.fx.replace_pattern(gm, pattern, replacement)[source][source]¶
匹配所有可能的非重叠运算符集及其数据依赖(
pattern)在GraphModule的图中(gm),然后将每个匹配的子图替换为另一个子图(replacement)。- Parameters
gm (图模块) – 包装了用于操作的图的图模块
模式 (Union[Callable, GraphModule]) – 要在
gm中匹配以进行替换的子图替换 (Union[Callable, GraphModule]) – 用于替换
pattern的子图
- Returns
一个包含
Match个对象的列表,表示在原始图中与pattern匹配的位置。如果没有匹配,则列表为空。Match定义为:class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- Return type
List[Match]
Examples:
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述代码将首先匹配
pattern在forward方法中的traced_module。模式匹配是基于使用-定义关系,而不是节点名称。例如,如果你在p = torch.cat([a, b])中有pattern,你可以匹配原始forward函数中的m = torch.cat([a, b]),尽管变量名称不同(pvsm)。The
return语句在pattern中仅根据其值进行匹配;它可能与较大图中的return语句匹配,也可能不匹配。换句话说,模式不必扩展到较大图的末尾。当匹配模式时,它将从较大的函数中移除并替换为
replacement。如果在较大的函数中有多个pattern的匹配项,则每个非重叠匹配项都将被替换。在匹配项重叠的情况下,将在重叠匹配项集合中找到的第一个匹配项将被替换。(这里的“第一个”是指根据节点使用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的内容。)值得注意的是,
patternCallable 的参数必须在 Callable 本身中使用, 并且replacementCallable 的参数必须匹配模式。第一个规则解释了为什么在上面的代码块中,forward函数有参数x, w1, w2,而pattern函数只有参数w1, w2。pattern不使用x,因此不应将其指定为参数。 作为第二个规则的一个例子,考虑替换def pattern(x, y): return torch.neg(x) + torch.relu(y)
与
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement需要与pattern相同数量的参数 (x和y),尽管参数y在replacement中未被使用。调用
subgraph_rewriter.replace_pattern后,生成的 Python 代码如下所示:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
此 API 的向后兼容性得到保证。