0.多标签问题
在很多深度学习任务中会用到多标签学习,比如做目标检测任务,如下图所示,图片1中物体类别为1,剩余的四个为其位置坐标。图片2中物体类别为2,剩余的四个为其位置坐标。所以即使一张图片中只有一个目标但其仍然是多标签学习问题。
0000001.jpg 1 72 79 232 273
0000002.jpg 2 67 59 155 161
还有一种多标签问题是利用人脸的某些特性来辅助人脸特征点定位(双眼,鼻子,两个嘴角一共五个特征点)。如下图中的标签解释:第一个字段为图片名,第2,3个字段为第一个特征点的位置,以此类推后边8个字段分别为剩余4个特征点的位置(距离左上角的百分比)。后续的字段为是否戴帽子等等特性,每一列代表了一种特性,是为1不是为0。这也是一种多标签学习问题。
图片名 (距离左上角的位置 距离左上角的位置)
000002.jpg 0.264367816092 0.724137931034 0.425287356322 0.298850574713 0.758620689655 0.241379310345 0.241379310345 0.528735632184 0.729885057471 0.72988505747
1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1
000003.jpg 0.391891891892 0.662162162162 0.648648648649 0.27027027027 0.486486486486 0.256756756757 0.27027027027 0.486486486486 0.662162162162 0.743243243243
0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 1
1.利用HDF5文件实现多标签制作训练
(1)利用python生成hdf5文件
在Caffe中,如果使用LMDB数据格式的话,默认是只支持“图像+整数单标签”这种形式的数据的。如果训练网络需要一些如第一部分中介绍的其他形式的数据或标签(如浮点数据,多标签等等),可以将其制作成HDF5格式。当然万物都有其内在的规律,此消彼长,所以HDF5数据格式虽然比较灵活,但缺点是占用空间较大。
下面的代码是利用python实现hdf5文件的制作。需要注意的是下边的代码仅仅是一种展示如何用python制作h5py文件,并没有严格按照第一部分中介绍的多标签形式来操作。
import random
import os
import numpy as np
from PIL import Image
import h5py
LABELS = {"0":(1,0,0,0,0),"1":(0,1,0,0,0),"2":(0,0,1,0,0),"3":(0,0,0,1,0),"4":(0,0,0,0,1)}
IMAGE_DIR = ['image_train','image_test']
HDF5_FILE = ['hdf5_train.h5','hdf5_test.h5']
LIST_FILE = ['list_train.txt','list_test.txt']
filename = "/home/YL/lenet/traintestTxt/train11_abspath.txt"
LABEL_SIZE = 5 #Multi-labels
MEAN_VALUE = 128
setname = 'train'
with open(filename,'r') as f:
lines = f.readlines()
np.random.shuffle(lines)
sample_size = len(lines)
datas = np.zeros((sample_size,3,56,56))
labels = np.zeros((sample_size,5))
h5_filename = '{}.h5'.format(setname)
#print datas.shape
#print h5_filename
for i,line in enumerate(lines):
data_line = line.split(" ")[0]
labels_line = line.split(" ")[1]
temp_img = np.array(Image.open(data_line)).astype(np.float32)
temp_img = temp_img[:,:,::-1] #turn RGB to BGR
temp_img = temp_img.transpose((2,0,1))# turn 56,56,3 to 3,56,56
temp_img = temp_img.copy()
temp_img.resize(3,56,56)
datas[i,:,:,:] = (temp_img-128)/256
labels[i,:] = np.array(LABELS[labels_line.strip("\n")]).astype(np.int)
print('processed {} images!'.format(i))
with h5py.File(h5_filename,'w') as h:
h['data'] = datas
h['label'] = labels
f.close()
with open('{}.txt'.format(setname),'w') as f:
f.write(os.path.abspath(h5_filename)+'\n')
f.close()
(2)利用hdf5进行训练
正如上面提到过的,DHF5文件往往容量比较大,而且caffe导入单个HDF5文件大小有限制,因此当我们的训练数据较多的时候,往往需要将数据分别写入多个HDF5文件中,并把这多个HDF5文件的路径存放到同一个train.txt中。具体执行如下:
layer {
name: "data"
type: "HDF5Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
hdf5_data_param {
source: "train.txt"
batch_size: 128
shuffle: true
}
}
还需要注意,shuffle是对H5文件进行乱序,而每个H5文件内部的顺序不动。由于可能存在多个HDF5文件,所以HDF5Data的输入是从一个TXT文件读取的列表,train.txt内容示例如下:
train1.h5
train2.h5
...
2.利用LMDB结合HDF5实现多标签制作训练
正如上述提到的HDF5文件的缺点是占用空间较大。而caffe对于单个HDF5文件的大小是有限制的,虽然可以通过上述方法的txt文件解决,但是利用HDF5做多标签时很占用空间。 因此还有一种策略为:将图像文件存为LMDB格式,快速且节省空间;将标签文件存为HDF5格式。并且最终在网络定义Prototxt文件中,同时使用Data层和HDF5层。
(1)图像文件LMDB
将图像文件存为LMDB就不多做解释了,详细观看下述两个博客。
Caffe | 你的第一个分类网络之Caffe训练
Caffe | 你的第一个分类网络之数据准备
(2)标签文件HDF5
import h5py
import numpy as np
import os
label_dim = 45
# 存放标签值的文件
list_txt = 'name_label.txt'
# 要生成的HDFS文件名
hdf5_file_name = 'hdf5_train.h5'
with open(list_txt, 'r') as f:
lines = f.readlines()
samples_num = len(lines)
# 此处可以指定数据类型,如 dtype=float
labels = np.zeros((len(lines), label_dim))
for index in range(samples_num):
img_name, label = lines[index].strip().split()
label_int = [int(l) for l in label]
labels[index, :] = label_int
# 将标签数据写入hdf5文件
h5_file = h5py.File(hdf5_file_name, 'w')
# 此处'multi_label'和网络定义文件中HDF5Data层的top名字是一致的
h5_file.create_dataset('multi_label', data=labels)
h5_file.close()
其中,name_label.txt存放了图像名称和标签,当然这里我们只需要读取其中的标签。这样的话,hdf5_train.h5里面就储存了所有图像对应的标签,每个标签包含多个0或1的值。需要注意的是,在制作图像数据LMDB文件时候的name.txt中的文件名一定要和name_label.txt中标签一一对应。
(3)利用LMDB和HDF5进行训练
name: "MY_NET"
layer {
name: "data"
type: "Data"
top: "data"
top: "label"
data_param {
source: "train_lmdb"
backend: LMDB
batch_size: 1
}
transform_param {
mean_value: 104.0
mean_value: 117.0
mean_value: 123.0
}
}
layer {
name: "multi_label"
type: "HDF5Data"
top: "label"
hdf5_data_param {
source: "train.txt"
batch_size: 1
}
}
train.txt内容示例如下:
train1.h5
train2.h5
...
3.利用LMDB文件实现多标签制作训练
其实不论是第一种利用HDF5结合Slice标签的策略,亦或是第二种LMDB结合HDF5实现多标签的策略,都存在局限性。 因此还有第三种更方便但执行起来会有点麻烦的方法,那就是直接修改caffe网络源码使其满足多标签的输入。
(1)修改convert_imageset
// This program converts a set of images to a lmdb/leveldb by storing them
// as Datum proto buffers.
// Usage:
// convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
//
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
// subfolder1/file1.JPEG 7
// ....
#include <algorithm>
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>
#include "boost/scoped_ptr.hpp"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/rng.hpp"
using namespace caffe; // NOLINT(build/namespaces)
using std::pair;
using boost::scoped_ptr;
DEFINE_bool(gray, false,
"When this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle, false,
"Randomly shuffle the order of images and their labels");
DEFINE_string(backend, "lmdb",
"The backend {lmdb, leveldb} for storing the result");
DEFINE_int32(resize_width, 0, "Width images are resized to");
DEFINE_int32(resize_height, 0, "Height images are resized to");
DEFINE_bool(check_size, false,
"When this option is on, check that all the datum have the same size");
DEFINE_bool(encoded, false,
"When this option is on, the encoded image will be save in datum");
DEFINE_string(encode_type, "",
"Optional: What type should we encode the image as ('png','jpg',...).");
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
#ifndef GFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"
"format used as input for Caffe.\n"
"Usage:\n"
" convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"
"The ImageNet dataset for the training demo is at\n"
" http://www.image-net.org/download-images\n");
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (argc < 4) {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");
return 1;
}
const bool is_color = !FLAGS_gray;
const bool check_size = FLAGS_check_size;
const bool encoded = FLAGS_encoded;
const string encode_type = FLAGS_encode_type;
std::cout<<"starting........"<<std::endl;
std::ifstream infile(argv[2]);
std::cout<<argv[2]<<std::endl;
std::vector<std::pair<std::pair<std::string, int>, std::vector<float> > > lines;
std::string filename;
int label;
std::vector<float> point;
float ppoint[50]={0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
while (infile >> filename >>label >>ppoint[0] >> ppoint[1] >> ppoint[2] >> ppoint[3] >> ppoint[4] >> ppoint[5] >> ppoint[6] >> ppoint[7] >> ppoint[8] >> ppoint[9] >> ppoint[10] >> ppoint[11] >> ppoint[12] >> ppoint[13] >> ppoint[14] >> ppoint[15] >> ppoint[16] >> ppoint[17] >> ppoint[18] >> ppoint[19] >> ppoint[20] >> ppoint[21] >> ppoint[22] >> ppoint[23] >> ppoint[24] >> ppoint[25] >> ppoint[26] >> ppoint[27] >> ppoint[28] >> ppoint[29] >> ppoint[30] >> ppoint[31] >> ppoint[32] >> ppoint[33] >> ppoint[34] >> ppoint[35] >> ppoint[36] >> ppoint[37] >> ppoint[38] >> ppoint[39] >> ppoint[40] >> ppoint[41] >> ppoint[42] >> ppoint[43] >> ppoint[44] >> ppoint[45] >> ppoint[46] >> ppoint[47] >> ppoint[48] >> ppoint[49])
{
for(int ii=0;ii<50;ii++)
{
point.push_back(ppoint[ii]);
}
//std::cout<<filename<<std::endl;
//std::cout<<point[0]<<std::endl;
//std::cout<<point[1]<<std::endl;
lines.push_back(std::make_pair(std::make_pair(filename, label),point));
point.clear();
}
printf("xxxxxxxxxxxxxxxxxxxxxxxxx\n");
if (FLAGS_shuffle) {
// randomly shuffle data
LOG(INFO) << "Shuffling data";
shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << " images.";
if (encode_type.size() && !encoded)
LOG(INFO) << "encode_type specified, assuming encoded=true.";
int resize_height = std::max<int>(0, FLAGS_resize_height);
int resize_width = std::max<int>(0, FLAGS_resize_width);
// Create new DB
scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));
db->Open(argv[3], db::NEW);
scoped_ptr<db::Transaction> txn(db->NewTransaction());
// Storing to db
std::string root_folder(argv[1]);
Datum datum;
int count = 0;
const int kMaxKeyLength = 256;
char key_cstr[kMaxKeyLength];
int data_size = 0;
bool data_size_initialized = false;
for (int line_id = 0; line_id < lines.size(); ++line_id) {
bool status;
std::string enc = encode_type;
if (encoded && !enc.size()) {
// Guess the encoding type from the file name
string fn = lines[line_id].first.first;
size_t p = fn.rfind('.');
if ( p == fn.npos )
LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";
enc = fn.substr(p);
std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);
}
datum.clear_data();
datum.clear_sim();
status = ReadImageToDatum(root_folder + lines[line_id].first.first,
lines[line_id].first.second, resize_height, resize_width, is_color,
enc, &datum);
if (status == false) continue;
//added multi point values
std::vector<float> pp;
//std::cout<<lines[line_id].first.first<<std::endl;
for(int ii=0;ii<50;ii++)
{
pp = lines[line_id].second;
//std::cout<<pp[ii]<<std::endl;
datum.add_sim(pp[ii]);
}
//std::cout<<datum.sim_size()<<std::endl;
if (check_size) {
if (!data_size_initialized) {
data_size = datum.channels() * datum.height() * datum.width();
data_size_initialized = true;
} else {
const std::string& data = datum.data();
CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
<< data.size();
}
}
// sequential
int length = snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
lines[line_id].first.first.c_str());
// Put in db
string out;
CHECK(datum.SerializeToString(&out));
txn->Put(string(key_cstr, length), out);
if (++count % 1000 == 0) {
// Commit db
txn->Commit();
txn.reset(db->NewTransaction());
LOG(ERROR) << "Processed " << count << " files.";
}
}
// write the last batch
if (count % 1000 != 0) {
txn->Commit();
LOG(ERROR) << "Processed " << count << " files.";
}
return 0;
}
(2)修改caffe.proto
(3)利用sh文件生成LMDB
(4)利用LMDB进行训练
需要注意的是利用sim输出后,需要用Slice层将多标签进行切分。
网友评论