美文网首页
深度学习模型转换onnx2ncnn

深度学习模型转换onnx2ncnn

作者: 半笔闪 | 来源:发表于2020-05-26 17:55 被阅读0次

    我们知道现在的深度学习训练框架(如tensorflow、caffe、pytorch、MXNet等等)都有自己的模型存储格式,那他们之间的转换就是一个常见的需求了,但是如果每个框架都要写转换到其他所有框架的代码,那就麻烦了,如果出现一种新框架,每种框架都要再写一种转换。所以最好的方式应该就是每个框架都有向一种统一的框架转换的代码,等需要转换的时候,只需要先转换成这种统一的模型格式,再由这种格式转到你的目标框架格式,这种统一的模型格式就是ONNX。而这种思想就类似与编译器中IR(Intermediate Representation,中间表达形式)的思想。
    本篇就以ncnn源码中的onnx2ncnn为例,讲解一下onnx的基础以及onnx向其他框架转换的知识。如下图是ncnn源码下tools/onnx:


    image.png

    protobuf中的.proto语法

    这里先通过onnx.proto来讲解一下.proto文件的语法以及onnx的基本数据结构。打开onnx.proto文件,去掉注释,除了头几行外,其他都是一个一个的message,这里截取一个message NodeProto。如下:

    // 定义语法类型,通常proto3好于proto2,proto2好于proto1
    syntax = "proto2";
    // 定义作用域
    package onnx;
    // 类class NodeProto
    message NodeProto {
      repeated string input = 1;    
      repeated string output = 2;  
      optional string name = 3;    
      optional string op_type = 4;  
      optional string domain = 7;   
      repeated AttributeProto attribute = 5;
      optional string doc_string = 6;
    }
    
    1. required: 必须赋值的字符(onnx.proto中没有)
    2. optional: 可有可无的字段,可以使用[default = xxx]配置默认值
    3. repeated: 可重复变长字段,类似数组
      上面我们可以看到没有字段后面都有一个=数字,这是每个字段在每个message里独一无二的tag,tag 1-15是字节编码,16-2047使用2字节编码,所以1-15给频繁使用的字段,这里的tag都没有超过15。而上面的字段类型可以参考下图:


      image.png

      有了.proto文件,那如何使用呢?一般是编译的时候,通过.proto文件生成C++源文件来来使用,这个具体可参考protobuf官网。基本像下面这样一句代码就能生成:

    protoc -I=input_dir --cpp_out=output_dir input_dir/onnx.proto
    

    ONNX基础

    从onnx.proto文件中,我们可以看到,onnx的数据结构,onnx的网络的每一层的数据结构是Node,由这些Node组成Graph,然后Graph和onnx模型的一些其他信息组成一个model,也就是最终的.onnx模型。

    1. NodeProto
    message NodeProto {
    //存放节点输入的名字 [类型:字符串列表]
      repeated string input = 1;   
    // 存放节点输出的名字 [类型:字符串列表]
      repeated string output = 2;  
    //节点名
      optional string name = 3;    
    //节点的算子类型 [类型:字符串]
      optional string op_type = 4;  
    //算子域[类型:字符串]
      optional string domain = 7;   
    //存放节点的属性attributes [类型:任意]
      repeated AttributeProto attribute = 5;
    //描述文档的字符串,这个默认为None [类型:字符串]
      optional string doc_string = 6;
    }
    
    1. GraphProto
    message GraphProto {
    //生成的节点列表 [类型:NodeProto列表
      repeated NodeProto node = 1;
    //graph的名字 [类型:字符串]
      optional string name = 2; 
    //存放超参数 [类型:TensorProto列表]
      repeated TensorProto initializer = 5;
    //描述文档的字符串,这个默认为None [类型:字符串]
      optional string doc_string = 10;
    //存放graph的输入数据信息 [类型:ValueInfoProto列表]
      repeated ValueInfoProto input = 11;
    //存放graph的输出数据信息 [类型:ValueInfoProto列表]
      repeated ValueInfoProto output = 12;
    //存放中间层产生的输出数据的信息 [类型:ValueInfoProto列表]
      repeated ValueInfoProto value_info = 13;
      repeated TensorAnnotation quantization_annotation = 14;
      // repeated string input = 3;
      // repeated string output = 4;
      // optional int64 ir_version = 6;
      // optional int64 producer_version = 7;
      // optional string producer_tag = 8;
      // optional string domain = 9;
    }
    
    1. ModelProto
    message ModelProto {
      optional int64 ir_version = 1;
      repeated OperatorSetIdProto opset_import = 8;
      optional string producer_name = 2;
      optional string producer_version = 3;
      optional string domain = 4;
      optional int64 model_version = 5;
      optional string doc_string = 6;
    //生成的graph
      optional GraphProto graph = 7;
      repeated StringStringEntryProto metadata_props = 14;
    };
    

    onnx2ncnn

    看到onnx2ncnn.cpp,从main进入后,调用read_proto_from_binary载入并解析.onnx文件到onnx::ModelProto model。来看一下read_proto_from_binary:

    static bool read_proto_from_binary(const char* filepath, google::protobuf::Message* message)
    {
        //以都字节的形式打开文件
        std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
        if (!fs.is_open())
        {
            fprintf(stderr, "open failed %s\n", filepath);
            return false;
        }    
        //读入
        google::protobuf::io::IstreamInputStream input(&fs); 
        //反序列化字节流
        google::protobuf::io::CodedInputStream codedstr(&input);
        //限制最大字节数
        codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
        //解析出message
        bool success = message->ParseFromCodedStream(&codedstr);
        //关闭字节流
        fs.close();
        return success;
    }
    

    然后,在main中就是按解析出的node对应ncnn中的一层来写入ncnn模型文件,完成转换。

    相关文章

      网友评论

          本文标题:深度学习模型转换onnx2ncnn

          本文链接:https://www.haomeiwen.com/subject/sysvahtx.html