美文网首页
2018-12-20-STNE代码学习

2018-12-20-STNE代码学习

作者: HollyMeng | 来源:发表于2018-12-20 16:34 被阅读0次

    架构:


    image.png

    参数:


    image.png

    下面以cora为例分析一下构建神经网络的主要代码部分:

    class STNE(object):
        def __init__(self, hidden_dim, node_num, fea_dim, seq_len, depth=1, node_fea=None, node_fea_trainable=False):
            self.node_num, self.fea_dim, self.seq_len = node_num, fea_dim, seq_len
    
    image.png
    image.png
    image.png
    image.png
            self.input_seqs = tf.placeholder(tf.int32, shape=(None, self.seq_len), name='input_seq')
            self.dropout = tf.placeholder(tf.float32, name='dropout')
            if node_fea is not None:
                assert self.node_num == node_fea.shape[0] and self.fea_dim == node_fea.shape[1]
                self.embedding_W = tf.Variable(initial_value=node_fea, name='encoder_embed', trainable=node_fea_trainable)   
            else:
                self.embedding_W = tf.Variable(initial_value=tf.random_uniform(shape=(self.node_num, self.fea_dim)),
                                               name='encoder_embed', trainable=node_fea_trainable)
            input_seq_embed = tf.nn.embedding_lookup(self.embedding_W, self.input_seqs, name='input_embed_lookup')
           
    
    image.png
            # encoder
            encoder_cell_fw_0 = tf.contrib.rnn.DropoutWrapper(LSTMCell(hidden_dim), output_keep_prob=1 - self.dropout)
            encoder_cell_bw_0 = tf.contrib.rnn.DropoutWrapper(LSTMCell(hidden_dim), output_keep_prob=1 - self.dropout)
            if depth == 1:  # for cora and wiki, cell num=1
                encoder_cell_fw_all = tf.contrib.rnn.MultiRNNCell([encoder_cell_fw_0])
                encoder_cell_bw_all = tf.contrib.rnn.MultiRNNCell([encoder_cell_bw_0])
            else:   # for citeseer, cell num=2
                encoder_cell_fw_1 = tf.contrib.rnn.DropoutWrapper(LSTMCell(hidden_dim), output_keep_prob=1 - self.dropout)
                encoder_cell_bw_1 = tf.contrib.rnn.DropoutWrapper(LSTMCell(hidden_dim), output_keep_prob=1 - self.dropout)
    
                encoder_cell_fw_all = tf.contrib.rnn.MultiRNNCell([encoder_cell_fw_0] + [encoder_cell_fw_1] * (depth - 1))
                encoder_cell_bw_all = tf.contrib.rnn.MultiRNNCell([encoder_cell_bw_0] + [encoder_cell_bw_1] * (depth - 1))
    
            encoder_outputs, encoder_final = bi_rnn(encoder_cell_fw_all, encoder_cell_bw_all, inputs=input_seq_embed, dtype=tf.float32)
            c_fw_list, h_fw_list, c_bw_list, h_bw_list = [], [], [], []
            for d in range(depth): 
                (c_fw, h_fw) = encoder_final[0][d]
                (c_bw, h_bw) = encoder_final[1][d]
                c_fw_list.append(c_fw)
                h_fw_list.append(h_fw)
                c_bw_list.append(c_bw)
                h_bw_list.append(h_bw)
    
            decoder_init_state = tf.concat(c_fw_list + c_bw_list, axis=-1), tf.concat(h_fw_list + h_bw_list, axis=-1)
            decoder_cell = tf.contrib.rnn.DropoutWrapper(LSTMCell(hidden_dim * 2), output_keep_prob=1 - self.dropout)
            decoder_init_state = LSTMStateTuple(
                tf.layers.dense(decoder_init_state[0], units=hidden_dim * 2, activation=None),
                tf.layers.dense(decoder_init_state[1], units=hidden_dim * 2, activation=None))
    
            self.encoder_output = tf.concat(encoder_outputs, axis=-1)
            encoder_output_T = tf.transpose(self.encoder_output, [1, 0, 2])  # h
    
            new_state = decoder_init_state
            outputs_list = []
            for i in range(seq_len):
                new_output, new_state = decoder_cell(tf.zeros(shape=tf.shape(encoder_output_T)[1:]), new_state)  # None
                outputs_list.append(new_output)
    
            decoder_outputs = tf.stack(outputs_list, axis=0)  # seq_len * batch_size * hidden_dim
            decoder_outputs = tf.transpose(decoder_outputs, [1, 0, 2])  # batch_size * seq_len * hidden_dim
            self.decoder_outputs = decoder_outputs
            output_preds = tf.layers.dense(decoder_outputs, units=self.node_num, activation=None)
            loss_ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.input_seqs, logits=output_preds)
            self.loss_ce = tf.reduce_mean(loss_ce, name='loss_ce')
    
            self.global_step = tf.Variable(1, name="global_step", trainable=False)
    

    相关文章

      网友评论

          本文标题:2018-12-20-STNE代码学习

          本文链接:https://www.haomeiwen.com/subject/jsaykqtx.html