美文网首页
二分类问题

二分类问题

作者: 抑郁质 | 来源:发表于2018-09-20 10:29 被阅读0次
    # -*- coding: utf-8
    
    -*-
    
    """
    
    Spyder Editor
    
    
    This is a temporary
    
    script file.
    
    """
    
    
    
    import tensorflow as
    
    tf
    
    
    from numpy.random
    
    import RandomState
    
    
    batch_size = 8
    
    
    w1 =
    
    tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
    
    w2 =
    
    tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))
    
    
    
    x=
    
    tf.placeholder(tf.float32,shape=[None,2],name='x-input')
    
    
    y_=tf.placeholder(tf.float32,shape=[None,1],name='y-input')
    
    
    a=tf.matmul(x,w1)
    
    y=tf.matmul(a,w2)
    
    
    y=tf.sigmoid(y)
    
    
    
    #定义损失函数来刻画预测值与真实值的差距
    
    cross_entropy =
    
    -tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-10,1.0))+(1-y)*tf.log(tf.clip_by_value(1-y,1e-10,1.0)))
    
    
    train_step =
    
    tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
    
    #softmax封住后的交叉熵
    
    #cross_entropy =
    
    tf.nn.softmax_cross_entropy_with_logits(labels=y_,logits=y)
    
    
    rdm = RandomState(1)
    
    
    dataset_size=128
    
    X =
    
    rdm.rand(dataset_size,2)
    
    print(X)
    
    Y = [[int
    
    (x1+x2<1)] for (x1,x2) in X]
    
    print(Y)
    
    with tf.Session() as
    
    sess:
    
    
        init_op = tf.global_variables_initializer()
    
    
        sess.run(init_op)
    
    
    
        print(sess.run(w1))
    
        print(sess.run(w2))
    
    
        STEPS = 5000
    
        for i in range(STEPS):
    
            start = (i*batch_size) % dataset_size
    
            end = min(start+batch_size,dataset_size)
    
    
            sess.run(train_step,
    
                    feed_dict={x:X[start:end],y_:Y[start:end]})
    
            if i % 1000 == 0 :
    
                total_cross_entropy =sess.run(cross_entropy,feed_dict={x:X,y_:Y})
    
    
                print("after %d trainingstep(s),cross entropy on all data is %g"%(i,total_cross_entropy))
    
    
        print(sess.run(w1))
    
        print(sess.run(w2))
    
    

    相关文章

      网友评论

          本文标题:二分类问题

          本文链接:https://www.haomeiwen.com/subject/xtugnftx.html