Extractor
完成模型网络结构和模型权重参数的载入,就可以运行网络了,这就需要Extractor了。在第一篇中,执行网络的代码:
ncnn::Extractor ex = xxxnet.create_extractor(); //创建网络执行器
ex.set_num_threads(2); //设置执行器使用的线程数
ex.set_light_mode(true); //设置执行器是否使用轻量模式
ex.input("data", in_img); //设置执行器输入
ex.extract("prob", out_img); //提取执行器输出
先看一下Extractor的声明,在net.h中
class Extractor
{
public:
// enable light mode
// intermediate blob will be recycled when enabled
// enabled by default
//是否设置轻量模式
void set_light_mode(bool enable);
// set thread count for this extractor
// this will overwrite the global setting
// default count is system depended
//设置运行的线程数
void set_num_threads(int num_threads);
// set blob memory allocator
//设置blob的内存分配
void set_blob_allocator(Allocator* allocator);
// set workspace memory allocator
//设置工作区的内存分配
void set_workspace_allocator(Allocator* allocator);
#if NCNN_STRING
// set input by blob name
// return 0 if success
//根据输入blob的名字 返回输入
int input(const char* blob_name, const Mat& in);
// get result by blob name
// return 0 if success
//执行器,根据输出blob的名字返回输出
int extract(const char* blob_name, Mat& feat);
#endif // NCNN_STRING
// set input by blob index
// return 0 if success
//根据输入blob的索引 返回输入
int input(int blob_index, const Mat& in);
// get result by blob index
// return 0 if success
//执行器,根据输出blob的索引返回输出
int extract(int blob_index, Mat& feat);
protected:
//友元函数,使其他类可以调用Extractor类中的create_extractor函数
friend Extractor Net::create_extractor() const;
//create_extractor中调用,执行器
Extractor(const Net* net, int blob_count);
private:
const Net* net;
//一个放blob的mat的vector,作用就是存储blob的mat
std::vector<Mat> blob_mats;
//配置设置
Option opt;
};
根据顺序,先来看
ncnn::Extractor ex = xxxnet.create_extractor(); //创建网络执行器
在net.cpp中看到create_extractor()的实现
Extractor Net::create_extractor() const
{
return Extractor(this, blobs.size());
}
如上面注释中说明的,这里调用了Extractor():
Extractor::Extractor(const Net* _net, int blob_count) : net(_net)
{
blob_mats.resize(blob_count);
opt = net->opt;
}
: net(_net)是c++中的变量初始化列表的形式初始化net对象,其实就是把之前载入初始化好的网络结构net对象传进来。
把blob_mats resize到blob的数量的大小。
设置线程数和轻量模式之前提过这里不再赘述,接着看ex.input("data", in_img):
int Extractor::input(const char* blob_name, const Mat& in)
{
int blob_index = net->find_blob_index_by_name(blob_name);
if (blob_index == -1)
return -1;
return input(blob_index, in);
}
注意这里先调用的是根据blob名字输入输入,先用find_blob_index_by_name找到对应的索引,return的时候再调用根据blob索引输入输入。
int Extractor::input(int blob_index, const Mat& in)
{
if (blob_index < 0 || blob_index >= (int)blob_mats.size())
return -1;
blob_mats[blob_index] = in;
return 0;
}
最后到ex.extract("prob", out_img); //提取执行器输出
int Extractor::extract(const char* blob_name, Mat& feat)
{
int blob_index = net->find_blob_index_by_name(blob_name);
if (blob_index == -1)
return -1;
return extract(blob_index, feat);
}
同样的这里根据find_blob_index_by_name找到索引,return的时候再调用根据索引执行输入结果的extract。
int Extractor::extract(int blob_index, Mat& feat)
{
//如果blob的索引小于0或大于blob_mats
的size,说明数据不一致
if (blob_index < 0 || blob_index >= (int)blob_mats.size())
return -1;
int ret = 0;
//如果输出的blob为空
if (blob_mats[blob_index].dims == 0)
{
//查找输出blob对应的producer,这里可以看下blob.h和blob.cpp来了解下blob的数据结构,简单来看就是一个name,一个producer(谁输出的)和一个consumers(谁需要这个blob作为输入)
int layer_index = net->blobs[blob_index].producer;
//开始前向运算
ret = net->forward_layer(layer_index, blob_mats, opt);
}
//输出
feat = blob_mats[blob_index];
if (opt.use_packing_layout)
{
Mat bottom_blob_unpacked;
convert_packing(feat, bottom_blob_unpacked, 1, opt);
feat = bottom_blob_unpacked;
}
return ret;
}
这里重点看一下ret = net->forward_layer(layer_index, blob_mats, opt),可以看下net中的forward_layer函数,forward_layer会不断运行每一层的forward_layer直到最后一层输出。
网友评论