目录

使用 Flask 通过 REST API 在 Python 中部署 PyTorch

创建时间: Jul 03, 2019 |上次更新时间:2024 年 1 月 19 日 |上次验证: Nov 05, 2024

作者Avinash Sajjanshetty

在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开一个 用于模型推理的 REST API。具体而言,我们将部署一个预训练的 检测图像的 DenseNet 121 模型。

提示

这里使用的所有代码都是在 MIT 许可下发布的,并且可以在 Github 上找到。

这是有关部署 PyTorch 模型的一系列教程中的第一个 在生产中。以这种方式使用 Flask 是迄今为止最简单的开始方式 提供 PyTorch 模型,但它不适用于用例 具有很高的性能要求。为此:

API 定义

我们将首先定义我们的 API 端点,请求和响应类型。我们 API 端点将位于 接受 HTTP POST 请求,其中包含包含图像的参数。响应将为 JSON response 包含预测:/predictfile

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

依赖

通过运行以下命令安装所需的依赖项:

pip install Flask==2.0.1 torchvision==0.10.0

简单的 Web 服务器

以下是一个简单的 Web 服务器,摘自 Flask 的文档

from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

我们还将更改响应类型,以便它返回 JSON 响应 包含 ImageNet 类 id 和 name。更新后的文件将 be now:app.py

from flask import Flask, jsonify
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

推理

在接下来的部分中,我们将重点介绍如何编写推理代码。这将 涉及两个部分,一个部分是我们准备图像以便可以馈送它 到 DenseNet,接下来,我们将编写代码来获得实际的预测 从模型中。

准备映像

DenseNet 模型要求图像为大小为 3 通道 RGB 图像 224 x 224 像素。我们还将使用所需的平均值对图像张量进行归一化 和标准差值。您可以在此处阅读更多相关信息。

我们将使用 from 库并构建一个 transform 管道,它根据需要转换我们的图像。你 可以在此处阅读有关 Transforms 的更多信息。transformstorchvision

import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

上述方法以字节为单位获取图像数据,应用一系列转换 并返回一个 Tensor。要测试上述方法,请在 字节模式(首先替换 ../_static/img/sample_file.jpeg 替换为实际的 路径),并查看是否返回了张量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

预测

现在将使用预先训练的 DenseNet 121 模型来预测图像类。我们 将使用 library 中的一个,加载模型并获取一个 推理。虽然在此示例中我们将使用预训练模型,但您可以 对您自己的模型使用相同的方法。了解有关加载 models 中。torchvision

from torchvision import models

# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

张量将包含预测类 id 的索引。 但是,我们需要一个人类可读的类名。为此,我们需要一个类 ID 进行名称映射。下载此文件并记住您的保存位置(或者,如果您 正在按照本教程中的确切步骤进行操作,请将其保存在 tutorials/_static) 中。此文件包含 ImageNet 类 id 到 ImageNet 类名。我们将加载此 JSON 文件并获取 预测指数。y_hatimagenet_class_index.json

import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

在使用字典之前,首先我们将 tensor 值转换为字符串值,因为字典中的键是字符串。 我们将测试上述方法:imagenet_class_indeximagenet_class_index

with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

您应该会收到如下响应:

['n02124075', 'Egyptian_cat']

数组中的第一项是 ImageNet 类 id,第二项是人类 可读名称。

将模型集成到我们的 API 服务器中

在最后一部分中,我们将把我们的模型添加到我们的 Flask API 服务器中。因为 我们的 API 服务器应该获取一个图像文件,我们将更新我们的方法以从请求中读取文件:predict

from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(weights='IMAGENET1K_V1')
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()
FLASK_ENV=development FLASK_APP=app.py flask run

library 向我们的应用程序发送 POST 请求:

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

现在,打印 resp.json() 将显示以下内容:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}

我们编写的服务器非常简单,可能无法完成所有工作 您需要用于您的生产应用程序。所以,这里有一些事情 可以做得更好:

  • 终端节点假定始终存在图像文件 在请求中。这可能不适用于所有请求。我们的用户可以 使用不同的参数发送图像或根本不发送任何图像。/predict

  • 用户也可以发送非图像类型的文件。由于我们不处理 错误,这将破坏我们的服务器。添加显式错误处理 path 将引发异常,这将使我们能够更好地处理 错误的输入

  • 尽管模型可以识别大量图像类别, 它可能无法识别所有图像。增强实施 处理模型无法识别图像中任何内容的情况。

  • 我们在开发模式下运行 Flask 服务器,这并不适合 在生产环境中部署。您可以查看此教程,了解如何在生产环境中部署 Flask 服务器。

  • 您还可以通过创建包含包含图像和 显示预测。查看类似项目的演示及其源代码

  • 在本教程中,我们只展示了如何构建一个可以返回 一次一张图像。我们可以修改我们的服务,以便能够返回 一次多张图片。此外,service-streamer 库会自动将服务请求排队,并将它们采样为小批量 可以输入到您的模型中。您可以查看此教程

  • 最后,我们鼓励您查看有关部署 PyTorch 模型的其他教程 linked-to 的 URL 中。

脚本总运行时间:(0 分 0.000 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源