tnn
TNN_NS::TNN是TNN对外暴露接口的管理类,如下使用
//api接口的管理对象
TNN_NS::TNN tnn;
//配置MODEL_TYPE_TNN和模型文件(tnnparam和tnnmodel文件)
TNN_NS::ModelConfig model_config;
//读取tnnparam和tnnmodel文件
auto proto_tnn = fdLoadFile(model_param);
auto model_tnn = fdLoadFile(model);
model_config.model_type = TNN_NS::MODEL_TYPE_TNN;
model_config.params = {proto_tnn, model_tnn};
初始化
tnn.Init(model_config);
TNN这个类的主要目的是为了对外暴露接口,主要有四个成员函数,分别是Init、DeInit、AddOutput和CreateInst,先来看下include/tnn下的tnn.h,如下
class TNNImpl;
class PUBLIC TNN {
public:
TNN();
~TNN();
// 初始化函数,解析模型文件
Status Init(ModelConfig& config);
// 释放模型解析器.
Status DeInit();
// 增加输出节点,先通过output_name,如果没有找到名字,就通过索引 查找节点.
Status AddOutput(const std::string& output_name, int output_index = 0);
// 用于创建网络执行器Instance
std::shared_ptr<Instance> CreateInst(
NetworkConfig& config, Status& status,
InputShapesMap inputs_shape = InputShapesMap());
private:
//要暴露接口的tnn内部类
std::shared_ptr<TNNImpl> impl_ = nullptr;
};
由于这个类的目的是为了暴露接口,所以它的四个成员函数其实对应了TNNImpl中的四个函数,来看看source/tnn/core下的tnn.cc,
Status TNN::Init(ModelConfig& config) {
//tnn内部需要暴露接口的类的实例对象
impl_ = TNNImplManager::GetTNNImpl(config.model_type);
if (!impl_) {
LOGE("Error: not support mode type: %d\n", config.model_type);
return Status(TNNERR_NET_ERR, "not support mode type");
}
//对应实例对象初始化函数,解析模型文件
return impl_->Init(config);
}
//对应释放模型解析器
Status TNN::DeInit() {
impl_ = nullptr;
return TNN_OK;
}
Status TNN::AddOutput(const std::string& layer_name, int output_index) {
// todo for output index
if (!impl_) {
LOGE("Error: impl_ is nil\n");
return Status(TNNERR_NET_ERR, "tnn impl_ is nil");
}
//对应增加输出节点函数
return impl_->AddOutput(layer_name, output_index);
}
std::shared_ptr<Instance> TNN::CreateInst(NetworkConfig& config, Status& status, InputShapesMap inputs_shape) {
if (!impl_) {
status = Status(TNNERR_NET_ERR, "tnn impl_ is nil");
return nullptr;
}
//对应创建网络执行器Instance,返回Instance对象
return impl_->CreateInst(config, status, inputs_shape);
}
网友评论