美文网首页
机器学习CNN之step by step(tensorflow版

机器学习CNN之step by step(tensorflow版

作者: 小楼又春风 | 来源:发表于2018-07-04 16:46 被阅读0次

上篇文章

tensorflow版本

个人感觉tensorflow相比caffe自由性更强,也难度更高一些,需要实际编写代码来完成各步骤工作,需要对tensorflow的api有一定了解。不过你也可完全复用github上开源的训练代码,如slim,在这基础上你只需要几步就可完成模型的训练来实现图片分类

默认已经安装好tensorflow,如未安装可参考官方文档

本文操作主要是基于slim,所以需要先git clone下代码来,在此基础上进行修改和训练。

第一步:准备数据集

和Caffe版本一样,你也需要训练图片和对应的标签索引txt文件。其中标签索引txt文件中一行代表一张图片,内容是图片路径和对应分类的标签,中间用空格隔开,train.txt如下所示:

/Users/fanchun/Desktop/视频文件分析/分屏花屏/huapin/flurglitch/page1/Num1.jpg_0.jpg 0
/Users/fanchun/Desktop/视频文件分析/分屏花屏/huapin/flurglitch/page1/Num10.jpg_0.jpg 0
/Users/fanchun/Desktop/视频文件分析/分屏花屏/huapin/flurglitch/page1/Num100.jpg_1.jpg 1
/Users/fanchun/Desktop/视频文件分析/分屏花屏/huapin/flurglitch/page1/Num11.jpg_1.jpg 1
/Users/fanchun/Desktop/视频文件分析/分屏花屏/huapin/flurglitch/page1/Num14.jpg_1.jpg 1
/Users/fanchun/Desktop/视频文件分析/分屏花屏/huapin/flurglitch/page1/Num15.jpg_0.jpg 0

将图片转换为TFRecord格式,可参考github上的例子,稍加修改就可写成比较通用的转换代码,如下面的代码,具体参数含义如下:

  • list_path即为标签索引txt文件路径
  • data_dir为图片存放路径,那么标签索引文件中指定的路径则为图片相对该路径的相对路径。
  • output_dir生成TFRecord文件的路径
  • _NUM_SHARDS生成TFRecord文件的数量,单独生成一个文件会太大,可生成多个TFRecord文件
  • shuffle是否打乱顺序
  • resize_width和resize_heightresize的宽和高,如果不需要resize则用默认值
  • tmp_path过程中可能用到保存临时文件的路径,随意指定一个即可

简单而言,该方法主要是从标签索引txt文件读出图片路径和对应标签值,调用dataset_utils.py中写好的image_to_tfexample方法来生成对应数据model,然后调用tfrecord_writer.write(example.SerializeToString())写入TFRecord文件,中间可能会做些打乱和rsize的操作。

def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5 , shuffle = True, resize_width = 0, resize_height = 0, tmp_path = ''):
    fd = open(list_path)
    lines = [line.split() for line in fd]
    fd.close()
    if shuffle:
        # shuffle the data:
        random.seed(_RANDOM_SEED)
        random.shuffle(lines)

    need_resize = 0
    if resize_width * resize_height:
        print("resize image ({}x{})".format(resize_width, resize_height))
        need_resize = 1

    num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
    with tf.Graph().as_default():
        decode_jpeg_data = tf.placeholder(dtype=tf.string)
        decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
        with tf.Session('') as sess:
            for shard_id in range(_NUM_SHARDS):
                output_path = os.path.join(output_dir,
                    'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
                tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
                start_ndx = shard_id * num_per_shard
                end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
                for i in range(start_ndx, end_ndx):
                    sys.stdout.write('\r>> Converting image {}/{} shard {}\n'.format(
                        i + 1, len(lines), shard_id))
                    sys.stdout.flush()
                    filepath = lines[i][0]
                    if need_resize:
                        img = Image.open(os.path.join(data_dir, filepath))
                        filepath = os.path.join(tmp_path,os.path.basename(filepath))
                        img.resize([resize_width,resize_height]).save(filepath)
                        height,width = resize_height,resize_width
                        image_data = tf.gfile.FastGFile(os.path.join(data_dir, filepath), 'rb').read()
                    else:
                        image_data = tf.gfile.FastGFile(os.path.join(data_dir, filepath), 'rb').read()
                        image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
                        height, width = image.shape[0], image.shape[1]

                    example = dataset_utils.image_to_tfexample(
                        image_data, b'jpg', height, width, int(lines[i][1]))
                    tfrecord_writer.write(example.SerializeToString())
                tfrecord_writer.close()
    sys.stdout.write('\n')
    sys.stdout.flush()

执行上述方法即可在output_dir目录下生成对应的TFRecord文件

第二步:读入数据

数据的读取可以参考github代码,稍加改动即可使用。
训练时我们主要调用train_image_classifier.py来执行训练,该文件代码中是通过dataset_factory.py调用对应的get_split来获取得到数据对象,我们可以自己创建个通用的方法来让dataset_factory.py调用。

_FILE_PATTERN = '*.tfrecord'

SPLITS_TO_SIZES = {'train': 39588, 'test': 9980}

_NUM_CLASSES = 2

def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
    """Gets a dataset tuple with instructions for reading picture data.
    
    Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.
    
    Returns:
    A `Dataset` namedtuple.
    
    Raises:
    ValueError: if `split_name` is not a valid train/validation split.
    """
    if split_name not in SPLITS_TO_SIZES:
        raise ValueError('split name %s was not recognized.' % split_name)

    if not file_pattern:
        file_pattern = _FILE_PATTERN
        file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

    # Allowing None in the signature so that dataset_factory can use the default.
    if reader is None:
        reader = tf.TFRecordReader

    num = SPLITS_TO_SIZES[split_name]
    path = os.path.join(dataset_dir,split_name)
    return get_dataset(dataset_dir=path,num_samples=num,num_classes= _NUM_CLASSES)
    
def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern= _FILE_PATTERN):
    file_pattern = os.path.join(dataset_dir, file_pattern)
    keys_to_features = {
        'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
        'image/class/label': tf.FixedLenFeature(
            [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
    }
    items_to_handlers = {
        'image': slim.tfexample_decoder.Image(),
        'label': slim.tfexample_decoder.Tensor('image/class/label'),
    }
    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
    items_to_descriptions = {
        'image': 'A color image of varying size.',
        'label': 'A single integer between 0 and ' + str(num_classes - 1),
    }
    labels_to_names = None
    if labels_to_names_path is not None:
        fd = open(labels_to_names_path)
        labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
        fd.close()
    return slim.dataset.Dataset(
            data_sources=file_pattern,
            reader=tf.TFRecordReader,
            decoder=decoder,
            num_samples=num_samples,
            items_to_descriptions=items_to_descriptions,
            num_classes=num_classes,
            labels_to_names=labels_to_names)

然后在dataset_factory.pydatasets_map字典里新增一个我们写的通用文件和对应的名字,到时候就可在调用执行训练时以名字指定上述代码为读取方式。

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'general': general,
}

第三步:训练

这一步主要是写个shell脚本来进行训练,你亦可以参考github上源码
这里也提供下我训练时用的脚本语句,并附上各个参数的含义。

# Where the checkpoint and logs will be saved to.
TRAIN_DIR=/Users/fanchun/Documents/机器学习/models/research/slim/tmp

# Where the dataset is saved to.
DATASET_DIR=/Users/fanchun/Documents/MachinelearningDATA/tensordata/flur/

# Run training.
python ../train_image_classifier.py \
  --train_dir=${TRAIN_DIR} \
  --dataset_name=general \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --model_name=mobilenet_v1 \
  --preprocessing_name=mobilenet_v1 \
  --checkpoint_path=/Users/fanchun/Documents/机器学习/models/research/slim/pretrained/mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
  --checkpoint_exclude_scopes=MobilenetV1/Logits/Conv2d_1c_1x1 \
  --batch_size=64 \
  --save_interval_secs=3600 \
  --save_summaries_secs=3600 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --rmsprop_decay=0.9 \
  --opt_epsilon=1.0\
  --learning_rate=0.001 \
  --learning_rate_decay_factor=0.1 \
  --momentum=0.9 \
  --num_epochs_per_decay=3.0 \
  --weight_decay=0.0 \
  --num_clones=2 \
  --clone_on_cpu=True

上面各个参数和caffe训练时配置的参数有异曲同工之妙,其中重要的几个含义如下:

  • train_dir生成数据集TFRecord的路径
  • dataset_name设置为上述写的通用读取数据集方式的名字即可
  • dataset_split_name 这里主要是为了区分训练集和验证集,我们可以分别制作出对应的TFRecord文件,放到两个以dataset_split_name 来命名的文件夹里,存放在dataset_dir目录下,读取时会拼接处路径来读取
.
├── test
│   ├── data_00000-of-00005.tfrecord
│   ├── data_00001-of-00005.tfrecord
│   ├── data_00002-of-00005.tfrecord
│   ├── data_00003-of-00005.tfrecord
│   └── data_00004-of-00005.tfrecord
├── test_score.txt
├── train
│   ├── data_00000-of-00005.tfrecord
│   ├── data_00001-of-00005.tfrecord
│   ├── data_00002-of-00005.tfrecord
│   ├── data_00003-of-00005.tfrecord
│   └── data_00004-of-00005.tfrecord
└── train_score.txt
  • dataset_dir存放训练集验证集的根目录,拼接dataset_dirdataset_split_name来找到对应的TFRecord文件
  • model_name 训练选取的模型,slim里已经帮我们实现了很多业界著名的模型,如下所示
networks_map = {'alexnet_v2': alexnet.alexnet_v2,
                'cifarnet': cifarnet.cifarnet,
                'overfeat': overfeat.overfeat,
                'vgg_a': vgg.vgg_a,
                'vgg_16': vgg.vgg_16,
                'vgg_19': vgg.vgg_19,
                'inception_v1': inception.inception_v1,
                'inception_v2': inception.inception_v2,
                'inception_v3': inception.inception_v3,
                'inception_v4': inception.inception_v4,
                'inception_resnet_v2': inception.inception_resnet_v2,
                'lenet': lenet.lenet,
                'resnet_v1_50': resnet_v1.resnet_v1_50,
                'resnet_v1_101': resnet_v1.resnet_v1_101,
                'resnet_v1_152': resnet_v1.resnet_v1_152,
                'resnet_v1_200': resnet_v1.resnet_v1_200,
                'resnet_v2_50': resnet_v2.resnet_v2_50,
                'resnet_v2_101': resnet_v2.resnet_v2_101,
                'resnet_v2_152': resnet_v2.resnet_v2_152,
                'resnet_v2_200': resnet_v2.resnet_v2_200,
                'mobilenet_v1': mobilenet_v1.mobilenet_v1,
                'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075,
                'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050,
                'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025,
                'nasnet_cifar': nasnet.build_nasnet_cifar,
                'nasnet_mobile': nasnet.build_nasnet_mobile,
                'nasnet_large': nasnet.build_nasnet_large,
               }

当然也可以自己按格式实现一个网络出来。

  • clone_on_cpu 理论上tensorflow在有gpu情况下会自己选择gpu来训练,如果你是gpu训练,需要将此参数去掉,否则还是在使用cpu训练

----finetune训练

如果你是finetune训练,你需要配置下面两个参数

  • checkpoint_path 为finetune训练时用的使用其他数据集训练好的模型文件路径,训练时会从该路径下模型里读取出对应的超参数。,如下面文件夹里的训练好的模型文件,我们需要指定文件通用前缀即可,如$PATH/mobilenet_v1_1.0_224.ckpt
.
├── mobilenet_v1_1.0_224.ckpt.data-00000-of-00001
├── mobilenet_v1_1.0_224.ckpt.index
├── mobilenet_v1_1.0_224.ckpt.meta
├── mobilenet_v1_1.0_224.tflite
├── mobilenet_v1_1.0_224_eval.pbtxt
├── mobilenet_v1_1.0_224_frozen.pb
└── mobilenet_v1_1.0_224_info.txt
  • checkpoint_exclude_scopes我们实际训练得模型最后一层可能已经改掉了,如我们输出是二分类,可能与finetune的预训练模型结构不一致,这时候如果仍然全部读取参数则会报错,这个参数就是用来指定哪一层不需要读取。

----验证

slim也提供了对应的``来验证对于验证集训练模型的准确率,从而判断是否收敛和有过拟合发生。

可以写个脚本,创建一个crontab定时任务,定时执行验证下最新训练出模型的准确率如何,定时任务脚本如下图所示

# Where the checkpoint and logs will be saved to.
TRAIN_DIR=/Users/fanchun/Documents/机器学习/models/research/slim/tmp

# Where the dataset is saved to.
DATASET_DIR=/Users/fanchun/Documents/MachinelearningDATA/tensordata/flur/


cd /Users/fanchun/Documents/机器学习/models/research/slim
# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR} \
  --eval_dir=${TRAIN_DIR} \
  --dataset_name=general \
  --dataset_split_name=test \
  --dataset_dir=${DATASET_DIR} \
  --model_name=mobilenet_v1

执行crontab -e,添加定时任务,将执行输出保存下来,grep日志可以查看准确率的变化情况

ETHANFAN-MC1:tmp fanchun$ grep 'Accuracy' accuracy.log
2018-05-07 21:20:10.877568: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.5848]
2018-05-07 22:20:07.974321: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9316]
2018-05-07 23:20:04.408895: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9561]
2018-05-08 00:20:12.106144: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9676]
2018-05-08 01:20:15.131215: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9761]
2018-05-08 02:20:19.184057: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9804]
2018-05-08 03:20:15.421232: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9862]
2018-05-08 04:20:07.181298: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9886]
2018-05-08 05:20:06.910575: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.992]
2018-05-08 06:20:10.145253: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9949]
2018-05-08 07:20:09.542855: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9948]
2018-05-08 08:20:10.381687: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9956]
2018-05-08 09:20:09.922440: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9962]
2018-05-08 10:20:20.420422: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9968]
2018-05-08 11:20:17.745162: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.997]
2018-05-08 12:20:18.251829: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9971]
2018-05-08 13:20:26.789431: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9971]
2018-05-08 14:20:39.059187: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9972]
2018-05-08 15:21:00.792348: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9973]
2018-05-08 16:20:38.790176: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9973]
2018-05-08 22:20:47.657169: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9973]
2018-05-08 23:20:22.674791: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9974]
2018-05-09 00:20:20.719059: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9975]
2018-05-09 01:20:13.749383: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9976]
2018-05-09 02:20:13.777951: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9976]
2018-05-09 03:20:16.734230: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9977]
2018-05-09 04:20:20.524365: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9977]
2018-05-09 05:20:19.858391: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.998]
2018-05-09 06:20:17.753914: I tensorflow/core/kernels/logging_ops.cc:79] eval/Accuracy[0.9981]

