美文网首页
TF2 Keras (1): 用序列(Sequence )创建模

TF2 Keras (1): 用序列(Sequence )创建模

作者: 数科每日 | 来源:发表于2021-01-06 08:55 被阅读0次

    本文是对官方文档 的学习笔记。


    Keras 支持2种定义Model 的方式 Sequence 和 Functional API 这篇文章主要讨论Sequence 模式。 Sequence 的优点是简单,用一个简单的数组就可以定义一个Model。虽然方便, 但是 Sequence 也有其缺点

    不合适使用Sequence 方式构建模型的场景
    • 模型有多个输入和输出
    • 有任何一层是多输入,或者多输出
    • 有layer sharing
    • 存在非线性拓扑 (例如:ResNet, Multi-Branch Model)
    用 Squence 构建一个Model 注意事项
    • 可以用 activation 设定激活函数
    • 每层可以设置 name 【可选】
    # Define Sequential model with 3 layers
    model = keras.Sequential(
        [
            layers.Dense(2, activation="relu", name="layer1"),
            layers.Dense(3, activation="relu", name="layer2"),
            layers.Dense(4, name="layer3"),
        ]
    )
    # Call model on a test input
    x = tf.ones((3, 3))
    y = model(x)
    
    添加/删除层

    对于已经build 好的model ,可以通过 add, pop函数对layer 进行添加、删除。

    add

    model = keras.Sequential()
    model.add(layers.Dense(2, activation="relu"))
    model.add(layers.Dense(3, activation="relu"))
    model.add(layers.Dense(4))
    

    pop

    model.pop()
    print(len(model.layers))  # 2
    
    利用 add() + summary() 进行 Debug

    summary()可以打印整个Model 的结构, 进行调试

    model = keras.Sequential()
    model.add(keras.Input(shape=(250, 250, 3)))  # 250x250 RGB images
    model.add(layers.Conv2D(32, 5, strides=2, activation="relu"))
    model.add(layers.Conv2D(32, 3, activation="relu"))
    model.add(layers.MaxPooling2D(3))
    
    # Can you guess what the current output shape is at this point? Probably not.
    # Let's just print it:
    model.summary()
    
    # The answer was: (40, 40, 32), so we can keep downsampling...
    
    model.add(layers.Conv2D(32, 3, activation="relu"))
    model.add(layers.Conv2D(32, 3, activation="relu"))
    model.add(layers.MaxPooling2D(3))
    model.add(layers.Conv2D(32, 3, activation="relu"))
    model.add(layers.Conv2D(32, 3, activation="relu"))
    model.add(layers.MaxPooling2D(2))
    
    # And now?
    model.summary()
    
    # Now that we have 4x4 feature maps, time to apply global max pooling.
    model.add(layers.GlobalMaxPooling2D())
    
    # Finally, we add a classification layer.
    model.add(layers.Dense(10))
    
    提取Featrue

    每一层的输出都可以单独提取出来。

    initial_model = keras.Sequential(
        [
            keras.Input(shape=(250, 250, 3)),
            layers.Conv2D(32, 5, strides=2, activation="relu"),
            layers.Conv2D(32, 3, activation="relu", name="my_intermediate_layer"),
            layers.Conv2D(32, 3, activation="relu"),
        ]
    )
    feature_extractor = keras.Model(
        inputs=initial_model.inputs,
        outputs=initial_model.get_layer(name="my_intermediate_layer").output,
    )
    # Call feature extractor on test input.
    x = tf.ones((1, 250, 250, 3))
    features = feature_extractor(x)
    
    基于Sequence 的迁移学习

    2种常见的迁移学习的方式

    1 把除了最后一层(输出层)以外的层全部冻住
    model = keras.Sequential([
        keras.Input(shape=(784))
        layers.Dense(32, activation='relu'),
        layers.Dense(32, activation='relu'),
        layers.Dense(32, activation='relu'),
        layers.Dense(10),
    ])
    
    # Presumably you would want to first load pre-trained weights.
    model.load_weights(...)
    
    # Freeze all layers except the last one.
    for layer in model.layers[:-1]:
      layer.trainable = False
    
    # Recompile and train (this will only update the weights of the last layer).
    model.compile(...)
    model.fit(...)
    
    2 在已有模型上叠加新的层
    # Load a convolutional base with pre-trained weights
    base_model = keras.applications.Xception(
        weights='imagenet',
        include_top=False,
        pooling='avg')
    
    # Freeze the base model
    base_model.trainable = False
    
    # Use a Sequential model to add a trainable classifier on top
    model = keras.Sequential([
        base_model,
        layers.Dense(1000),
    ])
    
    # Compile & train
    model.compile(...)
    model.fit(...)
    

    相关文章

      网友评论

          本文标题:TF2 Keras (1): 用序列(Sequence )创建模

          本文链接:https://www.haomeiwen.com/subject/tbqgoktx.html