注意
单击此处下载完整的示例代码
(测试版)使用 FX 构建简单的 CPU 性能分析器¶
创建时间: Mar 04, 2021 |上次更新时间:2024 年 1 月 16 日 |上次验证时间:未验证
作者: James Reed
在本教程中,我们将使用 FX 执行以下操作:
以我们可以检查和收集的方式捕获 PyTorch Python 代码 有关代码结构和执行的统计信息
构建一个小类,用作简单的性能 “分析器”, 从 actual 收集有关模型每个部分的运行时统计信息 运行。
在本教程中,我们将使用 torchvision ResNet18 模型 用于演示目的。
import torch
import torch.fx
import torchvision.models as models
rn18 = models.resnet18()
rn18.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
现在我们有了模型,我们想更深入地检查它的 性能。也就是说,对于下面的调用,哪些部分 的模型花费的时间最长?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)
回答这个问题的一种常见方法是通过该程序 source,添加在 程序,并比较这些时间戳之间的差异以查看 时间戳之间的区域需要多长时间。
该技术当然适用于 PyTorch 代码,但是它 如果我们不必复制模型代码并对其进行编辑,那就更好了, 尤其是我们尚未编写的代码(如此 TorchVision 模型)。 相反,我们将使用 FX 来自动化这个 “instrumentation” 进程,而无需修改任何源。
首先,让我们把一些导入(我们将使用所有 这些内容稍后在代码中)。
import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter
注意
tabulate
是一个外部库,不是 PyTorch 的依赖项。
我们将使用它来更轻松地可视化性能数据。请
确保您已从您最喜欢的 Python 包源安装它。
使用符号跟踪捕获模型¶
接下来,我们将使用 FX 的符号跟踪机制来捕获 我们可以操作的数据结构中的模型定义 并检查。
traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)
graph():
%x : torch.Tensor [num_users=1] = placeholder[target=x]
%conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
%relu : [num_users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
%maxpool : [num_users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
%layer1_0_conv1 : [num_users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
%layer1_0_bn1 : [num_users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
%layer1_0_relu : [num_users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
%layer1_0_conv2 : [num_users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
%layer1_0_bn2 : [num_users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
%add : [num_users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
%layer1_0_relu_1 : [num_users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
%layer1_1_conv1 : [num_users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
%layer1_1_bn1 : [num_users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
%layer1_1_relu : [num_users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
%layer1_1_conv2 : [num_users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
%layer1_1_bn2 : [num_users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
%layer1_1_relu_1 : [num_users=2] = call_module[target=layer1.1.relu](args = (%add_1,), kwargs = {})
%layer2_0_conv1 : [num_users=1] = call_module[target=layer2.0.conv1](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_bn1 : [num_users=1] = call_module[target=layer2.0.bn1](args = (%layer2_0_conv1,), kwargs = {})
%layer2_0_relu : [num_users=1] = call_module[target=layer2.0.relu](args = (%layer2_0_bn1,), kwargs = {})
%layer2_0_conv2 : [num_users=1] = call_module[target=layer2.0.conv2](args = (%layer2_0_relu,), kwargs = {})
%layer2_0_bn2 : [num_users=1] = call_module[target=layer2.0.bn2](args = (%layer2_0_conv2,), kwargs = {})
%layer2_0_downsample_0 : [num_users=1] = call_module[target=layer2.0.downsample.0](args = (%layer1_1_relu_1,), kwargs = {})
%layer2_0_downsample_1 : [num_users=1] = call_module[target=layer2.0.downsample.1](args = (%layer2_0_downsample_0,), kwargs = {})
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%layer2_0_bn2, %layer2_0_downsample_1), kwargs = {})
%layer2_0_relu_1 : [num_users=2] = call_module[target=layer2.0.relu](args = (%add_2,), kwargs = {})
%layer2_1_conv1 : [num_users=1] = call_module[target=layer2.1.conv1](args = (%layer2_0_relu_1,), kwargs = {})
%layer2_1_bn1 : [num_users=1] = call_module[target=layer2.1.bn1](args = (%layer2_1_conv1,), kwargs = {})
%layer2_1_relu : [num_users=1] = call_module[target=layer2.1.relu](args = (%layer2_1_bn1,), kwargs = {})
%layer2_1_conv2 : [num_users=1] = call_module[target=layer2.1.conv2](args = (%layer2_1_relu,), kwargs = {})
%layer2_1_bn2 : [num_users=1] = call_module[target=layer2.1.bn2](args = (%layer2_1_conv2,), kwargs = {})
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%layer2_1_bn2, %layer2_0_relu_1), kwargs = {})
%layer2_1_relu_1 : [num_users=2] = call_module[target=layer2.1.relu](args = (%add_3,), kwargs = {})
%layer3_0_conv1 : [num_users=1] = call_module[target=layer3.0.conv1](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_bn1 : [num_users=1] = call_module[target=layer3.0.bn1](args = (%layer3_0_conv1,), kwargs = {})
%layer3_0_relu : [num_users=1] = call_module[target=layer3.0.relu](args = (%layer3_0_bn1,), kwargs = {})
%layer3_0_conv2 : [num_users=1] = call_module[target=layer3.0.conv2](args = (%layer3_0_relu,), kwargs = {})
%layer3_0_bn2 : [num_users=1] = call_module[target=layer3.0.bn2](args = (%layer3_0_conv2,), kwargs = {})
%layer3_0_downsample_0 : [num_users=1] = call_module[target=layer3.0.downsample.0](args = (%layer2_1_relu_1,), kwargs = {})
%layer3_0_downsample_1 : [num_users=1] = call_module[target=layer3.0.downsample.1](args = (%layer3_0_downsample_0,), kwargs = {})
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%layer3_0_bn2, %layer3_0_downsample_1), kwargs = {})
%layer3_0_relu_1 : [num_users=2] = call_module[target=layer3.0.relu](args = (%add_4,), kwargs = {})
%layer3_1_conv1 : [num_users=1] = call_module[target=layer3.1.conv1](args = (%layer3_0_relu_1,), kwargs = {})
%layer3_1_bn1 : [num_users=1] = call_module[target=layer3.1.bn1](args = (%layer3_1_conv1,), kwargs = {})
%layer3_1_relu : [num_users=1] = call_module[target=layer3.1.relu](args = (%layer3_1_bn1,), kwargs = {})
%layer3_1_conv2 : [num_users=1] = call_module[target=layer3.1.conv2](args = (%layer3_1_relu,), kwargs = {})
%layer3_1_bn2 : [num_users=1] = call_module[target=layer3.1.bn2](args = (%layer3_1_conv2,), kwargs = {})
%add_5 : [num_users=1] = call_function[target=operator.add](args = (%layer3_1_bn2, %layer3_0_relu_1), kwargs = {})
%layer3_1_relu_1 : [num_users=2] = call_module[target=layer3.1.relu](args = (%add_5,), kwargs = {})
%layer4_0_conv1 : [num_users=1] = call_module[target=layer4.0.conv1](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_bn1 : [num_users=1] = call_module[target=layer4.0.bn1](args = (%layer4_0_conv1,), kwargs = {})
%layer4_0_relu : [num_users=1] = call_module[target=layer4.0.relu](args = (%layer4_0_bn1,), kwargs = {})
%layer4_0_conv2 : [num_users=1] = call_module[target=layer4.0.conv2](args = (%layer4_0_relu,), kwargs = {})
%layer4_0_bn2 : [num_users=1] = call_module[target=layer4.0.bn2](args = (%layer4_0_conv2,), kwargs = {})
%layer4_0_downsample_0 : [num_users=1] = call_module[target=layer4.0.downsample.0](args = (%layer3_1_relu_1,), kwargs = {})
%layer4_0_downsample_1 : [num_users=1] = call_module[target=layer4.0.downsample.1](args = (%layer4_0_downsample_0,), kwargs = {})
%add_6 : [num_users=1] = call_function[target=operator.add](args = (%layer4_0_bn2, %layer4_0_downsample_1), kwargs = {})
%layer4_0_relu_1 : [num_users=2] = call_module[target=layer4.0.relu](args = (%add_6,), kwargs = {})
%layer4_1_conv1 : [num_users=1] = call_module[target=layer4.1.conv1](args = (%layer4_0_relu_1,), kwargs = {})
%layer4_1_bn1 : [num_users=1] = call_module[target=layer4.1.bn1](args = (%layer4_1_conv1,), kwargs = {})
%layer4_1_relu : [num_users=1] = call_module[target=layer4.1.relu](args = (%layer4_1_bn1,), kwargs = {})
%layer4_1_conv2 : [num_users=1] = call_module[target=layer4.1.conv2](args = (%layer4_1_relu,), kwargs = {})
%layer4_1_bn2 : [num_users=1] = call_module[target=layer4.1.bn2](args = (%layer4_1_conv2,), kwargs = {})
%add_7 : [num_users=1] = call_function[target=operator.add](args = (%layer4_1_bn2, %layer4_0_relu_1), kwargs = {})
%layer4_1_relu_1 : [num_users=1] = call_module[target=layer4.1.relu](args = (%add_7,), kwargs = {})
%avgpool : [num_users=1] = call_module[target=avgpool](args = (%layer4_1_relu_1,), kwargs = {})
%flatten : [num_users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
%fc : [num_users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
return fc
这为我们提供了 ResNet18 模型的 Graph 表示。A 图表
由一系列相互连接的 Node 组成。每个节点
表示 Python 代码中的调用点(无论是对函数、
一个模块或一个方法)和边缘(表示为每个节点和 AND)表示在这些调用站点之间传递的值。更多
有关 Graph 表示和 FX 的其余 API CA 的信息
在 FX 文档 https://pytorch.org/docs/master/fx.html 中找到。args
kwargs
创建性能分析解释器¶
接下来,我们将创建一个继承自 的类。
虽然 that produces 编译 Python 代码
即在调用 a 时运行 ,运行 a 的另一种方法是逐个执行 each。那是
提供的功能:它解释图形节点 -
by-node 的torch.fx.Interpreter
GraphModule
symbolic_trace
GraphModule
GraphModule
Node
Graph
Interpreter
通过继承自 ,我们可以覆盖各种功能和
安装我们想要的分析行为。目标是让对象
我们可以传递一个模型,调用该模型 1 次或多次,然后获取有关
模型和模型的每个部分在这些运行期间花费的时间。Interpreter
让我们定义我们的类:ProfilingInterpreter
class ProfilingInterpreter(Interpreter):
def __init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)
# We are going to store away two things here:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}
######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
def run(self, *args) -> Any:
# Record the time we started running the model
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
return_val = super().run(*args)
# Record the time we finished running the model
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
self.total_runtime_sec.append(t_end - t_start)
return return_val
######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
def run_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
return_val = super().run_node(n)
# Record the time we finished running the op
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
self.runtimes_sec[n].append(t_end - t_start)
return return_val
######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
def summary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
mean_total_runtime = statistics.mean(self.total_runtime_sec)
# For each node, record summary statistics
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])
# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)
注意
我们使用 Python 的函数来拉动挂钟
时间戳并进行比较。这并不是最准确的
衡量性能的方法,并且只会给我们第一个-
order 近似值。我们只将这种简单的技术用于
本教程中的演示目的。time.time
研究 ResNet18 的性能¶
我们现在可以使用 来检查性能
我们的 ResNet18 模型的特征;ProfilingInterpreter
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))
Op type Op Average runtime (s) Pct total runtime
------------- --------------------- --------------------- -------------------
call_module maxpool 0.00571299 8.807
call_module conv1 0.0048542 7.48312
call_module layer4_0_conv2 0.00380135 5.86006
call_module layer1_1_conv2 0.00335789 5.17644
call_module layer4_1_conv1 0.0030942 4.76994
call_module layer4_1_conv2 0.00308824 4.76075
call_module layer1_0_conv1 0.00308061 4.74899
call_module layer1_0_conv2 0.0030272 4.66666
call_module layer2_1_conv2 0.00298619 4.60344
call_module layer1_1_conv1 0.00297284 4.58286
call_module layer3_1_conv2 0.00284076 4.37924
call_module layer2_1_conv1 0.00279307 4.30573
call_module layer3_1_conv1 0.00268817 4.14402
call_module layer3_0_conv2 0.00266528 4.10873
call_module layer2_0_conv2 0.00240898 3.71363
call_module layer4_0_conv1 0.00221109 3.40857
call_module layer3_0_conv1 0.00173402 2.67312
call_module bn1 0.00161076 2.4831
call_module layer2_0_conv1 0.00146675 2.26111
call_module layer2_0_downsample_0 0.00102091 1.57381
call_module layer4_0_downsample_0 0.000473499 0.729935
call_module layer3_0_downsample_0 0.000461102 0.710823
call_function add_1 0.000446558 0.688403
call_function add 0.000440359 0.678847
call_module relu 0.000230074 0.354676
call_module layer1_1_bn2 0.000178814 0.275655
call_module fc 0.000176668 0.272347
call_module layer1_0_bn1 0.000176191 0.271612
call_module layer1_0_bn2 0.000160694 0.247722
call_module layer1_1_bn1 0.000154495 0.238166
call_module avgpool 0.000152826 0.235593
call_module layer2_0_downsample_1 0.000143051 0.220524
call_module layer2_1_bn2 0.00014019 0.216114
call_module layer3_1_bn1 0.000132799 0.20472
call_module layer2_0_bn2 0.000131607 0.202882
call_module layer3_1_bn2 0.0001297 0.199942
call_module layer2_1_bn1 0.000128269 0.197737
call_module layer2_0_bn1 0.000127554 0.196634
call_module layer4_0_bn2 0.000125647 0.193694
call_module layer3_0_bn1 0.000121593 0.187446
call_module layer3_0_bn2 0.000118732 0.183035
call_module layer3_0_downsample_1 0.00011754 0.181197
call_module layer4_1_bn1 0.000117064 0.180462
call_module layer4_0_downsample_1 0.000116587 0.179727
call_module layer4_0_bn1 0.000115871 0.178625
call_module layer4_1_bn2 0.000115871 0.178625
call_module layer1_0_relu 0.000108719 0.167598
call_module layer1_1_relu_1 0.000100374 0.154734
call_module layer1_0_relu_1 9.84669e-05 0.151794
call_module layer1_1_relu 9.48906e-05 0.146281
call_function add_2 8.4877e-05 0.130844
call_function add_3 8.29697e-05 0.127904
call_module layer4_1_relu 8.27312e-05 0.127536
call_module layer4_0_relu 8.17776e-05 0.126066
call_module layer2_1_relu 8.05855e-05 0.124229
call_module layer2_0_relu 7.9155e-05 0.122023
call_module layer2_0_relu_1 7.86781e-05 0.121288
call_module layer2_1_relu_1 7.84397e-05 0.120921
call_module layer3_0_relu 7.70092e-05 0.118716
call_module layer3_1_relu 7.51019e-05 0.115775
call_function add_6 7.51019e-05 0.115775
call_module layer3_0_relu_1 7.22408e-05 0.111365
call_function add_5 7.12872e-05 0.109895
call_module layer4_0_relu_1 7.12872e-05 0.109895
call_module layer3_1_relu_1 7.03335e-05 0.108424
call_function add_7 6.98566e-05 0.107689
call_module layer4_1_relu_1 6.96182e-05 0.107322
call_function add_4 6.93798e-05 0.106954
call_function flatten 4.8399e-05 0.0746107
placeholder x 2.6226e-05 0.0404294
output output 1.83582e-05 0.0283006
我们应该在这里强调两件事:
MaxPool2d
占用的时间最多。这是一个已知问题:https://github.com/pytorch/pytorch/issues/51393BatchNorm2d 也会占用大量时间。我们可以继续这样做 的思路,并在 Conv-BN Fusion with FX 教程中对此进行优化。
结论¶
正如我们所看到的,使用 FX,我们可以轻松捕获 PyTorch 程序(甚至 我们没有源代码的那些!在机器可解释的 格式并将其用于分析,例如性能分析 我们在这里做到了。FX 为以下领域开启了一个充满可能性的激动人心的世界 使用 PyTorch 程序。
最后,由于 FX 仍处于测试阶段,我们很高兴听到任何 您对使用它的反馈。请随意使用 PyTorch 论坛 (https://discuss.pytorch.org/) 和问题跟踪器 (https://github.com/pytorch/pytorch/issues) 提供任何反馈 你可能有。
脚本总运行时间:(0 分 0.425 秒)