美文网首页
Conditional Generation by RNN &

Conditional Generation by RNN &

作者: 没日没夜醉心科研的九天 | 来源:发表于2020-11-26 20:14 被阅读0次

    Outline

            Generation

            Attention

            Tips for Generation

            Pointer Network

    Generation

            generate a structured object component by component

            1. examples of generation

            逐次用RNN生成字符/单词

    句子的生成

            逐个像素地生成图像

    图像的生成 1 图像的生成 2

            2. conditional generation

            不仅限于生成随机的句子,而是根据条件生成相应的sentence。

            如图像标题生成、聊天机器人。

    conditional generation

            图像标题生成(Conditional Caption generation)

            将图像通过CNN转换成一个vector,再将vector输入到RNN中生成标题。

    Caption generation

            有条件的句子生成(Conditional Sentence Generation)

            如机器翻译/聊天机器人。

            将中文输入到一个RNN(1)中,得到一个包含所有信息的vector,再将vector输入到RNN(2)中。前者是encoder,后者是decoder。RNN(1)和RNN(2)的参数既可以相同也可以不同,具体分析。

    Conditional Sentence Generation

    Attention

            Dynamic conditional generation

            1.Dynamic conditional generation

            有时候data过大,encoder生成的vector不足以包含所有的信息;或者只有特定的信息对某个输出有用,其他都没有必要,因此需要设计一个dynamic的生成方案。

            2. examples

            Machine Translation

                    首先有一个初始化参数Z0(decoder产生的key),然后每个component依次通过一个RNN得到对应的h,然后通过一个match(Z0和hi为输入,a为输出)得到ai,ai即代表着目前的重点关注范围。

                    match的具体算法由designer自己决定,可以是small network等。

    Machine Translation 1

                    对每一个hi做同样的运算,通过softmax就得到对应的概率。将其加权和作为输入,输入到decoder的RNN中,得到对应的transcript,而Z1是隐藏层的输出,进行下一步运算,直至输出结束。

    Machine Translation 2 Machine Translation 3

                    通过component与key的match运算,最后得出每一个component的attention weight,代表着当前对某些具有较大weight值的component的关注程度。

                    每次decoder的input为attention weight与component的加权和(概率分布)。

            Speech Recognition

    Speech Recognition

            Image Caption Generation

                    与machine translation相似。

    Image Caption Generation 1 Image Caption Generation 2 Image Caption Generation 3

            Memory Network

                    在memory上做attention。

                    首先回顾传统的attention-based RNN模型,与上述介绍的一样。

    传统的RNN模型

    Tips for generation

          1.Scheduled Sampling

                    训练和测试不匹配(mismatch between train and test)的问题可以用scheduled sampling来解决。

                    Training:在训练时,预测的结果总是会与reference(label,ground truth)来计算损失,并且下一个component的input会是上一个component对应的reference

    training

                    Generation(test):在生成时,下一个component的input只能是上一个component对应的output;此时没有reference来参考。

    genertaion

                    这样就会导致一个mismatch的问题。因为generation没有reference,如果一个output出现了错误的预测,那么就会将接下来的结果带入到没有经过训练的方向上,会导致很多问题,如下图所示。

    mismatch

                    如果我们考虑改变train的方法呢?

                     为了保持train和test一致,我们应当保持前一个component的output始终是后一个component的input,即使它与reference不一致。这种训练方式看起来很合理,但是很难train,注意:第一个component训练的目标是输出A、第二个要输出B,因此第一个component的预测输出最终会变成A,此时第二个component的输入也会变成A,但是他已经按照B训练好了。因此到最后网络的训练结果不一定好,也有可能更差。

    modifying

                    因此可以采用scheduled sampling来改善这个问题。

                    Scheduled sampling其实就是以一定的概率函数选择下一个input的来源是上一个output还是reference。如图,三种decay衰减函数。实验证明,采用scheduled sampling确实在效果上有一定的改善。

    scheduled sampling

          2.beam search

                    绿色的路分数最高,但是我们并不能提前知道最终的结果是怎样。因此我们可以设置beam size的大小,最次都挑选最优的路径。

    beam search 1

                            beam search的思想就是每一步都挑选最好的路径。在每一次只有两种结果的前提下,设置beam size=2,也就是说每次考虑2个component,从4条路径中挑选最好的路径作为最终训练结果。

    beam search 2

          3.better idea?

                    前一个output应当是把distribution还是选择的结果(如:非黑即白)作为input送给下一个component呢?显然把选择的结果送给下一个component比较好。如下图,我们想要最终输出“高兴想笑”或“难过想哭”,但是如果传送的是distrubution,可能会得到“高兴想哭”等不好的结果。

    better idea?

          4.Object level

                    如下图所示,采用component的loss训练到中间结果“The dog is is fast”后就train不动了,结果改善得不明显。但是如果换成object level的loss,则还能继续train下去直到得到目标结果。

    Object level

          5.Reinforcement learning

    Reinforcement learning

    Pointer network

            Pointer Network可以应用在给一堆点找边界上:

     Pointer Network 1

            相对于上述的attention-based网络,pointer network在计算出每一个component的attention weight后,不是计算概率和,而是直接将对应的component输出。如下图,(x4,y4)对应的attention weight是0.7最高,因此直接输出(x4,y4)。

     Pointer Network 2

            其他方面的应用,如machine translation和chat-bot,pointer network可以直接输出相应的word。

    application

    相关文章

      网友评论

          本文标题:Conditional Generation by RNN &

          本文链接:https://www.haomeiwen.com/subject/zsbqiktx.html