第四步:使用

通常我们使用 TensorFlow时保存模型都使用 ckpt 格式的模型文件,使用类似的语句来保存模型

tf.train.Saver().save(sess,ckpt_file_path,max_to_keep=4,keep_checkpoint_every_n_hours=2)
使用如下语句来恢复所有变量信息

saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))
但是这种方式有几个缺点,首先这种模型文件是依赖 TensorFlow 的,只能在其框架下使用;其次,在恢复模型之前还需要再定义一遍网络结构,然后才能把变量的值恢复到网络中。

谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 >TensorFlow 的模型。
它的主要使用场景是实现创建模型与使用模型的解耦, 使得前向推导 inference的代码统一。
另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小,适合在手机端运行。

slim已经提供对应的export_inference_graph.py文件来将模型保存为pb文件,然后我们再调用tensorflow中的freeze_graph.py来结合训练好的ckpt文件固化模型,具体代码如下所示,可参考文章

# Where the checkpoint is saved to
CHECKPOINT=$1

if [[ ${CHECKPOINT} = "" ]]; then
    echo "the CHECKPOINT path is error! [${CHECKPOINT}]"
fi


# Where the dataset is saved to.
DATASET_DIR=/Users/fanchun/Documents/MachinelearningDATA/tensordata/definition_data/

