注意
单击此处下载完整的示例代码
使用 Flask 通过 REST API 在 Python 中部署 PyTorch¶
创建时间: Jul 03, 2019 |上次更新时间:2024 年 1 月 19 日 |上次验证: Nov 05, 2024
在本教程中,我们将使用 Flask 部署 PyTorch 模型,并公开一个 用于模型推理的 REST API。具体而言,我们将部署一个预训练的 检测图像的 DenseNet 121 模型。
提示
这里使用的所有代码都是在 MIT 许可下发布的,并且可以在 Github 上找到。
这是有关部署 PyTorch 模型的一系列教程中的第一个 在生产中。以这种方式使用 Flask 是迄今为止最简单的开始方式 提供 PyTorch 模型,但它不适用于用例 具有很高的性能要求。为此:
如果您已经熟悉 TorchScript,则可以直接跳转到我们的在 C++ 中加载 TorchScript 模型教程。
如果您首先需要复习 TorchScript,请查看我们的 TorchScript 简介教程。
API 定义¶
我们将首先定义我们的 API 端点,请求和响应类型。我们
API 端点将位于 接受 HTTP POST 请求,其中包含包含图像的参数。响应将为 JSON
response 包含预测:/predict
file
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
简单的 Web 服务器¶
以下是一个简单的 Web 服务器,摘自 Flask 的文档
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
我们还将更改响应类型,以便它返回 JSON 响应
包含 ImageNet 类 id 和 name。更新后的文件将
be now:app.py
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
推理¶
在接下来的部分中,我们将重点介绍如何编写推理代码。这将 涉及两个部分,一个部分是我们准备图像以便可以馈送它 到 DenseNet,接下来,我们将编写代码来获得实际的预测 从模型中。
准备映像¶
DenseNet 模型要求图像为大小为 3 通道 RGB 图像 224 x 224 像素。我们还将使用所需的平均值对图像张量进行归一化 和标准差值。您可以在此处阅读更多相关信息。
我们将使用 from 库并构建一个
transform 管道,它根据需要转换我们的图像。你
可以在此处阅读有关 Transforms 的更多信息。transforms
torchvision
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
上述方法以字节为单位获取图像数据,应用一系列转换 并返回一个 Tensor。要测试上述方法,请在 字节模式(首先替换 ../_static/img/sample_file.jpeg 替换为实际的 路径),并查看是否返回了张量:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
预测¶
现在将使用预先训练的 DenseNet 121 模型来预测图像类。我们
将使用 library 中的一个,加载模型并获取一个
推理。虽然在此示例中我们将使用预训练模型,但您可以
对您自己的模型使用相同的方法。了解有关加载
models 中。torchvision
from torchvision import models
# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
return y_hat
张量将包含预测类 id 的索引。
但是,我们需要一个人类可读的类名。为此,我们需要一个类 ID
进行名称映射。下载此文件并记住您的保存位置(或者,如果您
正在按照本教程中的确切步骤进行操作,请将其保存在 tutorials/_static) 中。此文件包含 ImageNet 类 id 到
ImageNet 类名。我们将加载此 JSON 文件并获取
预测指数。y_hat
imagenet_class_index.json
import json
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
在使用字典之前,首先我们将
tensor 值转换为字符串值,因为字典中的键是字符串。
我们将测试上述方法:imagenet_class_index
imagenet_class_index
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
您应该会收到如下响应:
['n02124075', 'Egyptian_cat']
数组中的第一项是 ImageNet 类 id,第二项是人类 可读名称。
- 将模型集成到我们的 API 服务器中
在最后一部分中,我们将把我们的模型添加到我们的 Flask API 服务器中。因为 我们的 API 服务器应该获取一个图像文件,我们将更新我们的方法以从请求中读取文件:
predict
from flask import request @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': # we will get the file from the request file = request.files['file'] # convert that to bytes img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name})
import io import json from torchvision import models import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request app = Flask(__name__) imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json')) model = models.densenet121(weights='IMAGENET1K_V1') model.eval() def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': file = request.files['file'] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name}) if __name__ == '__main__': app.run()
FLASK_ENV=development FLASK_APP=app.py flask run
library 向我们的应用程序发送 POST 请求:
import requests resp = requests.post("http://localhost:5000/predict", files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
现在,打印 resp.json() 将显示以下内容:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
我们编写的服务器非常简单,可能无法完成所有工作 您需要用于您的生产应用程序。所以,这里有一些事情 可以做得更好:
终端节点假定始终存在图像文件 在请求中。这可能不适用于所有请求。我们的用户可以 使用不同的参数发送图像或根本不发送任何图像。
/predict
用户也可以发送非图像类型的文件。由于我们不处理 错误,这将破坏我们的服务器。添加显式错误处理 path 将引发异常,这将使我们能够更好地处理 错误的输入
尽管模型可以识别大量图像类别, 它可能无法识别所有图像。增强实施 处理模型无法识别图像中任何内容的情况。
我们在开发模式下运行 Flask 服务器,这并不适合 在生产环境中部署。您可以查看此教程,了解如何在生产环境中部署 Flask 服务器。
在本教程中,我们只展示了如何构建一个可以返回 一次一张图像。我们可以修改我们的服务,以便能够返回 一次多张图片。此外,service-streamer 库会自动将服务请求排队,并将它们采样为小批量 可以输入到您的模型中。您可以查看此教程。
最后,我们鼓励您查看有关部署 PyTorch 模型的其他教程 linked-to 的 URL 中。
脚本总运行时间:(0 分 0.000 秒)