注意
点击 这里 下载完整示例代码
通过 Flask 在 Python 中使用 REST API 部署 PyTorch 深度学习框架¶
创建日期: 2019年7月3日 | 最后更新日期: 2024年1月19日 | 最后验证日期: 2024年11月5日
作者: 阿维纳什·萨詹谢蒂
在本教程中,我们将使用 Flask 部署一个 PyTorch 模型,并暴露一个用于模型推理的 REST API。特别地,我们将部署一个预训练的 DenseNet 121 模型,该模型可以检测图像。
提示
这里使用的所有代码均采用MIT许可发布,并可在Github上获取。
这代表了一系列关于部署Pytorch模型到生产环境的教学中的第一篇。通过这种方式使用Flask是开始提供您的Pytorch模型最简单的方法,但它不适用于高性能要求的情况。对于这种情况:
If you’re already familiar with TorchScript, you can jump straight into our Loading a TorchScript Model in C++ tutorial.
If you first need a refresher on TorchScript, check out our Intro a TorchScript tutorial.
API 定义¶
我们将首先定义我们的API端点,请求和响应类型。我们的API端点位于/predict,接受HTTP POST请求,并带有file参数,该参数包含图像。响应将是包含预测的JSON响应:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
简单网页服务器¶
以下是一个简单的网络服务器,来自 Flask 的文档
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
我们还将更改响应类型,使其返回一个包含ImageNet类ID和名称的JSON响应。更新后的app.py文件将变为:
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
推理¶
在接下来的部分中,我们将专注于编写推理代码。这将涉及两个方面,一是准备图像以便可以输入到DenseNet中,二是编写获取模型实际预测的代码。
准备图片¶
DenseNet模型要求图像是大小为224 x 224的3通道RGB图像。我们还将使用所需的均值和标准差值对图像张量进行归一化。你可以在这里了解更多。
我们将使用 transforms 来自 torchvision 库并构建一个转换管道,根据需要对我们的图像进行转换。您可以在此处阅读有关转换的更多信息。
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
上述方法接受字节形式的图像数据,应用一系列转换 并返回一个张量。要测试上述方法,请以字节模式读取图像文件 (首先将 ../_static/img/sample_file.jpeg 替换为计算机上实际的 文件路径),然后查看是否返回一个张量:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
预测¶
现在将使用预训练的DenseNet 121模型来预测图像类别。我们将从torchvision库中使用一个,加载模型并进行推理。虽然在这个例子中我们使用的是预训练模型,但你可以使用相同的方法来处理你自己的模型。更多关于加载模型的信息,请参阅这篇教程。
from torchvision import models
# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
return y_hat
张量 y_hat 将包含预测类 ID 的索引。
然而,我们需要一个人类可读的类名称。为此,我们需要一个类 ID
到名称的映射。下载
此文件
并保存为 imagenet_class_index.json,记住你保存的位置(或者,如果你
完全按照本教程中的步骤操作,请将其保存在
tutorials/_static)。此文件包含了 ImageNet 类 ID 到
ImageNet 类名称的映射。我们将加载这个 JSON 文件并获取预测索引的类名称。
import json
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
在使用 imagenet_class_index 字典之前,我们将首先将张量值转换为字符串值,因为 imagenet_class_index 字典中的键是字符串。
我们将测试我们上面的方法:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
你应该会得到这样的响应:
['n02124075', 'Egyptian_cat']
数组中的第一个元素是 ImageNet 类别 ID,第二个元素是人类可读的名字。
- Integrating the model in our API Server
在最后的部分,我们将我们的模型添加到Flask API服务器中。由于我们的API服务器需要处理图像文件, 我们将更新
predict方法以从请求中读取文件:from flask import request @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': # we will get the file from the request file = request.files['file'] # convert that to bytes img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name})
import io import json from torchvision import models import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request app = Flask(__name__) imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json')) model = models.densenet121(weights='IMAGENET1K_V1') model.eval() def transform_image(image_bytes): my_transforms = transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image = Image.open(io.BytesIO(image_bytes)) return my_transforms(image).unsqueeze(0) def get_prediction(image_bytes): tensor = transform_image(image_bytes=image_bytes) outputs = model.forward(tensor) _, y_hat = outputs.max(1) predicted_idx = str(y_hat.item()) return imagenet_class_index[predicted_idx] @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': file = request.files['file'] img_bytes = file.read() class_id, class_name = get_prediction(image_bytes=img_bytes) return jsonify({'class_id': class_id, 'class_name': class_name}) if __name__ == '__main__': app.run()
FLASK_ENV=development FLASK_APP=app.py flask run
library to send a POST request to our app:
import requests resp = requests.post("http://localhost:5000/predict", files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
打印 resp.json() 现在将显示以下内容:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
The server we wrote is quite trivial and may not do everything you need for your production application. So, here are some things you can do to make it better:
The endpoint
/predictassumes that always there will be a image file in the request. This may not hold true for all requests. Our user may send image with a different parameter or send no images at all.The user may send non-image type files too. Since we are not handling errors, this will break our server. Adding an explicit error handing path that will throw an exception would allow us to better handle the bad inputs
Even though the model can recognize a large number of classes of images, it may not be able to recognize all images. Enhance the implementation to handle cases when the model does not recognize anything in the image.
We run the Flask server in the development mode, which is not suitable for deploying in production. You can check out this tutorial for deploying a Flask server in production.
You can also add a UI by creating a page with a form which takes the image and displays the prediction. Check out the demo of a similar project and its source code.
In this tutorial, we only showed how to build a service that could return predictions for a single image at a time. We could modify our service to be able to return predictions for multiple images at once. In addition, the service-streamer library automatically queues requests to your service and samples them into mini-batches that can be fed into your model. You can check out this tutorial.
Finally, we encourage you to check out our other tutorials on deploying PyTorch models linked-to at the top of the page.
脚本的总运行时间: ( 0 分钟 0.000 秒)