目录

优化视觉变换器模型以进行部署

创建日期: 2021年3月15日 | 最后更新日期: 2024年1月19日 | 最后验证日期: 2024年11月5日

唐杰夫, 乔塔·乔汉

Vision Transformer 模型应用了自然语言处理中引入的基于注意力机制的前沿 transformer 模型,以实现各种最先进的(SOTA)成果,应用于计算机视觉任务。Facebook 数据高效图像变压器 DeiT 是一个在 ImageNet 上训练用于图像分类的 Vision Transformer 模型。

在本教程中,我们将首先介绍DeiT是什么以及如何使用它,然后逐步讲解脚本编写、量化、优化以及在iOS和Android应用中使用模型的完整步骤。我们还将比较量化、优化和未量化、未优化模型的性能,并展示在每一步应用量化和优化的好处。

什么是DeiT

自2012年深度学习兴起以来,卷积神经网络(CNNs)一直是图像分类的主要模型,但CNNs通常需要数百万张图像进行训练,才能达到最佳表现。DeiT是一种视觉变换器模型,它在训练时所需的图像数据和计算资源要少得多,从而能够与领先的CNNs在执行图像分类任务方面相竞争,这得益于DeiT的两个关键组成部分:

  • 模拟在更大数据集上进行训练的数据增强;

  • 允许变压器网络从CNN的输出中学习的原生蒸馏技术。

DeiT 显示,即使在有限的数据和资源访问下,Transformer 也可以成功应用于计算机视觉任务。有关 DeiT 的更多信息,请参阅 repo论文

使用DeiT进行图像分类

Follow the README.md 在DeiT仓库中获取有关如何使用DeiT进行图像分类的详细信息,或者进行快速测试,请先安装所需的包:

pip install torch torchvision timm pandas requests

要在Google Colab中运行,请运行以下命令安装依赖项:

!pip install timm pandas requests

然后运行以下脚本:

from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
2.5.0+cu124
Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /var/lib/ci-user/.cache/torch/hub/main.zip
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: FutureWarning:

Importing from timm.models.registry is deprecated, please import via timm.models

/usr/local/lib/python3.10/dist-packages/timm/models/layers/__init__.py:48: FutureWarning:

Importing from timm.models.layers is deprecated, please import via timm.layers

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:63: UserWarning:

Overwriting deit_tiny_patch16_224 in registry with models.deit_tiny_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:78: UserWarning:

Overwriting deit_small_patch16_224 in registry with models.deit_small_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:93: UserWarning:

Overwriting deit_base_patch16_224 in registry with models.deit_base_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:108: UserWarning:

Overwriting deit_tiny_distilled_patch16_224 in registry with models.deit_tiny_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:123: UserWarning:

Overwriting deit_small_distilled_patch16_224 in registry with models.deit_small_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:138: UserWarning:

Overwriting deit_base_distilled_patch16_224 in registry with models.deit_base_distilled_patch16_224. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:153: UserWarning:

Overwriting deit_base_patch16_384 in registry with models.deit_base_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

/var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main/models.py:168: UserWarning:

Overwriting deit_base_distilled_patch16_384 in registry with models.deit_base_distilled_patch16_384. This is because the name being registered conflicts with an existing name. Please check if this is not expected.

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth

  0%|          | 0.00/330M [00:00<?, ?B/s]
  6%|6         | 20.5M/330M [00:00<00:01, 214MB/s]
 13%|#2        | 41.8M/330M [00:00<00:01, 219MB/s]
 19%|#9        | 63.0M/330M [00:00<00:01, 220MB/s]
 26%|##5       | 84.2M/330M [00:00<00:01, 221MB/s]
 32%|###1      | 106M/330M [00:00<00:01, 221MB/s]
 38%|###8      | 127M/330M [00:00<00:00, 221MB/s]
 45%|####4     | 148M/330M [00:00<00:00, 222MB/s]
 51%|#####1    | 169M/330M [00:00<00:00, 221MB/s]
 58%|#####7    | 190M/330M [00:00<00:00, 222MB/s]
 64%|######4   | 212M/330M [00:01<00:00, 222MB/s]
 71%|#######   | 233M/330M [00:01<00:00, 222MB/s]
 77%|#######6  | 254M/330M [00:01<00:00, 222MB/s]
 83%|########3 | 275M/330M [00:01<00:00, 222MB/s]
 90%|########9 | 297M/330M [00:01<00:00, 222MB/s]
 96%|#########6| 318M/330M [00:01<00:00, 222MB/s]
