美文网首页
Keras练习:线性回归

Keras练习:线性回归

作者: 囧书 | 来源:发表于2018-12-30 17:23 被阅读38次

    Keras是高层神经网络API,后端可基于Tensorflow运行。

    这里创建一个简单数据集,做线性回归,感受一下Keras的便捷。

    Regression
    import numpy as np
    import matplotlib.pyplot as plt
    from keras.models import Sequential
    from keras.layers import Dense
    
    
    np.random.seed(1327)
    
    
    def create_data():
        x = np.linspace(-1, 1, 200)
        np.random.shuffle(x)
        y = 0.5 * x + 2 + np.random.normal(0, 0.05, (200,))
        return x, y
    
    
    def build_model():
        model = Sequential()
        model.add(Dense(input_dim=1, units=1))
        model.compile(loss='mse', optimizer='sgd')
        return model
    
    
    def train(model, x, y):
        print('Training........')
        for step in range(1001):
            cost = model.train_on_batch(x, y)
            if step % 100 == 0:
                print('COST:', cost)
    
    
    def test(model, x, y):
        print('\nTesting.......')
        cost = model.evaluate(x_test, y_test, batch_size=40)
        print('TEST COST: ', cost)
        w, b = model.layers[0].get_weights()
        print('Weights=', w, 'biases=', b)
    
    
    if __name__ == '__main__':
        x, y = create_data()
        # plt.scatter(x, y)
        # plt.show()
    
        x_train = x[:160]
        y_train = y[:160]
        x_test = x[160:]
        y_test = y[160:]
    
        model = build_model()
        train(model, x_train, y_train)
        test(model, x_test, y_test)
    
        y_predict = model.predict(x_test)
        plt.scatter(x_test, y_test)
        plt.plot(x_test, y_predict)
        plt.show()
    
    

    输出:

    Training........
    COST: 4.0496254
    COST: 0.08321373
    COST: 0.0063530593
    COST: 0.0031990125
    COST: 0.0026882463
    COST: 0.002564282
    COST: 0.0025330638
    COST: 0.0025251824
    COST: 0.002523192
    COST: 0.0025226888
    COST: 0.0025225622
    
    Testing.......
    40/40 [==============================] - 0s 1ms/step
    TEST COST:  0.0018893185770139098
    Weights= [[0.5145617]] biases= [1.9962281]
    

    相关文章

      网友评论

          本文标题:Keras练习:线性回归

          本文链接:https://www.haomeiwen.com/subject/pvwvlqtx.html