CIFAR-10 分类问题是机器学习领域一种常见的基准问题,其任务是将 RGB 32x32 像素的图像分为以下 10 类:
airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.
源码
从tenforflow model拷贝一份,里面东西太多,只拷贝了models/tutorials/image/cifar10/,这是我的项目,也可以直接fork就是了
源码结构
文件 | 用途 |
---|---|
cifar10_input.py |
读取原生 CIFAR-10 二进制文件格式。 |
cifar10.py |
构建 CIFAR-10 模型。 |
cifar10_train.py |
在 CPU 或 GPU 上训练 CIFAR-10 模型。 |
cifar10_multi_gpu_train.py |
在多个 GPU 上训练 CIFAR-10 模型。 |
cifar10_eval.py |
评估 CIFAR-10 模型的预测性能。 |
下载数据,可选
我单独下载比较快
wget --no-check-certificate http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
mkdir /tmp/cifar10_data # 默认使用该目录
cp cifar-10-binary.tar.gz /tmp/cifar10_data
训练
我的笔记本运行了30分钟,训练了40000 steps,总的要训练100000 seps
没有GPU的话,估计要更久,可以参考ubuntu 16.04 笔记本双显卡安装tensorflow-gpu
$ python3 cifar10_train.py # 输出的训练文件夹在/tmp/cifar10_train
查看训练进展
$ python3 -m tensorboard.main --logdir=cifar10_train/
浏览器打开就可以看了
image.png
评估
$ python3 cifar10_eval.py
GeForce GTX 850M, pci bus id: 0000:01:00.0, compute capability: 5.0)
2018-11-21 21:39:44.403903: precision @ 1 = 0.850
准确率是85%
网友评论