美文网首页我爱编程
搭建模块化神经网络的八股Tensorflow

搭建模块化神经网络的八股Tensorflow

作者: 微雨旧时歌丶 | 来源:发表于2018-05-05 09:37 被阅读0次

1. 前向传播就是搭建网络,设计网络结构(forward.py)

def forward(x,regularizer):  #x是输入,regularizer是正则化权重
    #该函数完成网络结构的设计,给出输入到输出的通路
    w=
    b=
    y=
    return y

def get_weight(shape,regularizer):  #shape是w的形状,regularizer是正则化权重
    w = tf.Variable(  ) #括号里写赋初值的方法,
    tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape):  #shape是b的形状(就是b的个数)
    b=tf.Variable(  ) #赋初值
    return b

2. 反向传播就是训练网络,优化网络参数(backward.py)

def backward():
    x=tf.placeholder(    )
    y_=tf.placeholder(    )
    y=forward.forward(x,REGULARIZER) #用forward复现网络结构
    global_step = tf.Variable(0,trainable=False) #轮数计数器定义
    loss=
## 正则化 ##
   loss可以是:
   y与y_的差距(均方误差)(loss_mse)=tf.reduce_mean(tf.square(y-y_))
   也可以是:
              (交叉熵)ce=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
    y与y_的差距(cem)=tf.reduce_mean(ce)
    加入正则化后:
    loss=y与y_的差距 + tf.add_n(tf.get_collection('losses'))

##指数衰减学习率 ##
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        数据集总样本数/BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)

    train_step=tf.train.GradientDescentOptimizer(learing_rate).minimize(loss,global_step=global_step) #训练过程

##滑动平均##
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step) #global_step与指数衰减学习率中的公用一个
    ema_op = ema.apply(tf.trainable_variable())
    with tf.control_dependencies([train_step,ema_op]):
        train_op = tf.no_op(name='train')

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        for i in range(STEPS) # 迭代轮数
            sess.run(train_step,feed_dict = {x:    ,y_:    })  #执行训练过程
            if i%轮数 ==0:  #每隔一定轮数,打印信息
                print

main函数

if __name__=='__main__':
    backward()

相关文章

网友评论

    本文标题:搭建模块化神经网络的八股Tensorflow

    本文链接:https://www.haomeiwen.com/subject/afdsrftx.html