100%|##########| 330M/330M [00:01<00:00, 221MB/s]
269

输出应该是269,根据ImageNet的类别索引到标签文件,映射到timber wolf, grey wolf, gray wolf, Canis lupus

现在我们已经验证了可以使用DeiT模型对图像进行分类,让我们看看如何修改该模型以便在iOS和Android应用中运行。

Scripting DeiT

要将模型用于移动设备,我们首先需要对模型进行脚本化。 请参阅脚本和优化指南以快速了解。运行以下代码将上一步中使用的DeiT模型转换为可在移动设备上运行的TorchScript格式。

model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
Using cache found in /var/lib/ci-user/.cache/torch/hub/facebookresearch_deit_main

脚本化的模型文件 fbdeit_scripted.pt 大小约为 346MB 是生成的。

DeiT的量化

为了在保持推理准确性大致相同的同时显著减少训练模型的大小,可以对模型应用量化。由于DeiT中使用的transformer模型,我们可以轻松地对模型应用动态量化,因为动态量化最适合LSTM和transformer模型(详见这里以获取更多详情)。

现在运行以下代码:

# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and ``qnnpack`` for mobile inference.
backend = "x86" # replaced with ``qnnpack`` causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/observer.py:229: UserWarning:

Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.

这生成了模型fbdeit_quantized_scripted.pt的脚本化和量化版本,大小约为89MB,比非量化模型的346MB大小减少了约74%!

您可以使用 scripted_quantized_model 生成相同的推理结果:

out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed
269

优化DeiT

在将量化和脚本化的模型用于移动设备之前,需要对其进行优化:

from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")

生成的 fbdeit_optimized_scripted_quantized.pt 文件大小与量化、转脚本但未优化的模型大致相同。推理结果保持不变。

out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed
269

使用Lite解释器

让我们创建 Lite 版本的模型,以查看 Lite 解释器可以带来多少模型大小缩减和推理速度提升。

optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")

虽然轻量级模型的大小与非轻量级版本相当,但在移动设备上运行轻量级版本时,预期会提高推理速度。

比较推理速度

要查看四种模型——原始模型、脚本化模型、量化并脚本化模型以及优化的量化并脚本化模型——的推理速度差异,请运行以下代码:

with torch.autograd.profiler.profile(use_cuda=False) as prof1:
    out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
    out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
    out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
    out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
    out = ptl(img)

print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))
original model: 171.81ms
scripted model: 107.33ms
scripted & quantized model: 128.83ms
scripted & quantized & optimized model: 146.44ms
lite model: 149.71ms

在Google Colab上运行的结果是:

original model: 1236.69ms
scripted model: 1226.72ms
scripted & quantized model: 593.19ms
scripted & quantized & optimized model: 598.01ms
lite model: 600.72ms

以下结果总结了每个模型的推理时间以及相对于原始模型的百分比减少量。

import pandas as pd
import numpy as np

df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
    ["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
    ["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
    ["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
     "{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
    columns=['Inference Time', 'Reduction'])], axis=1)

print(df)

"""
        Model                             Inference Time    Reduction
0   original model                             1236.69ms           0%
1   scripted model                             1226.72ms        0.81%
2   scripted & quantized model                  593.19ms       52.03%
3   scripted & quantized & optimized model      598.01ms       51.64%
4   lite model                                  600.72ms       51.43%
"""
                                    Model  ... Reduction
0                          original model  ...        0%
1                          scripted model  ...    37.53%
2              scripted & quantized model  ...    25.02%
3  scripted & quantized & optimized model  ...    14.77%
4                              lite model  ...    12.87%

[5 rows x 3 columns]

'\n        Model                             Inference Time    Reduction\n0\toriginal model                             1236.69ms           0%\n1\tscripted model                             1226.72ms        0.81%\n2\tscripted & quantized model                  593.19ms       52.03%\n3\tscripted & quantized & optimized model      598.01ms       51.64%\n4\tlite model                                  600.72ms       51.43%\n'

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源