目录

优化 Vision Transformer 模型以进行部署

创建时间: 2021年3月15日 |上次更新时间:2024 年 1 月 19 日 |上次验证: Nov 05, 2024

Jeff Tang吉塔·乔汉

Vision Transformer 模型应用了最前沿的基于注意力的模型 transformer 模型,在 自然语言处理 中引入以实现 各种最先进的 (SOTA) 结果,到计算机视觉 任务。Facebook Data-efficient Image Transformers DeiT 是在 ImageNet 上训练的图像 Vision Transformer 模型 分类。

在本教程中,我们将首先介绍什么是 DeiT 以及如何使用它, 然后完成脚本编写、量化、优化、 以及在 iOS 和 Android 应用程序中使用模型。我们还将比较 量化、优化和非量化、非优化的性能 模型,并展示应用量化和优化的好处 到模型。

什么是 DeiT

卷积神经网络 (CNN) 一直是图像的主要模型 分类,但 CNN 通常 需要数亿张图像进行训练才能实现 SOTA 结果。DeiT 是一个视觉转换器模型,它需要的要少得多 用于培训的数据和计算资源,以与领先的 CNN 执行图像分类,这可以通过两个 DeiT 的关键组成部分:

  • 在更大的数据集上模拟训练的数据增强;

  • 允许变压器网络学习的原生蒸馏 a CNN 的输出。

DeiT 表明 Transformers 可以成功地应用于计算机 Vision 任务,对数据和资源的访问权限有限。了解更多 有关 DeiT 的详细信息,请参阅 repopaper

使用 DeiT 对图像进行分类

请关注 DeiT 存储库,了解有关如何 使用 DeiT 对图像进行分类,或者为了快速测试,请先安装 所需软件包:README.md

pip install torch torchvision timm pandas requests

要在 Google Colab 中运行,请通过运行以下命令来安装依赖项:

!pip install timm pandas requests

然后运行以下脚本:

from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.5.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:

Importing from timm.models.registry is deprecated, please import via timm.models

/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:

Importing from timm.models.layers is deprecated, please import via timm.layers

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:

Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:

Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:

Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:

Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:

Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:

Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:

Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:

Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth

  0%|          | 0.00/330M [00:00<?, ?B/s]
  6%|6         | 20.5M/330M [00:00<00:01, 214MB/s]
 13%|#2        | 41.8M/330M [00:00<00:01, 219MB/s]
 19%|#9        | 63.0M/330M [00:00<00:01, 220MB/s]
 26%|##5       | 84.2M/330M [00:00<00:01, 221MB/s]
 32%|###1      | 106M/330M [00:00<00:01, 221MB/s]
 38%|###8      | 127M/330M [00:00<00:00, 221MB/s]
 45%|####4     | 148M/330M [00:00<00:00, 222MB/s]
 51%|#####1    | 169M/330M [00:00<00:00, 221MB/s]
 58%|#####7    | 190M/330M [00:00<00:00, 222MB/s]
 64%|######4   | 212M/330M [00:01<00:00, 222MB/s]
 71%|#######   | 233M/330M [00:01<00:00, 222MB/s]
 77%|#######6  | 254M/330M [00:01<00:00, 222MB/s]
 83%|########3 | 275M/330M [00:01<00:00, 222MB/s]
 90%|########9 | 297M/330M [00:01<00:00, 222MB/s]
 96%|#########6| 318M/330M [00:01<00:00, 222MB/s]
100%|##########| 330M/330M [00:01<00:00, 221MB/s]
269

输出应为 269,根据类的 ImageNet 列表 索引到标签文件,映射到 。timber wolf, grey wolf, gray wolf, Canis lupus

现在我们已经验证了我们可以使用 DeiT 模型进行分类 图像,让我们看看如何修改模型,使其可以在 iOS 上运行,并且 Android 应用程序。

脚本设计

要在移动设备上使用该模型,我们首先需要编写 型。请参阅 Script and Optimize 配方中的 快速概述。运行下面的代码,将 previous step 添加到可在移动设备上运行的 TorchScript 格式。

model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main

大小约为 346MB 的脚本化模型文件为 生成。fbdeit_scripted.pt

量化 DeiT

要显著减小训练模型的大小,而 保持推理精度大致相同,量化可以是 应用于模型。得益于 DeiT 中使用的 transformer 模型,我们 可以很容易地将动态量化应用于模型,因为动态 量化最适合 LSTM 和 Transformer 模型(有关更多详细信息,请参阅此处)。

现在运行以下代码:

# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/observer.py:229: UserWarning:

Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.

这将生成模型的脚本化和量化版本 ,大小约为 89MB,减少了 74% 非量化模型大小为 346MB!fbdeit_quantized_scripted.pt

您可以使用 生成相同的 推理结果:scripted_quantized_model

out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed
269

优化 DeiT

使用量化和脚本化 model 就是优化它:

from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")

生成的文件具有大约 与量化、脚本化但未优化的模型大小相同。这 推理结果保持不变。fbdeit_optimized_scripted_quantized.pt

out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed
269

使用 Lite Interpreter

要了解模型大小减小和推理的速度有多快,Lite Interpreter 可以产生,让我们创建模型的 lite 版本。

optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")

尽管 lite 模型大小与非 lite 版本相当,但当 在移动设备上运行 Lite 版本,推理速度是预期的。

比较推理速度

要查看这四个模型的推理速度有何不同,请使用 原始模型、脚本化模型、量化和脚本化模型、 优化量化和脚本化模型 - 运行以下代码:

with torch.autograd.profiler.profile(use_cuda=False) as prof1:
    out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
    out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
    out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
    out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
    out = ptl(img)

print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 171.81ms
scripted model: 107.33ms
scripted & quantized model: 128.83ms
scripted & quantized & optimized model: 146.44ms
lite model: 149.71ms

在 Google Colab 上运行的结果是:

original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms

以下结果总结了每个模型所花费的推理时间 以及每个模型相对于原始模型的减少百分比 型。

import pandas as pd
import numpy as np

df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
    ["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
    ["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
    columns=['Inference Time', 'Reduction'])], axis=1)

print(df)

"""
        Model                             Inference Time    Reduction
0   original model                             1236.69ms           0%
1   scripted model                             1226.72ms        0.81%
2   scripted & quantized model                  593.19ms       52.03%
3   scripted & quantized & optimized model      598.01ms       51.64%
4   lite model                                  600.72ms       51.43%
"""
                                    Model  ... Reduction
0                          original model  ...        0%
1                          scripted model  ...    37.53%
2              scripted & quantized model  ...    25.02%
3  scripted & quantized & optimized model  ...    14.77%
4                              lite model  ...    12.87%

[5 rows x 3 columns]

'\n        Model                             Inference Time    Reduction\n0\toriginal model                             1236.69ms           0%\n1\tscripted model                             1226.72ms        0.81%\n2\tscripted & quantized model                  593.19ms       52.03%\n3\tscripted & quantized & optimized model      598.01ms       51.64%\n4\tlite model                                  600.72ms       51.43%\n'

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源