# Where the pb file is saved to.
PB_DIR=/tmp/frozen_mobilenet_v1.pb


cd /Users/fanchun/Documents/机器学习/models/research/slim

# Run evaluation.
python export_inference_graph.py \
  --alsologtostderr \
  --model_name=mobilenet_v1 \
  --image_size=224 \
  --dataset_name=general \
  --dataset_dir=${DATASET_DIR}   \
  --output_file=/tmp/mobilenet_v1_224.pb

python /Users/fanchun/Documents/MachinelearningDATA/flur/mobilenet/tensorflow/tensorflow/python/tools/freeze_graph.py \
  --input_graph=/tmp/mobilenet_v1_224.pb \
  --input_checkpoint=${CHECKPOINT} \
  --input_binary=true \
  --output_graph=${PB_DIR} \
  --output_node_names=MobilenetV1/Predictions/Reshape_1

echo "pb file saved in ${PB_DIR}"

得到固化模型pb文件后,我们就可以更简单地使用了,代码如下所示

with tf.gfile.GFile(self._model_data_path) as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    graph = tf.import_graph_def(graph_def, name='')
    
sess = tf.Session(graph= graph)
input_tensor = self._sess.graph.get_tensor_by_name("input:0")  # get input tensor
output_tensor = self._sess.graph.get_tensor_by_name("MobilenetV1/Predictions/Reshape_1:0") 
probs = sess.run(self._output_tensor, feed_dict={self._input_tensor: input_image})

至此,你便实现了用tensorflow来训练个图片分类模型。

相关文章

网友评论

      本文标题:机器学习CNN之step by step(tensorflow版

      本文链接:https://www.haomeiwen.com/subject/wgtyuftx.html