美文网首页嵌牛IT观察
Tensorflow常见模型及工程化方法

Tensorflow常见模型及工程化方法

作者: 2464d55f8c99 | 来源:发表于2018-12-09 08:37 被阅读0次

    姓名 李林涛 学号 16020199032

    转自:http://mp.weixin.qq.com/s?__biz=MzA5NDM4MjMxOQ==&mid=2447579217&idx=1&sn=93eede10b5c027917b91f9cc819e74e7&chksm=8458c9d1b32f40c75c6c3a0ec0434842e1c9efa92c7d0625d7760042acf7cf7ae13c05b1db58&mpshare=1&scene=23&srcid=1209EwoWaCSocVcEitXFbtTc#rd

    【嵌牛导读】:作者介绍了Tensorflow中的常见模型和工程化方法

    【嵌牛鼻子】:Tensorflow

    【嵌牛提问】:Tensorflow的常见框架和工程化方法是什么?

    【嵌牛正文】:

    Tensorflow在深度学习模型研究中起到了很大的促进作用,灵活的框架免去了研究人员、开发者大量的自动求导代码工作。本文总结一下常用的模型代码和工程化需要的代码。有需求的同学收藏一下,以便日后查阅。

    Tensorflow常见模型

    A. LSTM模型结构

    import tensorflow as tf

    import tensorflow.contrib as contrib

    from tensorflow.python.ops import array_ops

    class lstm(object):

        def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):

            self.in_data = in_data

            self.hidden_dim = hidden_dim

            self.batch_seqlen = batch_seqlen

            self.flag = flag

            lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)

            out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)

            if flag=='all_ht':

                self.out = out

            if flag = 'first_ht':

                self.out = out[:,0,:]

            if flag = 'last_ht':

                self.out = out[:,-1,:]

            if flag = 'concat':

                self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)

    B. Bi-LSTM模型结构

    import tensorflow as tf

    import tensorflow.contrib as contrib

    from tensorflow.python.ops import array_ops

    from tensorflow.python.framework import dtypes

    class bilstm(object):

        def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):

            self.in_data = in_data

            self.hidden_dim = hidden_dim

            self.batch_seqlen = batch_seqlen

            self.flag = flag

            lstm_cell_fw = contrib.rnn.LSTMCell(self.hidden_dim)

            lstm_cell_bw = contrib.rnn.LSTMCell(self.hidden_dim)

          out, state =

    tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_cell_fw,cell_bw=lstm_cell_bw,inputs=self.in_data,

    sequence_lenth=self.batch_seqlen,dtype=tf.float32)

            bi_out = tf.concat(out, 2)

            if flag=='all_ht':

                self.out = bi_out

            if flag=='first_ht':

                self.out = bi_out[:,0,:]

            if flag=='last_ht':

                self.out = tf.concat([state[0].h,state[1].h], 1)

            if flag=='concat':

                self.out = tf.concat([bi_out[:,0,:],tf.concat([state[0].h,state[1].h], 1)],1)

    C multi-channel CNN

    import tensorflow as tf

    import tensorflow.contrib as contrib

    from tensorflow.python.ops import array_ops

    class lstm(object):

        def __init__(self, in_data, hidden_dim, batch_seqlen=None, flag='concat'):

            self.in_data = in_data

            self.hidden_dim = hidden_dim

            self.batch_seqlen = batch_seqlen

            self.flag = flag

            lstm_cell = contrib.rnn.LSTMCell(self.hidden_dim)

            out, _ = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=self.in_data, sequence_length=self.batch_seqlen,dtype=tf.float32)

            if flag=='all_ht':

                self.out = out

            if flag = 'first_ht':

                self.out = out[:,0,:]

            if flag = 'last_ht':

                self.out = out[:,-1,:]

            if flag = 'concat':

                self.out = tf.concat([out[:,0,:], out[:,-1,:]],1)

    D depth-wise cnn

    import tensorflow as tf

    def depth_wise_conv(in_data, scope, kernel_size, dim):

        with tf.variable_scope(scope):

            shapes = in_data.shape.as_list()

            depthwise_filter = tf.get_varibale("depthwise_conv.weight",

                                                (kernel_size[0], kernel_size[1], shapes[-1]

                                                dtype=tf.float32, )

            pointwise_filter = tf.get_variable("pointwise_conv.weight",

                                                (1,1, shapes[-1], dim),

                                                dtype=tf.float32, )

            outputs = tf.nn.separable_conv2d(in_data, 

                                             depthwise_filter,

                                             pointwise_filter,

                                             strides=(1,1,1,1),

                                             padding="SAME"

                                            )

            return outputs

    D multi-layer depth-wise cnn

    def multi_convs(input_x, dim, conv_number=2, k=5):

        # input_x: 输入数据,为batch * seq * dim

        # dim:对应的输入的维度

        # conv_number: 对应的卷积的层数,一般2,

        # k对应的是卷积核的窗口大小

        res = input_x

        for index in range(conv_number):

            out = norm(res)  # layer norm

            out = tf.expand_dims(out, 2)  # bach * seq * 1 * dim

            out = depth_wise_conv(out, kernel_size=(k, 1), dim=dim, scope="convs.%d" % index)

            out = tf.squeeze(out, 2)  # batch * seq * dim

            out = tf.nn.relu(out)

            out = out + res

            res = out

        out = norm(out)                        # 输出为 batch * seq * 1 * dim

        out = tf.squeeze(out, squeeze_dims=2)  # 输出为 batch * seq * dim

        return out

    模型参数查看

    已知模型文件的ckpt文件,通过pywrap_tensorflow获取模型的各参数名。

    import tensoflow as tf

    from tensorflow.python import pywrap_tensorflow

    model_dir = "./ckpt/"

    ckpt = tf.train.get_checkpoint_state(model_dir)

    ckpt_path = ckpt.model_checkpoint_path

    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)

    param_dict = reader.get_variable_to_shape_map()

    for key, val in param_dict.items():

        try:

            print key, val

        except:

            pass

    工程化方法

    A. tennsorflow模型文件打包成PB文件

    import tensorflow as tf

    from tensorflow.python.tools import freeze_graph

    with tf.Graph().as_default():

        with tf.device("/cpu:0"):

            config = tf.ConfigProto(allow_soft_placement=True)

            with tf.Session(config=config).as_default() as sess:

                model = Your_Model_Name()

                model.build_graph()

                sess.run(tf.initialize_all_variables())

                saver = tf.train.Saver()

                ckpt_path = "/your/model/path"

                saver.restore(sess, ckpt_path)

                graphdef = tf.get_default_graph().as_graph_def()

                tf.train.write_graph(sess.graph_def,"/your/save/path/","save_name.pb",as_text=False)

                frozen_graph = tf.graph_util.convert_variables_to_constants(sess,graphdef,['output/node/name'])

                frozen_graph_trim = tf.graph_util.remove_training_nodes(frozen_graph)

    freeze_graph.freeze_graph('/your/save/path/save_name.pb','',True,

    ckpt_path,'output/node/name','save/restore_all','save/Const:0','frozen_name.pb',True,"")

    B.PB文件读取使用

    output_graph_def = tf.GraphDef()

    with open("your_name.pb","rb") as f:

        output_graph_def.ParseFromString(f.read())

        _ = tf.import_graph_def(output_graph_def, name="")

    node_in = sess.graph.get_tensor_by_name("input_node_name")

    model_out = sess.graph.get_tensor_by_name("out_node_name")

    feed_dict = {node_in:in_data}

    pred = sess.run(model_out, feed_dict)

    相关文章

      网友评论

        本文标题:Tensorflow常见模型及工程化方法

        本文链接:https://www.haomeiwen.com/subject/asrhhqtx.html