目录

torchvision.models.feature_extraction

特征提取实用程序让我们可以利用我们的模型来访问中间体 我们输入的转换。这可能对各种 计算机视觉中的应用。仅举几个例子:

  • 可视化特征图。

  • 提取特征以计算面部等任务的图像描述符 识别、复制检测或图像检索。

  • 将所选特征传递到下游子网络进行端到端训练 牢记特定任务。例如,传递特征的层次结构 到具有对象检测头的特征金字塔网络。

Torchvision 为此目的提供了。 它的工作原理大致遵循以下步骤:

  1. 以符号方式跟踪模型以获得 它如何逐步转换输入。

  2. 将用户选择的图形节点设置为输出。

  3. 删除所有冗余节点(输出节点下游的任何内容)。

  4. 从生成的图形生成 python 代码并将其捆绑到 PyTorch 模块与图形本身一起创建。


torch.fx 文档提供了上述过程的更通用和详细的解释,并且 符号描摹的内部运作。

关于节点名称

为了指定哪些节点应该是要提取的输出节点 功能,您应该熟悉此处使用的节点命名约定 (这与 中使用的略有不同)。节点名称为 指定为从顶层遍历模块层次结构的单独路径 module down 到 leaf operation 或 leaf module。例如,在 ResNet-50 中,表示第 4 个块的第 2 个块的 ReLU 输出 层。以下是一些需要记住的细节:torch.fx."layer4.2.relu"ResNet

  • 指定节点名称时,您可以 提供节点名称的截断版本作为快捷方式。要了解如何执行此操作 有效,请尝试创建 ResNet-50 模型并使用 和 请注意,与 相关的最后一个节点是 。可以指定为 return node,或者按照惯例,它指的是最后一个节点 (按执行顺序) 的 .train_nodes, _ = get_graph_node_names(model) print(train_nodes)layer4"layer4.2.relu_2""layer4.2.relu_2""layer4"layer4

  • 如果某个模块或操作重复多次,则节点名称将获取 一个额外的后缀来消除歧义。例如,也许 addition () 操作在同一方法中使用 3 次。然后是 , , 。计数器是 在直接父级的范围内维护。所以在 ResNet-50 中 a 和 a .因为添加 操作位于不同的块中,则不需要 postfix 来 消除歧义。_{int}+forward"path.to.module.add""path.to.module.add_1""path.to.module.add_2""layer4.1.add""layer4.2.add"

示例

以下是我们如何为 MaskRCNN 提取特征的示例:

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork


# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())

# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.

# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
    # node_name: user-specified key for output dict
    'layer1.2.relu_2': 'layer1',
    'layer2.3.relu_2': 'layer2',
    'layer3.5.relu_2': 'layer3',
    'layer4.2.relu_2': 'layer4',
}

# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
    'layer1': 'layer1',
    'layer2': 'layer2',
    'layer3': 'layer3',
    'layer4': 'layer4',
}

# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
#     'layer1': output of layer 1,
#     'layer2': output of layer 2,
#     'layer3': output of layer 3,
#     'layer4': output of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)

# Let's put all that together to wrap resnet50 with MaskRCNN

# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
    def __init__(self):
        super(Resnet50WithFPN, self).__init__()
        # Get a resnet50 backbone
        m = resnet50()
        # Extract 4 main layers (note: MaskRCNN needs this particular name
        # mapping for return nodes)
        self.body = create_feature_extractor(
            m, return_nodes={f'layer{k}': str(v)
                             for v, k in enumerate([1, 2, 3, 4])})
        # Dry run to get number of channels for FPN
        inp = torch.randn(2, 3, 224, 224)
        with torch.no_grad():
            out = self.body(inp)
        in_channels_list = [o.shape[1] for o in out.values()]
        # Build FPN
        self.out_channels = 256
        self.fpn = FeaturePyramidNetwork(
            in_channels_list, out_channels=self.out_channels,
            extra_blocks=LastLevelMaxPool())

    def forward(self, x):
        x = self.body(x)
        x = self.fpn(x)
        return x


# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()

API 参考

