(八)sequence to sequence —1

作者: 天生smile | 来源:发表于2018-12-12 16:01 被阅读0次

    基于tensorflow1.4 Seq2seq的实现

    import helpers
    import tensorflow as tf
    from tensorflow.contrib import seq2seq,rnn
    sess = tf.InteractiveSession()
    PAD = 0
    EOS = 1
    vocab_size = 10
    input_embedding_size = 20
    encoder_hidden_units = 25
    decoder_hidden_units = encoder_hidden_units
    import helpers as data_helpers
    batch_size = 10
    # 一个generator,每次产生一个minibatch的随机样本
    batches = data_helpers.random_sequences(length_from=3, length_to=8,
                                       vocab_lower=2, vocab_upper=10,
    print('产生%d个长度不一(最短3,最长8)的sequences, 其中前十个是:' % batch_size)
    for seq in next(batches)[:min(batch_size, 10)]:
    产生10个长度不一(最短3,最长8)的sequences, 其中前十个是:
    [9, 4, 4, 6]
    [4, 3, 3, 2]
    [5, 7, 4, 4]
    [5, 6, 6, 4, 6, 7, 3]
    [6, 7, 5, 2, 8, 6, 8]
    [5, 6, 9, 4, 6, 9, 6, 9]
    [3, 5, 2, 2, 9]
    [5, 6, 5, 8, 9, 8]
    [6, 8, 2, 4, 3]
    [9, 6, 8, 3, 5, 2]


    sess = tf.InteractiveSession()
    mode = tf.contrib.learn.ModeKeys.TRAIN

    1. 计算图的数据的placeholder

    with tf.name_scope('minibatch'):
        encoder_inputs = tf.placeholder(tf.int32, [None, None], name='encoder_inputs')
        encoder_inputs_length = tf.placeholder(tf.int32, [None], name='encoder_inputs_length')
        decoder_targets = tf.placeholder(tf.int32, [None, None], name='decoder_targets')
        decoder_inputs = tf.placeholder(shape=(None, None),dtype=tf.int32,name='decoder_inputs')
        decoder_inputs_length = tf.placeholder(shape=(None,),

    2.定义lstm cell 这里使用1层的lstm

    def _create_rnn_cell():
        def single_rnn_cell(encoder_hidden_units):
            # 创建单个cell,这里需要注意的是一定要使用一个single_rnn_cell的函数,不然直接把cell放在MultiRNNCell
            # 的列表中最终模型会发生错误
            single_cell = rnn.LSTMCell(encoder_hidden_units)
            single_cell = rnn.DropoutWrapper(single_cell, output_keep_prob=0.5)
            return single_cell
                #cell = rnn.MultiRNNCell([single_rnn_cell() for _ in range(self.num_layers)])
        cell = rnn.MultiRNNCell([single_rnn_cell(encoder_hidden_units) for _ in range(1)])
        return cell 



    1.定义encoder 部分

    with tf.variable_scope('encoder'):
        # 创建LSTMCell
        encoder_cell = _create_rnn_cell()
        # 构建embedding矩阵,encoder和decoder公用该词向量矩阵
        embedding = tf.get_variable('embedding', [vocab_size,input_embedding_size])
        encoder_inputs_embedded = tf.nn.embedding_lookup(embedding,encoder_inputs)
        # 使用dynamic_rnn构建LSTM模型,将输入编码成隐层向量。
        # encoder_outputs用于attention,batch_size*encoder_inputs_length*rnn_size,
        # encoder_state用于decoder的初始化状态,batch_size*rnn_szie
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded,

    2.定义decoder 部分(暂时不添加attention部分)

    with tf.variable_scope('decoder'):
        decoder_cell = _create_rnn_cell()
        decoder_initial_state = encoder_state
        output_layer = tf.layers.Dense(vocab_size,kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
        decoder_inputs_embedded = tf.nn.embedding_lookup(embedding, decoder_inputs)
        # 训练阶段,使用TrainingHelper+BasicDecoder的组合,这一般是固定的,当然也可以自己定义Helper类,实现自己的功能
        training_helper = seq2seq.TrainingHelper(inputs=decoder_inputs_embedded,
                                                            time_major=False, name='training_helper')
        training_decoder = seq2seq.BasicDecoder(cell=decoder_cell, helper=training_helper,
        # 调用dynamic_decode进行解码,decoder_outputs是一个namedtuple,里面包含两项(rnn_outputs, sample_id)
        # rnn_output: [batch_size, decoder_targets_length, vocab_size],保存decode每个时刻每个单词的概率,可以用来计算loss
        # sample_id: [batch_size], tf.int32,保存最终的编码结果。可以表示最后的答案
        max_target_sequence_length = tf.reduce_max(decoder_inputs_length, name='max_target_len')
        decoder_outputs, _, _ = seq2seq.dynamic_decode(decoder=training_decoder,
        decoder_logits_train = tf.identity(decoder_outputs.rnn_output)
        sample_id = decoder_outputs.sample_id
        #decoder_predict_train = tf.argmax(decoder_logits_train, axis=-1,name='decoder_pred_train')
        #decoder_predict_decode = tf.expand_dims(decoder_outputs.sample_id, -1)
        # 根据目标序列长度,选出其中最大值,然后使用该值构建序列长度的mask标志。用一个sequence_mask的例子来说明起作用
        #  tf.sequence_mask([1, 3, 2], 5)
        #  [[True, False, False, False, False],
        #  [True, True, True, False, False],
        #  [True, True, False, False, False]]
        max_target_sequence_length = tf.reduce_max(decoder_inputs_length, name='max_target_len')
        mask = tf.sequence_mask(decoder_inputs_length,max_target_sequence_length, dtype=tf.float32, name='masks')
        print('\t%s' % repr(decoder_logits_train))
        print('\t%s' % repr(decoder_targets))
        print('\t%s' % repr(sample_id))
        loss = seq2seq.sequence_loss(logits=decoder_logits_train,targets=decoder_targets, weights=mask)
        <tf.Tensor 'decoder/Identity:0' shape=(?, ?, 10) dtype=float32>
        <tf.Tensor 'minibatch/decoder_targets:0' shape=(?, ?) dtype=int32>
        <tf.Tensor 'decoder/decoder/transpose_1:0' shape=(?, ?) dtype=int32>
    train_op = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss)
    def next_feed():
        batch = next(batches)
        encoder_inputs_, encoder_inputs_length_ = data_helpers.batch(batch)
        decoder_targets_, decoder_targets_length_ = data_helpers.batch(
            [(sequence) + [EOS] for sequence in batch]
        decoder_inputs_, decoder_inputs_length_ = data_helpers.batch(
            [[EOS] + (sequence) for sequence in batch]
        # 在feedDict里面,key可以是一个Tensor
        return {
            encoder_inputs: encoder_inputs_.T,
            decoder_inputs: decoder_inputs_.T,
            decoder_targets: decoder_targets_.T,
            encoder_inputs_length: encoder_inputs_length_,
            decoder_inputs_length: decoder_inputs_length_
    x = next_feed()
    [9 4 3 3 2 6 0 0]
    [1 9 4 3 3 2 6 0 0]
    [9 4 3 3 2 6 1 0 0]
    loss_track = []
    max_batches = 3001
    batches_in_epoch = 100
        # 一个epoch的learning
        for batch in range(max_batches):
            fd = next_feed()
            _, l = sess.run([train_op, loss], fd)
            if batch == 0 or batch % batches_in_epoch == 0:
                print('batch {}'.format(batch))
                print('  minibatch loss: {}'.format(sess.run(loss, fd)))
                predict_ = sess.run(decoder_outputs.sample_id, fd)
                for i, (inp, pred) in enumerate(zip(fd[encoder_inputs], predict_)):
                    print('  sample {}:'.format(i + 1))
                    print('    input     > {}'.format(inp))
                    print('    predicted > {}'.format(pred))
                    if i >= 2:
    except KeyboardInterrupt:
        print('training interrupted')
    batch 0
      minibatch loss: 2.2938551902770996
      sample 1:
        input     > [8 5 3 9 3 5 0 0]
        predicted > [4 4 4 4 4 1 4 0 0]
      sample 2:
        input     > [9 5 8 4 4 6 4 7]
        predicted > [9 3 4 4 4 9 9 4 9]
      sample 3:
        input     > [6 6 5 7 6 8 0 0]
        predicted > [1 4 4 3 3 3 4 0 0]
    batch 100
      minibatch loss: 2.1440541744232178
      sample 1:
        input     > [5 5 3 7 2 5 0 0]
        predicted > [7 5 5 5 5 7 1 0 0]
      sample 2:
        input     > [3 2 7 2 4 9 6 8]
        predicted > [2 2 2 2 2 1 1 1 1]
      sample 3:
        input     > [6 8 6 2 0 0 0 0]
        predicted > [2 9 2 1 1 0 0 0 0]
    batch 200
      minibatch loss: 1.7902907133102417
      sample 1:
        input     > [8 5 6 9 6 6 7 0]
        predicted > [7 5 7 9 5 7 5 1 0]
      sample 2:
        input     > [5 3 4 0 0 0 0 0]
        predicted > [5 3 1 1 0 0 0 0 0]
      sample 3:
        input     > [8 9 3 6 6 4 6 2]
        predicted > [6 9 8 4 4 4 2 1 1]
    batch 300
      minibatch loss: 1.6711502075195312
      sample 1:
        input     > [6 5 6 5 7 0 0 0]
        predicted > [7 7 7 7 5 1 0 0 0]
      sample 2:
        input     > [7 8 6 9 7 2 7 0]
        predicted > [5 7 7 5 7 7 7 1 0]
      sample 3:
        input     > [7 3 8 2 2 0 0 0]
        predicted > [2 2 2 2 1 1 0 0 0]
    batch 400
      minibatch loss: 1.4671175479888916
      sample 1:
        input     > [3 4 8 8 9 0 0 0]
        predicted > [4 8 8 4 2 1 0 0 0]
      sample 2:
        input     > [5 2 6 2 5 4 3 8]
        predicted > [8 8 2 6 2 9 8 8 1]
      sample 3:
        input     > [2 9 6 0 0 0 0 0]
        predicted > [8 6 5 1 0 0 0 0 0]
    batch 500
      minibatch loss: 1.3590279817581177
      sample 1:
        input     > [3 2 2 3 8 8 5 5]
        predicted > [8 8 8 9 5 5 5 1 1]
      sample 2:
        input     > [8 4 6 3 8 2 0 0]
        predicted > [4 4 2 8 2 4 1 0 0]
      sample 3:
        input     > [2 2 6 3 9 9 0 0]
        predicted > [4 2 2 9 6 6 1 0 0]
    batch 600
      minibatch loss: 1.292779564857483
      sample 1:
        input     > [7 9 6 5 0 0 0 0]
        predicted > [5 9 5 5 1 0 0 0 0]
      sample 2:
        input     > [5 9 3 0 0 0 0 0]
        predicted > [5 9 3 1 0 0 0 0 0]
      sample 3:
        input     > [3 3 8 5 6 3 0 0]
        predicted > [3 3 3 3 9 3 1 0 0]
    batch 700
      minibatch loss: 1.2727009057998657
      sample 1:
        input     > [4 4 7 7 8 6 5 7]
        predicted > [3 7 7 7 9 7 5 1 1]
      sample 2:
        input     > [5 4 2 2 7 7 0 0]
        predicted > [2 2 8 7 7 7 1 0 0]
      sample 3:
        input     > [7 3 9 7 8 0 0 0]
        predicted > [3 7 5 7 8 1 0 0 0]
    batch 800
      minibatch loss: 1.1580817699432373
      sample 1:
        input     > [8 3 2 7 8 5 7 0]
        predicted > [4 3 7 7 7 7 7 1 0]
      sample 2:
        input     > [2 8 7 6 7 2 0 0]
        predicted > [2 2 7 7 7 2 1 0 0]
      sample 3:
        input     > [8 7 8 4 3 2 5 8]
        predicted > [8 7 4 3 5 5 5 8 1]
    batch 900
      minibatch loss: 1.1622250080108643
      sample 1:
        input     > [6 8 2 5 5 0 0 0]
        predicted > [8 8 5 5 5 1 0 0 0]
      sample 2:
        input     > [5 9 4 5 7 0 0 0]
        predicted > [5 6 7 7 7 1 0 0 0]
      sample 3:
        input     > [6 2 3 4 9 5 3 9]
        predicted > [4 3 3 4 9 3 9 6 1]
    batch 1000
      minibatch loss: 1.2378357648849487
      sample 1:
        input     > [4 3 2 3 8 7 4 8]
        predicted > [3 4 8 3 2 4 2 2 1]
      sample 2:
        input     > [5 6 5 4 5 8 5 6]
        predicted > [5 5 5 5 5 5 6 6 1]
      sample 3:
        input     > [3 8 4 3 4 3 6 0]
        predicted > [2 4 4 4 4 3 6 1 0]
    batch 1100
      minibatch loss: 1.1085090637207031
      sample 1:
        input     > [4 7 2 0 0 0 0 0]
        predicted > [4 2 8 1 0 0 0 0 0]
      sample 2:
        input     > [6 2 3 5 7 7 2 4]
        predicted > [6 7 7 7 7 7 2 4 1]
      sample 3:
        input     > [9 7 7 3 5 2 4 0]
        predicted > [7 7 7 3 5 8 3 1 0]
    batch 1200
      minibatch loss: 1.1771703958511353
      sample 1:
        input     > [8 2 7 8 9 7 0 0]
        predicted > [8 8 5 8 7 7 1 0 0]
      sample 2:
        input     > [8 8 4 7 2 8 0 0]
        predicted > [8 8 2 2 2 2 1 0 0]
      sample 3:
        input     > [2 9 7 9 4 9 3 2]
        predicted > [9 9 7 8 4 8 3 2 1]
    batch 1300
      minibatch loss: 0.9447832107543945
      sample 1:
        input     > [4 3 2 3 9 6 0 0]
        predicted > [4 3 4 6 9 9 1 0 0]
      sample 2:
        input     > [5 9 4 0 0 0 0 0]
        predicted > [5 6 4 1 0 0 0 0 0]
      sample 3:
        input     > [8 8 8 2 7 8 0 0]
        predicted > [8 8 8 2 5 8 1 0 0]
    batch 1400
      minibatch loss: 1.0269840955734253
      sample 1:
        input     > [5 6 3 5 7 5 6 4]
        predicted > [2 6 5 5 7 6 6 4 1]
      sample 2:
        input     > [2 6 2 4 2 6 0 0]
        predicted > [2 4 2 6 6 6 1 0 0]
      sample 3:
        input     > [2 3 8 4 0 0 0 0]
        predicted > [4 3 8 4 1 0 0 0 0]
    batch 1500
      minibatch loss: 0.8967496752738953
      sample 1:
        input     > [7 7 8 6 4 7 0 0]
        predicted > [7 7 2 4 4 7 1 0 0]
      sample 2:
        input     > [7 8 4 6 0 0 0 0]
        predicted > [7 4 4 6 1 0 0 0 0]
      sample 3:
        input     > [6 7 5 6 8 7 7 6]
        predicted > [7 7 5 6 7 7 7 1 1]
    batch 1600
      minibatch loss: 0.9586960077285767
      sample 1:
        input     > [6 5 8 3 2 4 9 0]
        predicted > [5 5 8 4 2 4 5 1 0]
      sample 2:
        input     > [4 9 6 9 0 0 0 0]
        predicted > [3 9 9 9 1 0 0 0 0]
      sample 3:
        input     > [7 7 9 9 5 2 0 0]
        predicted > [7 5 9 5 5 2 1 0 0]
    batch 1700
      minibatch loss: 1.0395662784576416
      sample 1:
        input     > [5 7 4 5 0 0 0 0]
        predicted > [5 7 4 7 1 0 0 0 0]
      sample 2:
        input     > [3 3 2 8 0 0 0 0]
        predicted > [3 4 2 8 1 0 0 0 0]
      sample 3:
        input     > [6 8 2 7 8 5 0 0]
        predicted > [8 8 2 7 8 7 1 0 0]
    batch 1800
      minibatch loss: 0.9203397035598755
      sample 1:
        input     > [4 5 4 2 5 8 0 0]
        predicted > [4 5 4 5 5 1 1 0 0]
      sample 2:
        input     > [2 7 4 8 8 4 0 0]
        predicted > [7 7 4 8 4 4 1 0 0]
      sample 3:
        input     > [6 6 4 0 0 0 0 0]
        predicted > [6 6 4 1 0 0 0 0 0]
    batch 1900
      minibatch loss: 0.7155815362930298
      sample 1:
        input     > [6 5 2 2 9 7 9 0]
        predicted > [6 2 2 8 9 7 9 1 0]
      sample 2:
        input     > [5 6 2 9 9 4 8 0]
        predicted > [5 9 9 6 9 4 8 1 0]
      sample 3:
        input     > [6 8 2 9 0 0 0 0]
        predicted > [2 8 2 9 1 0 0 0 0]
    batch 2000
      minibatch loss: 0.7423955202102661
      sample 1:
        input     > [3 5 2 9 8 5 3 2]
        predicted > [5 5 2 3 5 2 3 2 1]
      sample 2:
        input     > [8 5 5 9 6 0 0 0]
        predicted > [5 9 5 6 6 1 0 0 0]
      sample 3:
        input     > [6 8 8 0 0 0 0 0]
        predicted > [8 8 8 1 0 0 0 0 0]
    batch 2100
      minibatch loss: 0.8510919213294983
      sample 1:
        input     > [7 7 9 0 0 0 0 0]
        predicted > [7 7 9 1 0 0 0 0 0]
      sample 2:
        input     > [4 2 9 2 5 6 2 6]
        predicted > [2 2 9 2 6 6 6 6 1]
      sample 3:
        input     > [4 6 8 2 5 5 0 0]
        predicted > [6 9 8 2 5 5 1 0 0]
    batch 2200
      minibatch loss: 0.6667694449424744
      sample 1:
        input     > [9 8 8 4 0 0 0 0]
        predicted > [8 8 8 4 1 0 0 0 0]
      sample 2:
        input     > [5 8 7 0 0 0 0 0]
        predicted > [2 8 1 1 0 0 0 0 0]
      sample 3:
        input     > [9 3 4 0 0 0 0 0]
        predicted > [3 3 4 1 0 0 0 0 0]
    batch 2300
      minibatch loss: 0.7337868809700012
      sample 1:
        input     > [2 4 7 6 6 9 0 0]
        predicted > [2 6 6 6 6 9 1 0 0]
      sample 2:
        input     > [3 5 2 8 0 0 0 0]
        predicted > [3 5 2 8 1 0 0 0 0]
      sample 3:
        input     > [5 5 8 4 8 9 4 3]
        predicted > [3 5 8 4 4 3 3 3 1]
    batch 2400
      minibatch loss: 0.8720135688781738
      sample 1:
        input     > [8 7 5 7 2 7 2 0]
        predicted > [2 5 5 7 2 7 8 1 0]
      sample 2:
        input     > [7 7 9 4 3 6 8 0]
        predicted > [7 7 9 4 3 9 7 1 0]
      sample 3:
        input     > [8 6 3 2 6 0 0 0]
        predicted > [2 6 2 2 6 1 0 0 0]
    batch 2500
      minibatch loss: 0.6776264309883118
      sample 1:
        input     > [7 7 8 8 8 3 2 0]
        predicted > [7 7 8 8 8 3 1 1 0]
      sample 2:
        input     > [6 7 7 9 3 7 9 8]
        predicted > [7 7 7 3 9 7 9 8 1]
      sample 3:
        input     > [8 6 6 7 0 0 0 0]
        predicted > [6 6 6 7 1 0 0 0 0]
    batch 2600
      minibatch loss: 0.7246588468551636
      sample 1:
        input     > [3 6 7 0 0 0 0 0]
        predicted > [6 6 7 1 0 0 0 0 0]
      sample 2:
        input     > [9 6 8 4 6 6 8 0]
        predicted > [6 6 8 6 6 6 8 1 0]
      sample 3:
        input     > [6 5 9 6 9 2 7 0]
        predicted > [6 9 9 4 6 2 6 1 0]
    batch 2700
      minibatch loss: 0.6910533308982849
      sample 1:
        input     > [3 7 4 0 0 0 0 0]
        predicted > [3 7 4 1 0 0 0 0 0]
      sample 2:
        input     > [2 6 9 9 7 3 2 5]
        predicted > [6 6 9 3 3 3 2 5 1]
      sample 3:
        input     > [9 6 5 0 0 0 0 0]
        predicted > [9 6 5 1 0 0 0 0 0]
    batch 2800
      minibatch loss: 0.6767545342445374
      sample 1:
        input     > [9 8 5 0 0 0 0]
        predicted > [9 8 5 1 0 0 0 0]
      sample 2:
        input     > [2 6 6 4 9 8 2]
        predicted > [2 6 6 4 8 8 2 9]
      sample 3:
        input     > [3 8 7 0 0 0 0]
        predicted > [3 8 7 1 0 0 0 0]
    batch 2900
      minibatch loss: 0.6852056980133057
      sample 1:
        input     > [6 4 7 0 0 0 0 0]
        predicted > [6 4 7 1 0 0 0 0 0]
      sample 2:
        input     > [9 3 9 9 0 0 0 0]
        predicted > [3 9 9 9 1 0 0 0 0]
      sample 3:
        input     > [3 5 8 0 0 0 0 0]
        predicted > [3 5 8 1 0 0 0 0 0]
    batch 3000
      minibatch loss: 0.6660669445991516
      sample 1:
        input     > [7 2 6 9 5 2 8 7]
        predicted > [7 2 9 5 5 2 7 5 1]
      sample 2:
        input     > [6 9 9 3 2 0 0 0]
        predicted > [9 9 9 3 5 1 0 0 0]
      sample 3:
        input     > [8 4 6 6 0 0 0 0]
        predicted > [8 4 6 6 1 0 0 0 0]



