美文网首页
Tensorflow中float32模型强制转为float16半

Tensorflow中float32模型强制转为float16半

作者: huachao1001 | 来源:发表于2020-12-10 18:11 被阅读0次

    最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】

    在Tensorflow框架训练完成后,部署模型时希望对模型进行压缩。一种方案是前面文字介绍的方法《【Ubuntu】Tensorflow对训练后的模型做8位(uint8)量化转换》。另一种方法是半浮点量化,今天我们主要介绍如何通过修改Tensorflow的pb文件中的计算节点和常量(const),将float32数据类型的模型大小压缩减半为float16数据类型的模型。

    1 加载pb模型

    封装函数,加载pb模型:

    def load_graph(model_path):
        graph = tf.Graph()
        with graph.as_default():
            graph_def = tf.GraphDef()
            if model_path.endswith("pb"):
                with open(model_path, "rb") as f:
                    graph_def.ParseFromString(f.read())
            else:
                with open(model_path, "r") as pf:
                    text_format.Parse(pf.read(), graph_def)
            tf.import_graph_def(graph_def, name="")
            sess = tf.Session(graph=graph)
            ops=graph.get_operations()
            for op in ops:
                print(op.name)
            return sess
    

    2 重写BatchNorm

    由于BatchNorm对精度比较敏感,需要保持float32类型,因此BatchNorm需要特殊处理。

    #用FusedBatchNormV2替换FusedBatchNorm,以保证反向梯度下降计算时使用的是float
    def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): 
        if target_type == 'fp16':
            dtype = types_pb2.DT_HALF
        elif target_type == 'fp64':
            dtype = types_pb2.DT_DOUBLE
        else:
            dtype = types_pb2.DT_FLOAT
        new_node = graph_def.node.add()
        new_node.op = "FusedBatchNormV2"
        new_node.name = node.name
        new_node.input.extend(node.input)
        new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
        for attr in list(node.attr.keys()):
            if attr == "T":
                node.attr[attr].type = dtype
            new_node.attr[attr].CopyFrom(node.attr[attr])
        print("rewrite fused_batch_norm done!")
    

    3 Graph转换

    重新构造graph,参数从原始pb的graph中拷贝,并转为float16

    
    def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
        #生成新的图数据类型
        if target_type == 'fp16':
            dtype = types_pb2.DT_HALF
        elif target_type == 'fp64':
            dtype = types_pb2.DT_DOUBLE
        else:
            dtype = types_pb2.DT_FLOAT
    
        #加载需要转换的模型
        source_sess = load_graph(model_path)
        source_graph_def = source_sess.graph.as_graph_def()
        #创建新的模图对象
        target_graph_def = graph_pb2.GraphDef()
        target_graph_def.versions.CopyFrom(source_graph_def.versions)
        #对加载的模型遍历计算节点
        for node in source_graph_def.node:
            # 对FusedBatchNorm计算节点替换为FusedBatchNormV2
            if node.op == "FusedBatchNorm":
                rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
                continue
            # 复制计算节点
            new_node = target_graph_def.node.add()
            new_node.op = node.op
            new_node.name = node.name
            new_node.input.extend(node.input)
    
            #对attrs属性进行复制,attrs属性主要关注
            attrs = list(node.attr.keys())
            # BatchNorm属性保持不变
            if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
                for attr in attrs:
                    new_node.attr[attr].CopyFrom(node.attr[attr])
                continue
            # 除了BatchNorm以外其他计算节点的属性单独
            for attr in attrs:
                # 对指定的计算节点保持不变
                if node.name in keep_fp32_node_name:
                    new_node.attr[attr].CopyFrom(node.attr[attr])
                    continue
                #将Float类型修改为设置的目标类型
                if node.attr[attr].type == types_pb2.DT_FLOAT:
                    # modify node dtype
                    node.attr[attr].type = dtype
                    
                #重点关注value,weights都是保存在value属性中
                if attr == "value":
                    tensor = node.attr[attr].tensor
                    if tensor.dtype == types_pb2.DT_FLOAT:
                        # if float_val exists
                        if tensor.float_val:
                            float_val = tf.make_ndarray(node.attr[attr].tensor)
                            new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
                            continue
                        # if tensor content exists
                        if tensor.tensor_content:
                            tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                            tensor_weights = tf.make_ndarray(tensor)
                            # reshape tensor
                            tensor_weights = np.reshape(tensor_weights, tensor_shape)
                            tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
                            new_node.attr[attr].tensor.CopyFrom(tensor_proto)
                            continue
                new_node.attr[attr].CopyFrom(node.attr[attr])
        # transform graph
        if output_names:
            if not input_name:
                input_name = []
            transforms = ["strip_unused_nodes"]
            target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
        # write graph_def to model
        tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
        print("Converting done ...")
    

    4 完整的代码

    import tensorflow as tf
    from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
    from tensorflow.tools.graph_transforms import TransformGraph
    from google.protobuf import text_format
    import numpy as np
    
    # object detection api input and output nodes
    input_name = "input_tf"
    output_names = ["output:0"]
    keep_fp32_node_name = []
    
    def load_graph(model_path):
        graph = tf.Graph()
        with graph.as_default():
            graph_def = tf.GraphDef()
            if model_path.endswith("pb"):
                with open(model_path, "rb") as f:
                    graph_def.ParseFromString(f.read())
            else:
                with open(model_path, "r") as pf:
                    text_format.Parse(pf.read(), graph_def)
            tf.import_graph_def(graph_def, name="")
            sess = tf.Session(graph=graph)
            ops=graph.get_operations()
            for op in ops:
                print(op.name)
            return sess
    
    #用FusedBatchNormV2替换FusedBatchNorm,以保证反向梯度下降计算时使用的是float
    def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): 
        if target_type == 'fp16':
            dtype = types_pb2.DT_HALF
        elif target_type == 'fp64':
            dtype = types_pb2.DT_DOUBLE
        else:
            dtype = types_pb2.DT_FLOAT
        new_node = graph_def.node.add()
        new_node.op = "FusedBatchNormV2"
        new_node.name = node.name
        new_node.input.extend(node.input)
        new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
        for attr in list(node.attr.keys()):
            if attr == "T":
                node.attr[attr].type = dtype
            new_node.attr[attr].CopyFrom(node.attr[attr])
        print("rewrite fused_batch_norm done!")
    
    def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
        #生成新的图数据类型
        if target_type == 'fp16':
            dtype = types_pb2.DT_HALF
        elif target_type == 'fp64':
            dtype = types_pb2.DT_DOUBLE
        else:
            dtype = types_pb2.DT_FLOAT
    
        #加载需要转换的模型
        source_sess = load_graph(model_path)
        source_graph_def = source_sess.graph.as_graph_def()
        #创建新的模图对象
        target_graph_def = graph_pb2.GraphDef()
        target_graph_def.versions.CopyFrom(source_graph_def.versions)
        #对加载的模型遍历计算节点
        for node in source_graph_def.node:
            # 对FusedBatchNorm计算节点替换为FusedBatchNormV2
            if node.op == "FusedBatchNorm":
                rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
                continue
            # 复制计算节点
            new_node = target_graph_def.node.add()
            new_node.op = node.op
            new_node.name = node.name
            new_node.input.extend(node.input)
    
            #对attrs属性进行复制,attrs属性主要关注
            attrs = list(node.attr.keys())
            # BatchNorm属性保持不变
            if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
                for attr in attrs:
                    new_node.attr[attr].CopyFrom(node.attr[attr])
                continue
            # 除了BatchNorm以外其他计算节点的属性单独
            for attr in attrs:
                # 对指定的计算节点保持不变
                if node.name in keep_fp32_node_name:
                    new_node.attr[attr].CopyFrom(node.attr[attr])
                    continue
                #将Float类型修改为设置的目标类型
                if node.attr[attr].type == types_pb2.DT_FLOAT:
                    # modify node dtype
                    node.attr[attr].type = dtype
                    
                #重点关注value,weights都是保存在value属性中
                if attr == "value":
                    tensor = node.attr[attr].tensor
                    if tensor.dtype == types_pb2.DT_FLOAT:
                        # if float_val exists
                        if tensor.float_val:
                            float_val = tf.make_ndarray(node.attr[attr].tensor)
                            new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
                            continue
                        # if tensor content exists
                        if tensor.tensor_content:
                            tensor_shape = [x.size for x in tensor.tensor_shape.dim]
                            tensor_weights = tf.make_ndarray(tensor)
                            # reshape tensor
                            tensor_weights = np.reshape(tensor_weights, tensor_shape)
                            tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
                            new_node.attr[attr].tensor.CopyFrom(tensor_proto)
                            continue
                new_node.attr[attr].CopyFrom(node.attr[attr])
        # transform graph
        if output_names:
            if not input_name:
                input_name = []
            transforms = ["strip_unused_nodes"]
            target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
        # write graph_def to model
        tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
        print("Converting done ...")
    
    save_path = "test"
    name = "output_fp16.pb"
    model_path="test.pb"
    as_text = False
    target_type = 'fp16'
    convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names)
    # 测试一下转换后的模型是否能够加载
    sess = load_graph(save_path+"/"+name)
    

    相关文章

      网友评论

          本文标题:Tensorflow中float32模型强制转为float16半

          本文链接:https://www.haomeiwen.com/subject/izyrgktx.html