0. 安装 tflite-runtime
ref: https://tensorflow.google.cn/lite/guide/python
pip3 install https://dl.google.com/coral/python/tflite_runtime-2.1.0.post1-cp37-cp37m-linux_armv7l.whl
1. tensorflow官方示例
tensorflow 提供了一个示例, 基于picamera的。
ref: https://github.com/tensorflow/examples/blob/master/lite/examples/object_detection/raspberry_pi/
# 1. Clone
git clone https://github.com/tensorflow/examples --depth 1
# 2. 进入文件夹
cd examples/lite/examples/object_detection/raspberry_pi
# 文件夹里总共5个文件
# README.md #
# annotation.py # 用于绘制方框、标签
# detect_picamera.py # 主程序
# download.sh # 下载 python 依赖包、已训练的模型
# requirements.txt #
# 3. 下载已训练好的模型
bash download.sh /tmp
# - 下载 python 依赖包: numpy picamera Pillow
# - 下载 coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip, 里面有两个文件:detect.tflite, labelmap.txt,这个label文件有乱码
# - 下载正确的label文件: https://dl.google.com/coral/canned_models/coco_labels.txt
# 4. 运行程序
python3 detect_picamera.py --model /tmp/detect.tflite --labels /tmp/coco_labels.txt
2. 使用 opencv 调用 usb camera
我这里没有 picamera,只有一个老的 usb 接口的摄像头。 但 picamera 的 API 不支持 USB 摄像头。
下面改一下代码 使用 opencv 来调用 usb camera.
"""
Example using TF Lite to detect objects with the Raspberry USB camera.
Hardware:
- Pi 3b+
- usb camera
Software
- python 3.7.3
- tflite runtime 2.1
- opencv
Dataset
- coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip
"""
import re
import time
import numpy as np
import cv2
from tflite_runtime.interpreter import Interpreter
args_camera_width = 640
args_camera_height = 480
args_model = 'detect.tflite'
args_labels = 'coco_labels.txt'
args_threshold = 0.4
def load_labels(path):
"""Loads the labels file. Supports files with or without index numbers."""
with open(path, 'r', encoding='utf-8') as f:
lines = f.readlines()
labels = {}
for row_number, content in enumerate(lines):
pair = re.split(r'[:\s]+', content.strip(), maxsplit=1)
if len(pair) == 2 and pair[0].strip().isdigit():
labels[int(pair[0])] = pair[1].strip()
else:
labels[row_number] = pair[0].strip()
return labels
def detect_objects(interpreter, image, threshold):
# 识别:张量填充,运行推理
interpreter.set_tensor(input_details[0]['index'], input_image)
interpreter.invoke()
# 结果输出
boxes = interpreter.get_tensor(output_details[0]['index'])
classes = interpreter.get_tensor(output_details[1]['index'])
scores = interpreter.get_tensor(output_details[2]['index'])
boxes = np.squeeze(boxes)
classes = np.squeeze(classes).astype(np.int32)
scores = np.squeeze(scores)
# print('boxes:', boxes)
# print('classes:', classes)
# print('scores:', classes)
# 设置识别阈值,剔除不好的结果
results = []
for i, score in enumerate(scores):
if score >= threshold:
result = {
'box': boxes[i],
'class_id': classes[i],
'score': scores[i]
}
results.append(result)
return results
def annotate_objects(image, results):
for rst in results:
ymin, xmin, ymax, xmax = rst['box']
class_id = rst['class_id']
name = labels_dict[class_id]
score = rst['score']
xmin = int(xmin * args_camera_width)
xmax = int(xmax * args_camera_width)
ymin = int(ymin * args_camera_height)
ymax = int(ymax * args_camera_height)
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0))
txt = f'{name} {score:.2%}'
cv2.putText(image, txt, (xmin, ymin), 0, 1, (255, 255, 255), 2)
# 1. 读取 labels
labels_dict = load_labels(args_labels)
print('labels_dict: \n ', labels_dict)
# 2. 加载模型文件
interpreter = Interpreter(args_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# print('input_details:\n ', input_details)
# print('output_details:\n ', output_details)
# 3. 打开摄像头
camera = cv2.VideoCapture(0)
camera.set(3, args_camera_height)
camera.set(4, args_camera_width)
frame_rate_calc = 1.0
freq = cv2.getTickFrequency()
# 4. 目标识别
while (True):
# 4.1 计算FPS, 开始计时
t1 = cv2.getTickCount()
# 4.2 从摄像头读取图片, 缩放为 300x300
ret, frame = camera.read()
input_image = cv2.resize(frame, (300, 300))
input_image = np.expand_dims(input_image, axis=0)
input_image = np.uint8(np.float32(input_image))
# 4.3 识别:张量填充,运行推理
results = detect_objects(interpreter, input_image, args_threshold)
print(f'--- {time.strftime("%Y-%m-%d %H:%M:%S")} ---')
for rst in results:
box = rst['box']
class_id = rst['class_id']
name = labels_dict[class_id]
score = rst['score']
print(f'* {name} : {score:.2%} @ {box}')
# 4.4 将识别结果绘制在原图上
annotate_objects(frame, results)
# 4.5 将 FPS 绘制在原图上
txt = f'FPS: {frame_rate_calc:.2f}'
cv2.putText(frame, txt, (20, 30), 0, 1, (0, 255, 255), 2)
# 4.6 显示图片
cv2.imshow('Object detect', frame)
# 4.7 更新计算 FPS
t2 = cv2.getTickCount()
frame_rate_calc = freq / (t2 - t1)
cv2.waitKey(1)
camera.release()
cv2.destroyAllWindows()
网友评论