美文网首页
tf 实现 KNN

tf 实现 KNN

作者: cookyo | 来源:发表于2019-08-19 18:53 被阅读0次
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)     #下载并加载mnist数据
    
    train_X, train_Y = mnist.train.next_batch(5000) # 5000 for training (nn candidates)
    test_X, test_Y = mnist.test.next_batch(100)   # 200 for testing
    
    
    tra_X = tf.placeholder("float", [None, 784])
    te_X = tf.placeholder("float", [784])
    
    # Nearest Neighbor calculation using L1 Distance
    # Calculate L1 Distance
    distance = tf.reduce_sum(tf.abs(tf.add(tra_X, tf.negative(te_X))), reduction_indices=1)
    # Prediction: Get min distance index (Nearest neighbor)
    pred = tf.arg_min(distance, 0)
    
    accuracy = 0.
    
    # Initializing the variables
    init = tf.initialize_all_variables()
    
    # Launch the graph
    with tf.Session() as sess:
      sess.run(init)
    
        # loop over test data
      for i in range(len(test_X)):
        nn_index = sess.run(pred, feed_dict={tra_X: train_X, te_X: test_X[i, :]}) # Get nearest neighbor
        print("Test", i, "Prediction:", np.argmax(train_Y[nn_index]), "True Class:", np.argmax(test_Y[i])) # Get nearest neighbor class label and compare it to its true label
        if np.argmax(train_Y[nn_index]) == np.argmax(test_Y[i]): # Calculate accuracy
          accuracy += 1./len(test_X)
      print("Done!")
      print("Accuracy:", accuracy)
    

    相关文章

      网友评论

          本文标题:tf 实现 KNN

          本文链接:https://www.haomeiwen.com/subject/mhiesctx.html