美文网首页NLP学习
Transformer系列:残差连接原理详细解析和代码论证

Transformer系列:残差连接原理详细解析和代码论证

作者: xiaogp | 来源:发表于2023-07-29 07:43 被阅读0次

    关键词:Transformer残差连接

    内容目录

    • 残差连接的历史由来
    • Transformer中的残差连接
    • 深层网络的问题代码复现
    • 深层网络的问题分析
    • 残差连接的作用通俗理解
    • 残差连接和GBDT类比
    • 残差连接的作用公式理解
    • 深层网络运用残差连接代码实践

    残差连接的历史由来

    残差连接可以追溯到2015年何凯明等人正式提出的ResNet,使得残差连接/网络成为一种基准模型结构,残差连接解决了神经网络随着层数的增多变得难以训练的问题,主要是出现梯度消失,梯度爆炸和网络退化的情况,而残差连接的引入可以有效缓解这些这些问题从而使得网络可以拓展到更深的层数。


    Transformer中的残差连接

    Transformer也使用了残差连接(residual connection)这种标准结构,在Transformer中的Encoder和Decoder中层和层中间加入了ADD & Norm操作,其中ADD就是残差连接,如图所示

    Transformer中的残差连接

    Add具体的含义是将本层的输出和本层的输入对应位置相加(本层的输出和本层的输入维度相等)作为最终的输出,在Transformer实现的代码中以Encoder为例是这么实现的

    output, slf_attn = self.self_att_layer(enc_input, enc_input, enc_input, mask=mask)
    output = self.norm_layer(Add()([enc_input, output]))
    

    其中self_att_layer是Encoder中的多头注意力,enc_input是多头注意力的输入,output是多头注意力的输出,将enc_input和output通过Keras的Add()算子进行相加得到最终的output。


    深层网络的问题代码复现

    下面通过一个简单的全连接结构测试一下深度网络导致的模型训练问题,首先构建一个可以可以传入层数参数的Dense网络

    class Model(object):
        def __init__(self, num_class, feature_size, layer_num=100, learning_rate=0.001, weight_decay=0.01, decay_learning_rate=1):
            self.input_x = tf.placeholder(tf.float32, [None, feature_size], name="input_x")
            self.input_y = tf.placeholder(tf.float32, [None, num_class], name="input_y")
            self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
            self.batch_normalization = tf.placeholder(tf.bool, name="batch_normalization")
            self.global_step = tf.Variable(0, name="global_step", trainable=False)
    
            tmp_tensor = self.input_x
            for i in range(layer_num):
                with tf.variable_scope('layer_{}'.format(i + 1)):
                    dense_out_1 = tf.layers.dense(tmp_tensor, 32)
                    # bn
                    dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
                    tmp_tensor = tf.nn.relu(dense_out_1)
    
            with tf.variable_scope('layer_out'):
                self.output = tf.layers.dense(tmp_tensor, 2)
                self.probs = tf.nn.softmax(self.output, dim=1, name="probs")
    
            with tf.variable_scope('loss'):
                self.loss = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.output, labels=self.input_y))
                vars = tf.trainable_variables()
                loss_l2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if
                                    v.name not in ['bias', 'gamma', 'b', 'g', 'beta']]) * weight_decay
                self.loss += loss_l2
    
            with tf.variable_scope("optimizer"):
                if decay_learning_rate:
                    learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, decay_learning_rate)
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    self.train_step = optimizer.minimize(self.loss, global_step=self.global_step)
    
            with tf.variable_scope("metrics"):
                self.accuracy = tf.reduce_mean(
                    tf.cast(tf.equal(tf.arg_max(self.probs, 1), tf.arg_max(self.input_y, 1)), dtype=tf.float32))
    

    可以传入任意layer_num构建,下面测试下layer_num在[3, 10, 20, 35, 50, 100]下,迭代最大500轮之后,训练集和测试集的准确率在每轮迭代下的结果

    train_loss = {}
    test_loss = {}
    layer_num = [3, 10, 20, 35, 50]
    for i in layer_num:
        tf.reset_default_graph()
        model = Model(num_class=2, feature_size=15, layer_num=i, weight_decay=0)
        with tf.Session() as sess:
            init_op = tf.group(tf.global_variables_initializer())
            sess.run(init_op)
    
            train_batch = get_batch(3, 64, train_x, train_y)
            val_feed_dict = {model.input_x: test_x, model.input_y: test_y, model.dropout_keep_prob: 1,
                             model.batch_normalization: False}
            for batch in train_batch:
                epoch, batch_x, batch_y = batch
                feed_dict = {model.input_x: batch_x, model.input_y: batch_y, model.dropout_keep_prob: 1,
                             model.batch_normalization: True}
                _, step, loss_train, acc_train = sess.run([model.train_step, model.global_step, model.loss, model.accuracy], feed_dict=feed_dict)
                if step % 1 == 0:
                    loss_val, acc_val, probs = sess.run([model.loss, model.accuracy, model.probs], feed_dict=val_feed_dict)
                    train_loss.setdefault(i, []).append(acc_train)
                    test_loss.setdefault(i, []).append(acc_val)
    

    对结果进行画图如下,每10个点做了移动平均处理


    训练集accuracy 验证集accuracy

    随着网络层数的加深,模型在训练和测试的准确率都在下降,layer=3是其中的最佳效果,当网络达到35层以上时已经很难进行学习收敛,准确率在50%左右。进一步穷举每一种层数可能下的测试集准确率


    不同层数下的测试集accuracy

    结论是随着层数的增多模型效果下降,在这个数据上2层已经够用了。


    深层网络的问题分析

    上面的测试结果表明网络的深度过深导致模型训练困难无法收敛,主要存在三个问题

    • 梯度消失:神经网络采用串联式结构和反向传播优化方法,反向传播中梯度的计算存在模型参数w的累乘,w接近0累乘导致梯度接近0梯度消失
    • 梯度爆炸:同理梯度消失,若w较大,累成导致梯度爆炸
    • 网络退化:理论上网络存在一个最优层数,超过这个层数带来的冗余结构的效果并不超过该最优层数下的模型效果,这些冗余层数会带来网络退化

    对于梯度消失和梯度爆炸,举例如下网络结构


    网络结构 out计算

    此时要对第一层的b进行迭代计算梯度,根据链式求导计算过程如下

    链式求导

    其中存在中间每个网络层的随机初始化参数w,随着网络深度越大w越多,爆炸和消失的可能性越大。


    残差连接的作用通俗理解

    本质上残差连接类似一种兜底策略,目的是使得就算模型的深度已经达到最优解,后面再增加冗余层也至少不会导致之前的效果下降
    残差连接的思路是,举例模型一共50层,若16层时模型已经充分学习达到测试集最佳效果,则让从17层开始到第50层学习一种恒等变换在最后一层将第16层的输出恒等映射出来
    残差连接的做法是将上一层的输出直接连接到下一层的输出,及上一层的输出直接和下一层的原始输出对应位置相加形成最终输出,如图

    残差连接示意图

    上一层的输出是X,下一层的原始输出是F(x),relu(F(x)+X)是最终残差连接的结果,X输入下一层的同时直接连接到下一层的输出,如同构建了桥梁一般。
    其中X代表一个逐渐逼近最优结果的上层输出,而F(x)代表残差,表示还可以再逼近最优效果的网络结构,当模型深度已经达到最优值的时候,残差连接可以自适应的将F(x)学习为全0,由于有relu的存在残差网络很容易将F(x)全部置为0,此时relu(F(x)+X)转化为relu(X),而由于relu的性质得知,relu(X)=X,因为X已经经过上一层的rule变换,再经过一次relu还是X,从而实现了恒等变换,同样如果X还离最优的效果差距很远,残差连接也自适应地让下一层的F(x)充分学习。
    我们并不会知道哪几层就能达到很好的效果,因此可以在每一层或者每隔几层就加入残差网络结构,相当于在每一层/几层就有一个兜底策略,使得网络不会由于已经得到最优层数而相比于上一层有退化。


    残差连接和GBDT类比

    残差连接这种上一层作为基线,下一层拟合残差不断逼近最优结果的思想和GBDT很类似。GBDT用一个个基学习器拟合之前所有基学习器剩下的残差,而残差连接以X为基线,F(x)聚焦于还可学习的微小部分,差异在于GBDT每个基分类结果相加做logit即可得到预测结果,残差连接在网络中间层,最后还要套一层全连接进行任务分类。


    残差连接的作用公式理解

    在深层网络的问题分析那一段中有普通网络从x到y4的计算过程,加入残差连接之后y1到y4的计算如下

    y1 = σ(w1 * x + b1) + x
    y2 = σ(w2 * y1 + b2) + y1
    y3 = σ(w3 * y2 + b3) + y2
    y4 = σ(w4 * y3 + b4) + y3
    

    残差连接每一层的输出公式

    其中X代表某层的输出,某个高层I的输出等于某个低层i输入加上两层之间所有残差F的结果,此时若要对低层的i求梯度,结果如下

    残差连接对低层求梯度

    括号展开第一项直接就是高层I的梯度,直接作为一个因子直接作用到低层的i梯度,而不是像普通网络经过各种累乘放大或者缩小,等式右侧是一个累加,相比于原来的累成一定程度上降低了梯度爆炸和弥散的概率。


    深层网络运用残差连接代码实践

    修改深层网络的问题代码复现一段中的代码,使其在每一层的输出都加上该层的输入

    for i in range(layer_num):
        with tf.variable_scope('layer_{}'.format(i + 1)):
            dense_out_1 = tf.layers.dense(tmp_tensor, 32)
            dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
            if i != 0:
                # 残差连接
                dense_out_1 = tf.nn.relu(dense_out_1)
                dense_out_1 = tf.add(dense_out_1, tmp_tensor)
            # bn
            tmp_tensor = tf.nn.relu(dense_out_1)
    

    顺序采用input => dense => bn => relu => add => relu,参考这个

    残差连接块

    同样是运行[3, 10, 20, 35, 50, 100]层,训练集的accuracy随iter的变化如下


    残差连接训练集准确率

    从训练集来看,35层和100层明显低于其他层但是差距并不大,50层和3,10,20没有明显差异,不同层下整体训练都能收敛,再看测试集

    残差连接测试集准确率

    3,10,20三者没有明显差异,35,50,100随着层数越来越大测试效果逐渐变差,但是也能平均保持在0.7的准确率,相比普通网络只有0.5出头已经有很大改观,从而验证了在深层网络加入残差连接的有效性。

    相关文章

      网友评论

        本文标题:Transformer系列:残差连接原理详细解析和代码论证

        本文链接:https://www.haomeiwen.com/subject/lemkpdtx.html