美文网首页
[tf]Estimator

[tf]Estimator

作者: VanJordan | 来源:发表于2018-12-17 09:55 被阅读1次

    感觉网络定义这一块使用layers比较方便,其他的好像都比较鸡肋,写一个for训练就能实现的非要封装成一个函数。

    def lenet(x , is_training):
        x = tf.reshape(x, shape = [-1,28,28,1])
        net = tf.layers.conv2d(x,32,5,activation=tf.nn.relu)
        net = tf.layers.max_pooling2d(net,2,2)
        net = tf.contrib.layers.flatten(net)
        net = tf.layers.dense(net,1024)
        net = tf.layers.dropout(net, rate =0.4, training = is_training)
        return tf.layers.dense(net,10)
    
    # 自定义Estimator中是用的模型。定义的函数中有4个输入,features给出了在输入函数中会提供的输入覅个的张量,
    # 注意这是一个字典,字典里面的内容通过tf.estimator.inputs.numpy_input_fn中的x参数的内容指定
    # mode的取值有三种可能性,分别对应Estimator类中的train,evaluate和 predict这三个函数。通过这个参数可以判断当前
    # 是否是训练过程,最后一个params是一个字典,这个字典中可以给出模型的任何超参数,比如学习率
    
    def model_fn(features, labels, mode, params):
        predict =  lenet(features["image"], mode = tf.estimator.ModeKeys.TRAIN)
        if mode = tf.estimaotr.ModeKeys.PREDICT:
            # 使用EstimatorSpec传递返回值,并通过predictions参数指定返回徐建国
            return tf.estimator.EstimatorSpec(mode, predictions=['result':tf.argmax(predict,1)])
    
        loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels = labels))
        optimizer = tf.train.GradientDescentOptimizer(lr=params("lr"))
        train_op = optimizer.minimize(loss = loss, global_step= tf.train.get_global_step())
        # 定义评测标准,在运行evaluate时会计算这里定义的所有评测标准
        eval_metric_ops = {'my_metric': tf.metrics.accuracy(tf.argmax(predict,1), labels)}
        return  tf.estimator.EstimatorSpec( mode = mode,
                                            loss = loss,
                                            train_op = train_op,
                                            eval_metric_ops = eval_metric_ops)
    
    estimator = tf.estimator.Estimator(model_fn = model_fn, params = model_params)
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={"image": mnist.train.images},
        y = mnist.train.labels.astype(np.int32),
        num_epochs = None,
        batch_size = 128,
        shuffle = True
    )
    estimator.train(input_fn = train_input, steps = 300000)
    

    相关文章

      网友评论

          本文标题:[tf]Estimator

          本文链接:https://www.haomeiwen.com/subject/eeqqkqtx.html