美文网首页
从产业实践的角度来看:Keras顺序式API已过时

从产业实践的角度来看:Keras顺序式API已过时

作者: LabVIEW_Python | 来源:发表于2021-09-15 16:48 被阅读0次

    在机器学习领域的hello world程序如下:

    import tensorflow as tf 
    from tensorflow import keras 
    from tensorflow.keras import layers 
    
    # 载入MNIST数据
    mnist = keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train/255.0, x_test/255.0 #归一化数据
    
    # 创建模型
    model = keras.Sequential(
        [
            layers.Flatten(input_shape=(28,28)),
            layers.Dense(128, activation="relu"),
            layers.Dropout(0.2),
            layers.Dense(10)
        ]
    )
    # 创建损失函数
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    # 将模型、优化器、损失函数和评价指标compile到一起
    model.compile(optimizer="adam", loss=loss_fn, metrics=['accuracy'])
    # 在训练集上训练模型
    model.fit(x_train, y_train, epochs=5)
    # 在验证数据集上评估模型
    model.evaluate(x_test, y_test, verbose=2)
    
    

    其好处是简单易懂,但从产业实践的角度来看,Keras顺序式API已过时。因为:
    顺序模型只能有一个输入和一个输出,不能做图层共享,只支持顺序叠加,不支持非线性拓扑结构(例如残差连接、多分支模型)
    SOTA模型都会有分支,例如:ResNet,所以Keras Sequential API的模型创建方式,不适合产业实践。
    建议:构架模型时,使用子类API(Subclassing API )来构建,即:

    # 创建模型
    class MyModel(Model):
        def __init__(self):
            super(MyModel, self).__init__()
            self.conv1 = Conv2D(32, 3, activation='relu')
            self.flatten = Flatten()
            self.d1 = Dense(128, activation='relu')
            self.d2 = Dense(10)
    
        def call(self, x):
            x = self.conv1(x)
            x = self.flatten(x)
            x = self.d1(x)
            return self.d2(x)
    

    相关文章

      网友评论

          本文标题:从产业实践的角度来看:Keras顺序式API已过时

          本文链接:https://www.haomeiwen.com/subject/puxlgltx.html