美文网首页
Unsupervised NMT 代码理解

Unsupervised NMT 代码理解

作者: 润秀94 | 来源:发表于2018-10-11 15:32 被阅读0次

    对应论文: Phrase-based & Neural Unsupervised Machine Translation

    TrainerMT 包含encoder decoder discriminator lm

    针对于Attention Model

    encoder 是 TransformerEncoder(); decoder 是 TransformerDecoder(encoder)

    discriminator 是 Discriminator(), 没有把encoder 或者decoder 进行传参

    decoder.loss 是 CrossEntropyLoss 或者 LabelSmoothedCrossEntropy

    lm=LM(encoder, decoder)

    discriminator_step, lm_step

    1 对于parallel data: enc_dec_step(params.lambda_xe_para) 训练de>en

    2 对于back-parallel data: enc_dec_step(params.lambda_xe_back, back=True) 训练en>de

    3 对于monolingual data: enc_dec_step(params.lambda_xe_mono)每个语言 都训练 Autoencoder loss

    4 xe_otfd: on-the-fly back-translation parallel data

    xe_otfa: on-the-fly back-translation autoencoding data

    batches 来自 otf_iterator

    当lang3=lang1 时:是 back-translation 用于 autoencoding

    当lang3=lang2时: 是 back-translation 用于 parallel

    当互不相等时:

    ---------------------------------

    lm_step: loss也是 decoder.loss

    LM 里面有 scores_enc; scores_dec; scores_enc_rev 这三个是 subLMs

    loss=loss_enc+loss_dec+loss_enc_rev

    lm(sent, length, lang_id, is_encoder, reverse)

    lang_id = source 有多个languge, 代表不同的source

    对于不同的subLM(sent, length, source), 包含了 embedding, lstm_layer, proj_layer: output_shape: (slen, bs, n_words) representing the score for each output word of being the next word

    — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —

    Encoder 有多少语言就有多少 Embedding,LSTM,projection_layer

    projection_layer dim_output 就是 emb_dim

    输出 LatentState: 包含 dec_input(slen, bs, emb_dim), dis_input(slen, bs, 2*hidden_dim)

    Decoder:

    att_proj (hidden_dim, emb_dim)

    lstm_proj(hidden_dim, emb_dim)

    proj (lstm_proj_dim, n_words)

    return (ylen, bs, n_words)

    问题

    1. back parallel training 是怎么进行的 shared encoder decoder?

    2. Decoder lstm1 lstm2 的作用, 为什么 input_feeding 时候要那么设置?

    input_feeding=True: lstm1 dim: 2*emb_dim, lstm2 dim: hidden_dim

    input_feeding=False: lstm1 dim: emb_dim, lstm2: hidden_dim+emb_dim

    相关文章

      网友评论

          本文标题:Unsupervised NMT 代码理解

          本文链接:https://www.haomeiwen.com/subject/ptmbaftx.html