torchvision.models.feature_extraction.create_feature_extractor(模型 torch.nn.modules.module.Module, return_nodes: 可选[Union[List[str] Dict[str str]]] = train_return_nodes可选[union[list[str] dict[str str]]] = eval_return_nodes:可选[union[列表[str] Dict[str str]]] = tracer_kwargs: Dict = {}suppress_diff_warning:bool = Falsetorch.fx.graph_module。GraphModule [来源]

创建一个新的图形模块,该模块从给定的 model 作为字典,用户指定的键作为字符串,请求的 outputs作为值。这是通过重写 模型以返回所需的节点作为输出。所有未使用的节点 将与其相应的参数一起删除。

所需的输出节点必须指定为单独的 路径 从顶层模块向下遍历模块层次结构到叶子 operation 或 leaf 模块。有关节点命名约定的更多详细信息 此处使用,请参阅文档中相关小标题.

并非所有模型都是可追踪的 FX,尽管通过一些按摩,它们可以 被要求合作。以下是(并非详尽的)提示列表:

  • 如果您不需要追踪特定的、有问题的 sub-module,通过传递一个 list of 将其转换为“叶子模块”(见下面的示例)。 它不会被追踪,而是结果图 保存对该模块的 forward 方法的引用。leaf_modulestracer_kwargs

  • 同样,您可以通过传递 列表作为 之一(请参阅 示例如下)。autowrap_functionstracer_kwargs

  • 一些内置的 Python 函数可能会出现问题。例如,将在跟踪期间引发错误。您可以将它们包装在 own 函数,然后将其作为 这。intautowrap_functionstracer_kwargs

有关 FX 的更多信息,请参阅 torch.fx 文档

参数
  • 模型NN.Module) – 我们将提取特征的模型

  • return_nodeslistdict可选) – a 或 a 包含名称 (或部分名称 - 请参阅上面的注释) 将返回其激活的节点。如果是 a 中,键是节点名称和值 是用户为图形模块返回的 字典。如果它是 ,则将其视为映射 节点规范字符串直接输出名称。在这种情况下 that 和 被指定, 这不应该被指定。ListDictDictListDicttrain_return_nodeseval_return_nodes

  • train_return_nodeslistdict可选) – 类似于 .如果返回节点 的 train 模式与 EVAL 模式不同。 如果指定了此项,则还必须指定, ,不应指定。return_nodeseval_return_nodesreturn_nodes

  • eval_return_nodeslistdict可选) – 类似于 .如果返回节点 的 train 模式与 EVAL 模式不同。 如果指定了此项,则还必须指定, 和 return_nodes 不应指定。return_nodestrain_return_nodes

  • tracer_kwargsdictoptional) – 一个键工作参数字典(将它们传递给它的父类 torch.fx.Tracer)。NodePathTracer

  • suppress_diff_warningbooloptional) – 是否禁止显示警告 当 Train 和 EVAL 版本之间存在差异时 图表。默认为 False。

例子:

>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>>     [('feat1', torch.Size([1, 64, 56, 56])),
>>>      ('feat2', torch.Size([1, 256, 14, 14]))]

>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>>     # This would raise a TypeError if traced through
>>>     return int(x)
>>>
>>> class LeafModule(torch.nn.Module):
>>>     def forward(self, x):
>>>         # This would raise a TypeError if traced through
>>>         int(x.shape[0])
>>>         return torch.nn.functional.relu(x + 4)
>>>
>>> class MyModule(torch.nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.conv = torch.nn.Conv2d(3, 1, 3)
>>>         self.leaf_module = LeafModule()
>>>
>>>     def forward(self, x):
>>>         leaf_function(x.shape[0])
>>>         x = self.conv(x)
>>>         return self.leaf_module(x)
>>>
>>> model = create_feature_extractor(
>>>     MyModule(), return_nodes=['leaf_module'],
>>>     tracer_kwargs={'leaf_modules': [LeafModule],
>>>                    'autowrap_functions': [leaf_function]})
torchvision.models.feature_extraction.get_graph_node_names(模型torch.nn.modules.module.Module,tracer_kwargsdict = {}suppress_diff_warning:bool = False元组[列表[str] 列表[str]][来源]

Dev 实用程序按执行顺序返回节点名称。请参阅节点上的注释 names 下的 .用于查看哪个节点 名称可用于特征提取。有两个原因 无法轻松地直接从模型的代码中读取节点名称:

  1. 并非所有子模块都被追踪。来自所有 属于这一类。torch.nn

  2. 表示重复应用同一操作的节点 或 leaf 模块获取 postfix。_{counter}

该模型被跟踪两次:一次在 train 模式,一次在 eval 模式。双 返回节点名称集。

有关此处使用的节点命名约定的更多详细信息,请参阅文档中相关子标题

参数
  • 模型NN.Module) – 我们要打印节点名称的模型

  • tracer_kwargsdict可选) –

    一个 keywork 参数的字典(它们最终被传递给 torch.fx.Tracer)。NodePathTracer

  • suppress_diff_warningbooloptional) – 是否禁止显示警告 当 Train 和 EVAL 版本之间存在差异时 图表。默认为 False。

返回

跟踪 中的 Model 的节点名称列表 train 模式,另一个来自 eval 模式下跟踪模型。

返回类型

tuple列表列表)

例子:

>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源