参考 building-a-simple-keras-deep-learning-rest-api
文章中比较重要的代码是2个地方
1. 模型加载
这里貌似是通过网络下载的,所以可能加载比较慢
def load_model():
# load the pre-trained Keras model (here we are using a model
# pre-trained on ImageNet and provided by Keras, but you can
# substitute in your own networks just as easily)
global model
model = ResNet50(weights="imagenet")
2.模型调用
通过post请求上传图片。
调用模型就一行代码
model.predict(image)
if flask.request.method == "POST":
if flask.request.files.get("image"):
# read the image in PIL format
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
# preprocess the image and prepare it for classification
image = prepare_image(image, target=(224, 224))
# classify the input image and then initialize the list
# of predictions to return to the client
preds = model.predict(image)
results = imagenet_utils.decode_predictions(preds)
data["predictions"] = []
# loop over the results and add them to the list of
# returned predictions
for (imagenetID, label, prob) in results[0]:
r = {"label": label, "probability": float(prob)}
data["predictions"].append(r)
# indicate that the request was a success
data["success"] = True
但我在实际操作时,就遇到了问题。
首先,不知道keras版本是多少。还好到文章中给出的github链接找到了依赖:
Keras 2.2.4
TF 1.13.1
其次,跑起来报错:
Tensor is not an element of this graph
还是通过文章中给出的github链接当中,找到了 解决办法,看来是并发导致的问题。
最后的代码改动点在这里,注意加粗部分:
def my_load_model():
# load the pre-trained Keras model (here we are using a model
# pre-trained on ImageNet and provided by Keras, but you can
# substitute in your own networks just as easily)
global model
model = ResNet50(weights="imagenet")
global graph
graph = tf.get_default_graph()
def predict():
# initialize the data dictionary that will be returned from the
# view
data = {"success": False}
global graph
with graph.as_default():
# ensure an image was properly uploaded to our endpoint
if flask.request.method == "POST":
网友评论