1 下载tensorflow源码
git clone https://github.com/tensorflow/tensorflow.git
checkout到本地安装的tensorflow一样版本的分支
2 下载retrain.py脚步
cd /home/xxx/github/tensorflow/tensorflow/examples/image_retraining
curl https://github.com/tensorflow/hub/blob/master/examples/image_retraining/retrain.py
3 重新训练inception v3
python retrain.py --image_dir=/home/test/data/flower_photos/
5 将pb模型转化为tflite模型
注意:在tensorflow1.9以上版本,替换input_file为 graph_def_file,同时去掉--input_format=TENSORFLOW_GRAPHDEF.否则会报如下错:
toco: error: one of the arguments --graph_def_file --saved_model_dir --keras_model_file is required.
input_arrays有时候会报错,可先将pb文件转为tensorboard可看的图,用tensorboard查看输入node名.
在python3环境下:
import tensorflow as tf
model ='model.pb'#请将这里的pb文件路径改为自己的graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph) #log为指定保存graph的位置
在命令行运行:
tensorboard --logdir=log
如果运行报端口占用,可增加参数--port=8008
5.1 float数据格式转换
toco
--graph_def_file=output_graph.pb
--output_file=/tmp/inception_v3.tflite
--output_format=TFLITE
--inference_type=FLOAT
--input_type=FLOAT
--input_arrays=Placeholder
--output_arrays=final_result
--input_shapes=1,299,299,3
5.2 QUANTIZED_UINT8格式
toco
--graph_def_file=/tmp/output_graph.pb
--output_file=/tmp/inception_v3.tflite
--output_format=TFLITE
--input_arrays=Placeholder
--output_arrays=final_result
--input_shapes=1,299,299,3
--inference_type=QUANTIZED_UINT8
--inference_input_type=QUANTIZED_UINT8
--mean_value=128
--std_dev_values=128
--default_ranges_min=0
--default_ranges_max=6
网友评论