Keras的R语言接口

作者: Liam_ml | 来源:发表于2018-02-17 18:51 被阅读108次

    Getting Start

    首先,从github上下载keras

    devtools::install_github("rstudio/keras")
    
    

    Keras R界面默认使用TensorFlow后端引擎。 要安装核心Keras库以及TensorFlow后端,请使用install_keras()函数:

    library(keras)
    install_keras()
    
    

    这将提供Keras和TensorFlow的默认基于CPU的安装。 如果想要更加自定义的安装,可以看install_keras()的文档。

    MNIST

    我们可以通过一个简单的例子来学习Keras的基础知识:从MNIST数据集识别手写数字。 MNIST由28 x 28像这样的手写数字的灰度图像组成:


    image.png

    准备数据

    数据在keras的这个包中,

    library(keras)
    mnist <- dataset_mnist()
    x_train <- mnist$train$x
    y_train <- mnist$train$y
    x_test <- mnist$test$x
    y_test <- mnist$test$y
    
    

    x数据是灰度值的三维数组(图像,宽度,高度)。 为了准备训练数据,我们通过将宽度和高度重新整形为一个维度(28x28图像平展成长度为784个向量)将三维数组转换为矩阵。 然后,我们将灰度值从范围在0到255之间的整数转换为介于0和1之间的浮点值:

    # reshape
    x_train <- array_reshape(x_train, c(nrow(x_train), 784))
    x_test <- array_reshape(x_test, c(nrow(x_test), 784))
    # rescale
    x_train <- x_train / 255
    x_test <- x_test / 255
    

    请注意,我们使用array_reshape()函数而不是dim < - ()函数来重新整形数组。 这是为了使用行主语义(与R的默认列主语义相反)重新解释数据,这又与Keras调用的数值库解释数组维度的方式兼容。

    y数据是一个整数向量,其取值范围为0到9.为了准备这些数据以进行训练,我们使用Keras to_categorical()函数将向量单向热编码为二进制类矩阵:

    y_train <- to_categorical(y_train, 10)
    y_test <- to_categorical(y_test, 10)
    

    定义模型

    Keras的核心数据结构是一种模型,一种组织图层的方法。 最简单的模型是Sequential模型,这是一个线性的层堆栈。

    我们首先创建一个顺序模型,然后使用管道(%>%)运算符添加图层:

    model <- keras_model_sequential() 
    model %>% 
      layer_dense(units = 256, activation = 'relu', input_shape = c(784)) %>% 
      layer_dropout(rate = 0.4) %>% 
      layer_dense(units = 128, activation = 'relu') %>%
      layer_dropout(rate = 0.3) %>%
      layer_dense(units = 10, activation = 'softmax')
    
    

    他对第一层的input_shape参数指定了输入数据的形状(代表灰度图像的长度为784的数字向量)。 最后一层使用softmax激活函数输出长度为10的数字向量(每个数字的概率)。

    使用summary()函数打印模型的详细信息:

    summary(model)
    _________________________________________________________________________
    Layer (type)                    Output Shape                  Param #    
    =========================================================================
    dense_1 (Dense)                 (None, 256)                   200960     
    _________________________________________________________________________
    dropout_1 (Dropout)             (None, 256)                   0          
    _________________________________________________________________________
    dense_2 (Dense)                 (None, 128)                   32896      
    _________________________________________________________________________
    dropout_2 (Dropout)             (None, 128)                   0          
    _________________________________________________________________________
    dense_3 (Dense)                 (None, 10)                    1290       
    =========================================================================
    Total params: 235,146
    Trainable params: 235,146
    Non-trainable params: 0
    _________________________________________________________________________
    

    接下来,使用适当的损失函数,优化器和指标编译模型:

    model %>% compile(
      loss = 'categorical_crossentropy',
      optimizer = optimizer_rmsprop(),
      metrics = c('accuracy')
    )
    
    

    训练评估模型

    使用fit()函数使用128个图像的批次对30个时期的模型进行训练:

    history <- model %>% fit(
      x_train, y_train, 
      epochs = 30, batch_size = 128, 
      validation_split = 0.2
    )
    Train on 48000 samples, validate on 12000 samples
    Epoch 1/30
    48000/48000 [==============================] - 4s 91us/step - loss: 0.4249 - acc: 0.8717 - val_loss: 0.1666 - val_acc: 0.9490
    Epoch 2/30
    48000/48000 [==============================] - 4s 81us/step - loss: 0.2023 - acc: 0.9399 - val_loss: 0.1278 - val_acc: 0.9634
    Epoch 3/30
    48000/48000 [==============================] - 4s 79us/step - loss: 0.1552 - acc: 0.9534 - val_loss: 0.1148 - val_acc: 0.9681
    Epoch 4/30
    48000/48000 [==============================] - 4s 81us/step - loss: 0.1320 - acc: 0.9609 - val_loss: 0.1008 - val_acc: 0.9716
    Epoch 5/30
    48000/48000 [==============================] - 4s 76us/step - loss: 0.1148 - acc: 0.9658 - val_loss: 0.0933 - val_acc: 0.9738
    Epoch 6/30
    48000/48000 [==============================] - 4s 77us/step - loss: 0.1048 - acc: 0.9684 - val_loss: 0.0914 - val_acc: 0.9752
    Epoch 7/30
    48000/48000 [==============================] - 4s 78us/step - loss: 0.0979 - acc: 0.9715 - val_loss: 0.0901 - val_acc: 0.9752
    Epoch 8/30
    48000/48000 [==============================] - 4s 77us/step - loss: 0.0887 - acc: 0.9745 - val_loss: 0.0919 - val_acc: 0.9758
    Epoch 9/30
    48000/48000 [==============================] - 4s 77us/step - loss: 0.0858 - acc: 0.9748 - val_loss: 0.0904 - val_acc: 0.9779
    Epoch 10/30
    48000/48000 [==============================] - 4s 78us/step - loss: 0.0807 - acc: 0.9769 - val_loss: 0.0903 - val_acc: 0.9783
    Epoch 11/30
    48000/48000 [==============================] - 4s 77us/step - loss: 0.0781 - acc: 0.9781 - val_loss: 0.0956 - val_acc: 0.9771
    Epoch 12/30
    48000/48000 [==============================] - 4s 78us/step - loss: 0.0768 - acc: 0.9788 - val_loss: 0.0917 - val_acc: 0.9787
    Epoch 13/30
    48000/48000 [==============================] - 4s 77us/step - loss: 0.0706 - acc: 0.9794 - val_loss: 0.0909 - val_acc: 0.9784
    Epoch 14/30
    48000/48000 [==============================] - 4s 85us/step - loss: 0.0684 - acc: 0.9804 - val_loss: 0.0933 - val_acc: 0.9787
    Epoch 15/30
    48000/48000 [==============================] - 4s 84us/step - loss: 0.0682 - acc: 0.9810 - val_loss: 0.1013 - val_acc: 0.9785
    Epoch 16/30
    48000/48000 [==============================] - 4s 82us/step - loss: 0.0647 - acc: 0.9812 - val_loss: 0.0951 - val_acc: 0.9795
    Epoch 17/30
    48000/48000 [==============================] - 4s 78us/step - loss: 0.0627 - acc: 0.9829 - val_loss: 0.1004 - val_acc: 0.9792
    Epoch 18/30
    48000/48000 [==============================] - 4s 79us/step - loss: 0.0671 - acc: 0.9823 - val_loss: 0.0959 - val_acc: 0.9803
    Epoch 19/30
    48000/48000 [==============================] - 4s 77us/step - loss: 0.0602 - acc: 0.9831 - val_loss: 0.0976 - val_acc: 0.9797
    Epoch 20/30
    48000/48000 [==============================] - 4s 76us/step - loss: 0.0593 - acc: 0.9835 - val_loss: 0.1051 - val_acc: 0.9786
    Epoch 21/30
    48000/48000 [==============================] - 4s 78us/step - loss: 0.0592 - acc: 0.9840 - val_loss: 0.1008 - val_acc: 0.9799
    Epoch 22/30
    48000/48000 [==============================] - 4s 76us/step - loss: 0.0561 - acc: 0.9846 - val_loss: 0.1023 - val_acc: 0.9800
    Epoch 23/30
    48000/48000 [==============================] - 4s 78us/step - loss: 0.0592 - acc: 0.9844 - val_loss: 0.1100 - val_acc: 0.9787
    Epoch 24/30
    48000/48000 [==============================] - 4s 83us/step - loss: 0.0566 - acc: 0.9848 - val_loss: 0.1048 - val_acc: 0.9790
    Epoch 25/30
    48000/48000 [==============================] - 4s 79us/step - loss: 0.0531 - acc: 0.9852 - val_loss: 0.1091 - val_acc: 0.9802
    Epoch 26/30
    48000/48000 [==============================] - 4s 79us/step - loss: 0.0570 - acc: 0.9850 - val_loss: 0.1055 - val_acc: 0.9803
    Epoch 27/30
    48000/48000 [==============================] - 4s 84us/step - loss: 0.0515 - acc: 0.9868 - val_loss: 0.1114 - val_acc: 0.9798
    Epoch 28/30
    48000/48000 [==============================] - 4s 78us/step - loss: 0.0532 - acc: 0.9861 - val_loss: 0.1148 - val_acc: 0.9799
    Epoch 29/30
    48000/48000 [==============================] - 4s 76us/step - loss: 0.0532 - acc: 0.9860 - val_loss: 0.1105 - val_acc: 0.9796
    Epoch 30/30
    48000/48000 [==============================] - 4s 77us/step - loss: 0.0519 - acc: 0.9869 - val_loss: 0.1179 - val_acc: 0.9796
    
    
    image.png

    测试集合上评估模型的性能

    model %>% evaluate(x_test, y_test)
    10000/10000 [==============================] - 1s 55us/step
    $loss
    [1] 0.1040304
    
    $acc
    [1] 0.9815
    

    进行预测

    model %>% predict_classes(x_test)
    [1] 7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3
      [34] 4 7 2 7 1 2 1 1 7 4 2 3 5 1 2 4 4 6 3 5 5 6 0 4 1 9 5 7 8 9 3 7 4
      [67] 6 4 3 0 7 0 2 9 1 7 3 2 9 7 7 6 2 7 8 4 7 3 6 1 3 6 9 3 1 4 1 7 6
     [100] 9 6 0 5 4 9 9 2 1 9 4 8 7 3 9 7 4 4 4 9 2 5 4 7 6 7 9 0 5 8 5 6 6
     [133] 5 7 8 1 0 1 6 4 6 7 3 1 7 1 8 2 0 2 9 9 5 5 1 5 6 0 3 4 4 6 5 4 6
     [166] 5 4 5 1 4 4 7 2 3 2 7 1 8 1 8 1 8 5 0 8 9 2 5 0 1 1 1 0 9 0 3 1 6
     [199] 4 2 3 6 1 1 1 3 9 5 2 9 4 5 9 3 9 0 3 6 5 5 7 2 2 7 1 2 8 4 1 7 3
     [232] 3 8 8 7 9 2 2 4 1 5 9 8 7 2 3 0 2 4 2 4 1 9 5 7 7 2 8 2 0 8 5 7 7
     [265] 9 1 8 1 8 0 3 0 1 9 9 4 1 8 2 1 2 9 7 5 9 2 6 4 1 5 8 2 9 2 0 4 0
     [298] 0 2 8 4 7 1 2 4 0 2 7 4 3 3 0 0 3 1 9 6 5 2 5 9 7 9 3 0 4 2 0 7 1
     [331] 1 2 1 5 3 3 9 7 8 6 5 6 1 3 8 1 0 5 1 3 1 5 5 6 1 8 5 1 7 9 4 6 2
     [364] 2 5 0 6 5 6 3 7 2 0 8 8 5 4 1 1 4 0 7 3 7 6 1 6 2 1 9 2 8 6 1 9 5
     [397] 2 5 4 4 2 8 3 8 2 4 5 0 3 1 7 7 5 7 9 7 1 9 2 1 4 2 9 2 0 4 9 1 4
     [430] 8 1 8 4 5 9 8 8 3 7 6 0 0 3 0 2 0 6 9 9 3 3 3 2 3 9 1 2 6 8 0 5 6
     [463] 6 6 3 8 8 2 7 5 8 9 6 1 8 4 1 2 5 9 1 9 7 5 4 0 8 9 9 1 0 5 2 3 7
     [496] 0 9 4 0 6 3 9 5 2 1 3 1 3 6 5 7 4 2 2 6 3 2 6 5 4 8 9 7 1 3 0 3 8
     [529] 3 1 9 3 4 4 6 4 2 1 8 2 5 4 8 8 4 0 0 2 3 2 7 7 0 8 7 4 4 7 9 6 9
     [562] 0 9 8 0 4 6 0 6 3 5 4 8 3 3 9 3 3 3 7 8 0 2 2 1 7 0 6 5 4 3 8 0 9
     [595] 6 3 8 0 9 9 6 8 6 8 5 7 8 6 0 2 4 0 2 2 3 1 9 7 5 8 0 8 4 6 2 6 7
     [628] 9 3 2 9 8 2 2 9 2 7 3 5 9 1 8 0 2 0 5 2 1 3 7 6 7 1 2 5 8 0 3 7 2
     [661] 4 0 9 1 8 6 7 7 4 3 4 9 1 9 5 1 7 3 9 7 6 9 1 3 3 8 3 3 6 7 2 4 5
     [694] 8 5 1 1 4 4 3 1 0 7 7 0 7 9 4 4 8 5 5 4 0 8 2 1 0 8 4 5 0 4 0 6 1
     [727] 9 3 2 6 7 2 6 9 3 1 4 6 2 5 9 2 0 6 2 1 7 3 4 1 0 5 4 3 1 1 7 4 9
     [760] 9 4 8 4 0 2 4 5 1 1 6 4 7 1 9 4 2 4 1 5 5 3 8 3 1 4 5 6 8 9 4 1 5
     [793] 3 8 0 3 2 5 1 2 8 3 4 4 0 8 8 3 3 1 7 3 5 9 6 3 2 6 1 3 6 0 7 2 1
     [826] 7 1 4 2 4 2 1 7 9 6 1 1 2 4 8 1 7 7 4 8 0 7 3 1 3 1 0 7 7 0 3 5 5
     [859] 2 7 6 6 9 2 8 3 5 2 2 5 6 0 8 2 9 2 8 8 8 8 7 4 9 3 0 6 6 3 2 1 3
     [892] 2 2 9 3 0 0 5 7 8 3 4 4 6 0 2 9 1 4 7 4 7 3 9 8 8 4 7 1 2 1 2 2 3
     [925] 2 3 2 3 9 1 7 4 0 3 5 5 8 6 3 2 6 7 6 6 3 2 7 9 1 1 7 5 6 4 9 5 1
     [958] 3 3 4 7 8 9 1 1 0 9 1 4 4 5 4 0 6 2 2 3 1 5 1 2 0 3 8 1 2 6 7 1 6
     [991] 2 3 9 0 1 2 2 0 8 9
     [ reached getOption("max.print") -- omitted 9000 entries ]
    

    很酷吧

    相关文章

      网友评论

        本文标题:Keras的R语言接口

        本文链接:https://www.haomeiwen.com/subject/gptvtftx.html