美文网首页
Caffe | 多标签训练的三种策略

Caffe | 多标签训练的三种策略

作者: yuanCruise | 来源:发表于2019-01-13 22:13 被阅读28次

    0.多标签问题

    在很多深度学习任务中会用到多标签学习,比如做目标检测任务,如下图所示,图片1中物体类别为1,剩余的四个为其位置坐标。图片2中物体类别为2,剩余的四个为其位置坐标。所以即使一张图片中只有一个目标但其仍然是多标签学习问题。

    0000001.jpg 1 72 79 232 273
    0000002.jpg 2 67 59 155 161
    

    还有一种多标签问题是利用人脸的某些特性来辅助人脸特征点定位(双眼,鼻子,两个嘴角一共五个特征点)。如下图中的标签解释:第一个字段为图片名,第2,3个字段为第一个特征点的位置,以此类推后边8个字段分别为剩余4个特征点的位置(距离左上角的百分比)。后续的字段为是否戴帽子等等特性,每一列代表了一种特性,是为1不是为0。这也是一种多标签学习问题。

    图片名    (距离左上角的位置 距离左上角的位置) 
    000002.jpg 0.264367816092 0.724137931034 0.425287356322 0.298850574713 0.758620689655 0.241379310345 0.241379310345 0.528735632184 0.729885057471 0.72988505747
    1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1
    000003.jpg 0.391891891892 0.662162162162 0.648648648649 0.27027027027 0.486486486486 0.256756756757 0.27027027027 0.486486486486 0.662162162162 0.743243243243 
    0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1
    
    

    1.利用HDF5文件实现多标签制作训练

    (1)利用python生成hdf5文件
    在Caffe中,如果使用LMDB数据格式的话,默认是只支持“图像+整数单标签”这种形式的数据的。如果训练网络需要一些如第一部分中介绍的其他形式的数据或标签(如浮点数据,多标签等等),可以将其制作成HDF5格式。当然万物都有其内在的规律,此消彼长,所以HDF5数据格式虽然比较灵活,但缺点是占用空间较大。

    下面的代码是利用python实现hdf5文件的制作。需要注意的是下边的代码仅仅是一种展示如何用python制作h5py文件,并没有严格按照第一部分中介绍的多标签形式来操作。

    import random
    import os
    import numpy as np
    from PIL import Image
    import h5py
    LABELS = {"0":(1,0,0,0,0),"1":(0,1,0,0,0),"2":(0,0,1,0,0),"3":(0,0,0,1,0),"4":(0,0,0,0,1)}
    
    
    IMAGE_DIR = ['image_train','image_test']
    HDF5_FILE = ['hdf5_train.h5','hdf5_test.h5']
    LIST_FILE = ['list_train.txt','list_test.txt']
    
    filename = "/home/YL/lenet/traintestTxt/train11_abspath.txt"
    LABEL_SIZE = 5  #Multi-labels
    MEAN_VALUE = 128
    
    setname = 'train'
    
    with open(filename,'r') as f:
        lines = f.readlines()
    
    np.random.shuffle(lines)
    
    sample_size = len(lines)
    
    datas = np.zeros((sample_size,3,56,56))
    labels = np.zeros((sample_size,5))
    
    h5_filename = '{}.h5'.format(setname)
    
    #print datas.shape
    #print h5_filename
    
    for i,line in enumerate(lines):
        data_line = line.split(" ")[0]
        labels_line = line.split(" ")[1]
        temp_img = np.array(Image.open(data_line)).astype(np.float32)
        temp_img = temp_img[:,:,::-1] #turn RGB to BGR
        temp_img = temp_img.transpose((2,0,1))# turn 56,56,3 to 3,56,56
        temp_img = temp_img.copy()
        temp_img.resize(3,56,56)
    
        datas[i,:,:,:] = (temp_img-128)/256
    
        labels[i,:] = np.array(LABELS[labels_line.strip("\n")]).astype(np.int)
        print('processed {} images!'.format(i))
    
    with h5py.File(h5_filename,'w') as h:
        h['data'] = datas
        h['label'] = labels
        f.close()
    with open('{}.txt'.format(setname),'w') as f:
        f.write(os.path.abspath(h5_filename)+'\n')
        f.close()
    
    

    (2)利用hdf5进行训练
    正如上面提到过的,DHF5文件往往容量比较大,而且caffe导入单个HDF5文件大小有限制,因此当我们的训练数据较多的时候,往往需要将数据分别写入多个HDF5文件中,并把这多个HDF5文件的路径存放到同一个train.txt中。具体执行如下:

    layer {
      name: "data"
      type: "HDF5Data"
      top: "data"
      top: "label"
      include {
        phase: TRAIN
      }
      hdf5_data_param {
      source: "train.txt"
      batch_size: 128
      shuffle: true
      }
    }
    

    还需要注意,shuffle是对H5文件进行乱序,而每个H5文件内部的顺序不动。由于可能存在多个HDF5文件,所以HDF5Data的输入是从一个TXT文件读取的列表,train.txt内容示例如下:

    train1.h5
    train2.h5
    ...
    

    2.利用LMDB结合HDF5实现多标签制作训练

    正如上述提到的HDF5文件的缺点是占用空间较大。而caffe对于单个HDF5文件的大小是有限制的,虽然可以通过上述方法的txt文件解决,但是利用HDF5做多标签时很占用空间。 因此还有一种策略为:将图像文件存为LMDB格式,快速且节省空间;将标签文件存为HDF5格式。并且最终在网络定义Prototxt文件中,同时使用Data层和HDF5层。
    (1)图像文件LMDB
    将图像文件存为LMDB就不多做解释了,详细观看下述两个博客。
    Caffe | 你的第一个分类网络之Caffe训练
    Caffe | 你的第一个分类网络之数据准备

    (2)标签文件HDF5

    import h5py
    import numpy as np
    import os
    
    label_dim = 45 
    
    # 存放标签值的文件
    list_txt = 'name_label.txt'
    # 要生成的HDFS文件名
    hdf5_file_name = 'hdf5_train.h5'
    
    with open(list_txt, 'r') as f:
        lines = f.readlines()
        samples_num = len(lines)
    
        # 此处可以指定数据类型,如 dtype=float
        labels = np.zeros((len(lines), label_dim))
    
        for index in range(samples_num):
            img_name, label = lines[index].strip().split()
            label_int = [int(l) for l in label] 
            labels[index, :] = label_int
    
        # 将标签数据写入hdf5文件
        h5_file = h5py.File(hdf5_file_name, 'w')
        # 此处'multi_label'和网络定义文件中HDF5Data层的top名字是一致的
        h5_file.create_dataset('multi_label', data=labels)
        h5_file.close()
    
    
    

    其中,name_label.txt存放了图像名称和标签,当然这里我们只需要读取其中的标签。这样的话,hdf5_train.h5里面就储存了所有图像对应的标签,每个标签包含多个0或1的值。需要注意的是,在制作图像数据LMDB文件时候的name.txt中的文件名一定要和name_label.txt中标签一一对应。

    (3)利用LMDB和HDF5进行训练

    name: "MY_NET"
    
    layer {
      name: "data"
      type: "Data"
      top: "data"
      top: "label"
      data_param {
        source: "train_lmdb"
        backend: LMDB
        batch_size: 1 
      }
      transform_param {
        mean_value: 104.0
        mean_value: 117.0
        mean_value: 123.0
      }
    }
    
    layer {
      name: "multi_label"
      type: "HDF5Data" 
      top: "label"
      hdf5_data_param {
        source: "train.txt"
        batch_size: 1 
      }
    }
    

    train.txt内容示例如下:

    train1.h5
    train2.h5
    ...
    

    3.利用LMDB文件实现多标签制作训练

    其实不论是第一种利用HDF5结合Slice标签的策略,亦或是第二种LMDB结合HDF5实现多标签的策略,都存在局限性。 因此还有第三种更方便但执行起来会有点麻烦的方法,那就是直接修改caffe网络源码使其满足多标签的输入。

    (1)修改convert_imageset


    // This program converts a set of images to a lmdb/leveldb by storing them
    // as Datum proto buffers.
    // Usage:
    //   convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
    //
    // where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
    // should be a list of files as well as their labels, in the format as
    //   subfolder1/file1.JPEG 7
    //   ....
    
    #include <algorithm>
    #include <fstream>  // NOLINT(readability/streams)
    #include <string>
    #include <utility>
    #include <vector>
    
    #include "boost/scoped_ptr.hpp"
    #include "gflags/gflags.h"
    #include "glog/logging.h"
    
    #include "caffe/proto/caffe.pb.h"
    #include "caffe/util/db.hpp"
    #include "caffe/util/io.hpp"
    #include "caffe/util/rng.hpp"
    
    using namespace caffe;  // NOLINT(build/namespaces)
    using std::pair;
    using boost::scoped_ptr;
    
    DEFINE_bool(gray, false,
        "When this option is on, treat images as grayscale ones");
    DEFINE_bool(shuffle, false,
        "Randomly shuffle the order of images and their labels");
    DEFINE_string(backend, "lmdb",
            "The backend {lmdb, leveldb} for storing the result");
    DEFINE_int32(resize_width, 0, "Width images are resized to");
    DEFINE_int32(resize_height, 0, "Height images are resized to");
    DEFINE_bool(check_size, false,
        "When this option is on, check that all the datum have the same size");
    DEFINE_bool(encoded, false,
        "When this option is on, the encoded image will be save in datum");
    DEFINE_string(encode_type, "",
        "Optional: What type should we encode the image as ('png','jpg',...).");
    
    int main(int argc, char** argv) {
      ::google::InitGoogleLogging(argv[0]);
    
    #ifndef GFLAGS_GFLAGS_H_
      namespace gflags = google;
    #endif
    
      gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
            "format used as input for Caffe.\n"
            "Usage:\n"
            "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
            "The ImageNet dataset for the training demo is at\n"
            "    http://www.image-net.org/download-images\n");
      gflags::ParseCommandLineFlags(&argc, &argv, true);
    
      if (argc < 4) {
        gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
        return 1;
      }
    
      const bool is_color = !FLAGS_gray;
      const bool check_size = FLAGS_check_size;
      const bool encoded = FLAGS_encoded;
      const string encode_type = FLAGS_encode_type;
      
      std::cout<<"starting........"<<std::endl;
      std::ifstream infile(argv[2]);
      std::cout<<argv[2]<<std::endl;
    std::vector<std::pair<std::pair<std::string, int>, std::vector<float> > > lines;
      std::string filename;
      int label;
      std::vector<float> point;
    float ppoint[50]={0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
    
      while (infile >> filename >>label  >>ppoint[0] >> ppoint[1] >> ppoint[2] >> ppoint[3] >> ppoint[4] >> ppoint[5] >> ppoint[6] >> ppoint[7] >> ppoint[8] >> ppoint[9] >> ppoint[10] >> ppoint[11] >> ppoint[12] >> ppoint[13] >> ppoint[14] >> ppoint[15] >> ppoint[16] >> ppoint[17] >> ppoint[18] >> ppoint[19] >> ppoint[20] >> ppoint[21] >> ppoint[22] >> ppoint[23] >> ppoint[24] >> ppoint[25] >> ppoint[26] >> ppoint[27] >> ppoint[28] >> ppoint[29] >> ppoint[30] >> ppoint[31] >> ppoint[32] >> ppoint[33] >> ppoint[34] >> ppoint[35] >> ppoint[36] >> ppoint[37] >> ppoint[38] >> ppoint[39] >> ppoint[40] >> ppoint[41] >> ppoint[42] >> ppoint[43] >> ppoint[44] >> ppoint[45] >> ppoint[46] >> ppoint[47] >> ppoint[48] >> ppoint[49])
    
     {
    
        for(int ii=0;ii<50;ii++)
        {
          point.push_back(ppoint[ii]);
        }
        //std::cout<<filename<<std::endl;
        //std::cout<<point[0]<<std::endl;
        //std::cout<<point[1]<<std::endl;
        lines.push_back(std::make_pair(std::make_pair(filename, label),point));
        point.clear();
      }
    printf("xxxxxxxxxxxxxxxxxxxxxxxxx\n");
      if (FLAGS_shuffle) {
        // randomly shuffle data
        LOG(INFO) << "Shuffling data";
        shuffle(lines.begin(), lines.end());
      }
      LOG(INFO) << "A total of " << lines.size() << " images.";
    
      if (encode_type.size() && !encoded)
        LOG(INFO) << "encode_type specified, assuming encoded=true.";
    
      int resize_height = std::max<int>(0, FLAGS_resize_height);
      int resize_width = std::max<int>(0, FLAGS_resize_width);
    
      // Create new DB
      scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
      db->Open(argv[3], db::NEW);
      scoped_ptr<db::Transaction> txn(db->NewTransaction());
    
      // Storing to db
      std::string root_folder(argv[1]);
      Datum datum;
      int count = 0;
      const int kMaxKeyLength = 256;
      char key_cstr[kMaxKeyLength];
      int data_size = 0;
      bool data_size_initialized = false;
    
      for (int line_id = 0; line_id < lines.size(); ++line_id) {
        bool status;
        std::string enc = encode_type;
        if (encoded && !enc.size()) {
          // Guess the encoding type from the file name
          string fn = lines[line_id].first.first;
          size_t p = fn.rfind('.');
          if ( p == fn.npos )
            LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
          enc = fn.substr(p);
          std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
        }
    
        datum.clear_data();
        datum.clear_sim();
        status = ReadImageToDatum(root_folder + lines[line_id].first.first,
            lines[line_id].first.second, resize_height, resize_width, is_color,
            enc, &datum);
        if (status == false) continue;
            
        //added multi point values
        std::vector<float> pp;
        //std::cout<<lines[line_id].first.first<<std::endl;
        for(int ii=0;ii<50;ii++)
        {       
            pp = lines[line_id].second;
            //std::cout<<pp[ii]<<std::endl;
            datum.add_sim(pp[ii]);
        }
        //std::cout<<datum.sim_size()<<std::endl;
    
        if (check_size) {
          if (!data_size_initialized) {
            data_size = datum.channels() * datum.height() * datum.width();
            data_size_initialized = true;
          } else {
            const std::string& data = datum.data();
            CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
                << data.size();
          }
        }
        // sequential
        int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
            lines[line_id].first.first.c_str());
    
        // Put in db
        string out;
        CHECK(datum.SerializeToString(&out));
        txn->Put(string(key_cstr, length), out);
    
        if (++count % 1000 == 0) {
          // Commit db
          txn->Commit();
          txn.reset(db->NewTransaction());
          LOG(ERROR) << "Processed " << count << " files.";
        }
      }
    
      // write the last batch
      if (count % 1000 != 0) {
        txn->Commit();
        LOG(ERROR) << "Processed " << count << " files.";
      }
      return 0;
    }
    
    

    (2)修改caffe.proto


    (3)利用sh文件生成LMDB


    (4)利用LMDB进行训练
    需要注意的是利用sim输出后,需要用Slice层将多标签进行切分。


    相关文章

      网友评论

          本文标题:Caffe | 多标签训练的三种策略

          本文链接:https://www.haomeiwen.com/subject/jhohdqtx.html