美文网首页ML/DL深度学习
Keras Flatten的input_shape问题

Keras Flatten的input_shape问题

作者: Aspirinrin | 来源:发表于2018-01-31 10:19 被阅读3290次

    在fine tune Keras Applications中给出的分类CNN Model的时候,如果在Model的top层之上加入Flatten层就会出现错误。可能的报错信息类似下面的内容:

    $ python3 ./train.py
    Using TensorFlow backend.
    Found 60000 images belonging to 200 classes.
    Found 20000 images belonging to 200 classes.
    # 略过一些信息...
    Creating TensorFlow device (/device:GPU:0) ->
    (device: 0, name: GeForce GTX 1080, pci bus id: 0000:02:00.0, compute capability: 6.1)
    
    # ↓↓↓ 错误出现 ↓↓↓
    Traceback (most recent call last):
      File "./train.py", line 51, in <module>
        x = Flatten()(x)
      File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 636, in __call__
        output_shape = self.compute_output_shape(input_shape)
      File "/usr/local/lib/python3.5/dist-packages/keras/layers/core.py", line 490, in
        compute_output_shape
        '(got ' + str(input_shape[1:]) + '. '
    ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536).
    Make sure to pass a complete "input_shape" or "batch_input_shape" argument
    to the first layer in your model.
    # ↑↑↑ 错误结束 ↑↑↑
    

    出错的代码行是x = Flatten()(x),错误提示为ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 1536). Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model.

    Flatten()(x)希望参数拥有确定shape属性,实际得到的参数xshape属性是(None, None, 1536),很明显不符合要求。同时,错误提示信息中也给出了修正错误的方法Make sure to pass a complete "input_shape" or "batch_input_shape" argument to the first layer in your model。即,在Model的第一层给出确定的input_shapebatch_input_shape。那么,如何在Keras中解决该问题呢?

    以Keras Applications中的VGG16为例,我们只需要在其初始化的时候,给出具体的input_shape就可以了。例如,Keras给出的VGG16模型输入层图像尺寸是(224, 224)的,所以如果使用TensorFlow的channels_last数据格式,则初始化代码为:

    vgg16 = keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
    x = vgg16.output
    x = Flatten()(x)
    ...
    

    注意,因为要fine tune模型,对模型分类的种类和类别数进行重新定义,所以include_top=False,这样返回的模型不包括VGG16的全连接层和输出层。

    参考:

    1. Unable to fine tune Keras vgg16 model - input shape issue
    2. Keras Applications

    相关文章

      网友评论

        本文标题:Keras Flatten的input_shape问题

        本文链接:https://www.haomeiwen.com/subject/xmlkzxtx.html