美文网首页
Random Forest

Random Forest

作者: scpy | 来源:发表于2018-12-28 14:51 被阅读0次
    """ Random Forest.
    
    Implement Random Forest algorithm with TensorFlow, and apply it to classify 
    handwritten digit images. This example is using the MNIST database of 
    handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).
    
    Author: Aymeric Damien
    Project: https://github.com/aymericdamien/TensorFlow-Examples/
    """
    
    from __future__ import print_function
    
    import tensorflow as tf
    from tensorflow.contrib.tensor_forest.python import tensor_forest
    from tensorflow.python.ops import resources
    
    # Ignore all GPUs, tf random forest does not benefit from it.
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
    
    # Import MNIST data
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
    
    # Parameters
    num_steps = 500 # Total steps to train
    batch_size = 1024 # The number of samples per batch
    num_classes = 10 # The 10 digits
    num_features = 784 # Each image is 28x28 pixels
    num_trees = 10
    max_nodes = 1000
    
    # Input and Target data
    X = tf.placeholder(tf.float32, shape=[None, num_features])
    # For random forest, labels must be integers (the class id)
    Y = tf.placeholder(tf.int32, shape=[None])
    
    # Random Forest Parameters
    hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                          num_features=num_features,
                                          num_trees=num_trees,
                                          max_nodes=max_nodes).fill()
    
    # Build the Random Forest
    forest_graph = tensor_forest.RandomForestGraphs(hparams)
    # Get training graph and loss
    train_op = forest_graph.training_graph(X, Y)
    loss_op = forest_graph.training_loss(X, Y)
    
    # Measure the accuracy
    infer_op, _, _ = forest_graph.inference_graph(X)
    correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
    accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    # Initialize the variables (i.e. assign their default value) and forest resources
    init_vars = tf.group(tf.global_variables_initializer(),
        resources.initialize_resources(resources.shared_resources()))
    
    # Start TensorFlow session
    sess = tf.Session()
    
    # Run the initializer
    sess.run(init_vars)
    
    # Training
    for i in range(1, num_steps + 1):
        # Prepare Data
        # Get the next batch of MNIST data (only images are needed, not labels)
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
        if i % 50 == 0 or i == 1:
            acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
            print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))
    
    # Test Model
    test_x, test_y = mnist.test.images, mnist.test.labels
    print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
    

    相关文章

      网友评论

          本文标题:Random Forest

          本文链接:https://www.haomeiwen.com/subject/onnslqtx.html