首页
视频
资源
登录
原
Pytorch Flask服务部署图片识别(学习笔记)
613
人阅读
2024/1/5 14:16
总访问:
2633639
评论:
0
收藏:
0
手机
分类:
pytorch
![](https://img.tnblog.net/arcimg/hb/21f086c80c5d4afda1bc1029dadd8f3a.png) >#Pytorch Flask服务部署图片识别(学习笔记) [TOC] ## Flask 简介 tn2>Flask是一个用Python编写的轻量级Web应用框架。 它简单易用,但同时也足够灵活和强大,能够支持复杂的Web应用。 由于其轻量级的特性,Flask非常适合用作在Web上部署机器学习模型的工具。 tn>简单来讲:启动一个服务,根据传上来的东西进行预测并返回结果。 ## 安装Flask ```python python -m pip install flask ``` ## 实践目录 ![](https://img.tnblog.net/arcimg/hb/46fd7cddc24b4568a29fd34886586152.png) | 文件或文件夹 | 描述 | | ------------ | ------------ | | `flower_data` | 训练的图像数据 | | `best.pth` | 训练好的模型 | | `flask_server.py` | 服务器端代码 | | `flask_predict.py` | 客户端请求代码 | ## 服务器端 tn2>服务器对需要预处理的图片流程如下图所示: ![](https://img.tnblog.net/arcimg/hb/7a2167bd6e97406e9b6ef14875ad5005.png) tn2>`flask_server.py`代码如下所示: ```python import io import json # flask 服务 import flask import torch import torch import torch.nn.functional as F from PIL import Image from torch import nn #from torchvision import transforms as T from torchvision import transforms, models, datasets from torch.autograd import Variable # 初始化Flask app app = flask.Flask(__name__) model = None use_gpu = False # 加载模型进来 def load_model(): """Load the pre-trained model, you can use your model just as easily. """ # 定义一个全局变量 global model #这里我们直接加载官方工具包里提供的训练好的模型(代码会自动下载)括号内参数为是否下载模型对应的配置信息 model = models.resnet18() num_ftrs = model.fc.in_features model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)) # 102类的分类任务 #print(model) 加载模型 checkpoint = torch.load('best.pth') # 加载权重参数 model.load_state_dict(checkpoint['state_dict']) #将模型指定为测试格式 model.eval() #是否使用gpu if use_gpu: model.cuda() # 数据预处理 def prepare_image(image, target_size): """Do image preprocessing before prediction on any data. :param image: original image :param target_size: target image size :return: preprocessed image """ #针对不同模型,image的格式不同,但需要统一至RGB格式 if image.mode != 'RGB': image = image.convert("RGB") # Resize the input image and preprocess it.(按照所使用的模型将输入图片的尺寸修改,并转为tensor) # 图片与训练尺寸大小一致 image = transforms.Resize(target_size)(image) # 转tensor格式 image = transforms.ToTensor()(image) # Convert to Torch.Tensor and normalize. mean与std (RGB三通道)这里的参数和数据集中是对应的,训练过程中一致 # 设置均值和标准差 image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) # Add batch_size axis.增加一个维度,用于按batch测试 本次这里一次测试一张 # 举例:1*3*64*64 image = image[None] if use_gpu: image = image.cuda() return Variable(image, volatile=True) #不需要求导 # 开启服务 这里的predict是API路径、使用POST请求 @app.route("/predict", methods=["POST"]) def predict(): # Initialize the data dictionary that will be returned from the view. #做一个标志,刚开始无图像传入时为false,传入图像时为true data = {"success": False} # 如果收到请求 if flask.request.method == 'POST': #判断是否为图像 if flask.request.files.get("image"): # Read the image in PIL format # 将收到的图像进行读取 image = flask.request.files["image"].read() image = Image.open(io.BytesIO(image)) #二进制数据 # 利用上面的预处理函数将读入的图像进行预处理 image = prepare_image(image, target_size=(64, 64)) # 放入模型中进行预测,softmax得到各个类别的概率 preds = F.softmax(model(image), dim=1) # k找出类别前3高的 results = torch.topk(preds.cpu().data, k=3, dim=1) # 结果转成cpu最后转成numpy results = (results[0].cpu().numpy(), results[1].cpu().numpy()) #将data字典增加一个key,value,其中value为list格式 data['predictions'] = list() # 遍历每一个预测结果 for prob, label in zip(results[0][0], results[1][0]): #label_name = idx2label[str(label)] # label真实值,和probability概率值 r = {"label": str(label), "probability": float(prob)} # 将预测结果添加至data字典 data['predictions'].append(r) # Indicate that the request was a success. data["success"] = True # 将最终结果以json格式文件传出 return flask.jsonify(data) """ test_json = { "status_code": 200, "success": { "message": "image uploaded", "code": 200 }, "video":{ "video_name":opt['source'].split('/')[-1], "video_path":opt['source'], "description":"1", "length": str(hour)+','+str(minute)+','+str(round(second,4)), "model_object_completed":model_flag } "status_txt": "OK" } response = requests.post( 'http://xxx.xxx.xxx.xxx:8090/api/ObjectToKafka/',, data={'json': str(test_json)}) """ if __name__ == '__main__': print("Loading PyTorch model and Flask starting server ...") print("Please wait until server has fully started") #先加载模型 load_model() #再开启服务 app.run(port='5012') ``` tn2>这里我开放的端口是`5012`,通过请求`/predict`链接,通过执行如下命令将程序跑起来: ```python python flask_server.py ``` ![](https://img.tnblog.net/arcimg/hb/9ff98d7c5d80431485ced42218f59147.png) tn>只要把Flask关了模型就没了,如果Flask一直开着的模型就一直都在跑。 ## 客户端 tn2>客户端主要是上传一张`image_06998.jpg`的图片到服务器中去预测,代码如下: ```python import requests import argparse # url和端口携程自己的 flask_url = 'http://127.0.0.1:5012/predict' def predict_result(image_path): #传入本地图片 image = open(image_path, 'rb').read() payload = {'image': image} #request发给server. r = requests.post(flask_url, files=payload).json() # 成功的话在返回. if r['success']: # 输出结果. for (i, result) in enumerate(r['predictions']): print('{}. {}: {:.4f}'.format(i + 1, result['label'], result['probability'])) # 失败了就打印. else: print('Request failed') if __name__ == '__main__': parser = argparse.ArgumentParser(description='Classification demo') # 添加参数 parser.add_argument('--file', default='./flower_data/train_filelist/image_06998.jpg', type=str, help='test image file') args = parser.parse_args() # 开始请求 predict_result(args.file) ``` ```bash python flask_predict.py ``` tn2>预测结果如下所示: ![](https://img.tnblog.net/arcimg/hb/5d69d414c6d54d05bf372795de3bee4b.png) tn2>我们可以看到预测得最相似的label是`34`,准确率`97%`,我们去图片数据中找找这张图片的训练集验证一下。 ![](https://img.tnblog.net/arcimg/hb/e9edae8bec7c47d88497857ac7be15d3.png) tn2>训练的结果与预期的结果一致。 tn><a href="https://download.tnblog.net/resource/index/a6f1480a5ea54461854604818dea347a">代码链接</a>
欢迎加群讨论技术,1群:677373950(满了,可以加,但通过不了),2群:656732739
👈{{preArticle.title}}
👉{{nextArticle.title}}
评价
{{titleitem}}
{{titleitem}}
{{item.content}}
{{titleitem}}
{{titleitem}}
{{item.content}}
尘叶心繁
这一世以无限游戏为使命!
博主信息
排名
6
文章
6
粉丝
16
评论
8
文章类别
.net后台框架
171篇
linux
17篇
linux中cve
1篇
windows中cve
0篇
资源分享
10篇
Win32
3篇
前端
28篇
传说中的c
4篇
Xamarin
9篇
docker
15篇
容器编排
101篇
grpc
4篇
Go
15篇
yaml模板
1篇
理论
2篇
更多
Sqlserver
4篇
云产品
39篇
git
3篇
Unity
1篇
考证
2篇
RabbitMq
23篇
Harbor
1篇
Ansible
8篇
Jenkins
17篇
Vue
1篇
Ids4
18篇
istio
1篇
架构
2篇
网络
7篇
windbg
4篇
AI
18篇
threejs
2篇
人物
1篇
嵌入式
3篇
python
13篇
HuggingFace
8篇
pytorch
9篇
opencv
6篇
Halcon
3篇
最新文章
最新评价
{{item.articleTitle}}
{{item.blogName}}
:
{{item.content}}
关于我们
ICP备案 :
渝ICP备18016597号-1
网站信息:
2018-2024
TNBLOG.NET
技术交流:
群号656732739
联系我们:
contact@tnblog.net
欢迎加群
欢迎加群交流技术