美文网首页
第八章 Keras高层接口

第八章 Keras高层接口

作者: 晨光523152 | 来源:发表于2020-02-24 22:33 被阅读0次

8.3 模型保存与加载

在 Keras 中,有三种常用的模型保存与加载方法

8.3.1 张量方式

网络的状态主要体现在网络的结构以及网络层内部张量参数上,因此在拥有网络结构的源文件的条件下,直接保存网络张量参数到文件上是最轻量级的一种方式

Model.save_weights(path)

将当前的网络参数保存到 path 文件上。

network.save_weights('weights.ckpt')

这种保存与加载网络的方式最为轻量,文件中保存的仅仅是参数的张量的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能恢复网络状态
创建完一个名为 network 的网络之后,使用

network.load_weights('weights.ckpt')

加载参数。

8.3.2 网络方式

Model.save(path)

可以将模型的结构以及模型的参数保存到一个 path 文件上,在不需要网络源文件的条件下,通过

keras.models.load_model(path) 

即可恢复网络结构和网络参数。

8.3.3 SaveModel 方式

当需要将模型部署到其他平台时,采用 TensorFlow 提出的 SavedModel 方式更具有平台无关性。

tf.keras.experimental.export_saved_model(network, path)

可以将模型以 SavedModel 方式保存到 path 目录中。

然后可以通过函数

tf.keras.experimental.load_from_save_model(path)

即可恢复成网络结构和参数,方便各个平台能够无缝对接训练好的网络模型。

8.4 自定义类

  • 对于需要创建自定义逻辑的网络层,可以通过自定义类来实现,需要继承自 layers.Layer 基类;

  • 对于需要创建自定义网络类,可以通过自定义类来实现,需要继承自 keras.Model 基类。

这样产生的自定义类才能够方便的利用 Layer / Model 基类提供的参数管理功能,同时也能与其他的标准网络层类交互使用。

8.4.1 自定义网络层

对于自定义的网络层,需要实现初始化

__init__

方法和前向传播逻辑

call

方法。

首先创建并继承自 Layer 基类,创建初始化方法函数,并调用母类的初始化函数。

以全连接层为例;需要设置特征的长度 inp_dim 和输出特征的长度 outp_dim,并通过 self.add_variable(name,shape) 创建 shape 大小,名字为 name 的张量,并设置为需要优化:

class MyDense(layers.Layer):
    def __init__(self, inp_dim, outp_dim):
        super(MyDense, self).__init__()
        self.kernel = self.add_variable('w',shape = [inp_dim, outp_dim],
                                       trainable = True)
        
    def call(self,inputs):
        out = inputs @ self.kernel
        out = tf.nn.relu(out)
        return out

8.4.2 自定义网络

1 . 自定义的类可以和其他标准类一样,通过 Sequential 容器方便地包裹成一个网络模型:

network = Sequential([
    MyDense(784, 256),
    MyDense(256, 128),
    MyDense(128, 64),
    MyDense(64, 32),
    MyDense(32, 10)
])
network.build(input_shape = (None, 28 * 28))
network.summary()
  1. 可以继承基类来实现任意逻辑的自定义网络类。
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = MyDense(28*28, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)
    def call(self,inputs):
        x = self.fc1(inputs)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        x = self.fc5(x)
        
        return x

model = MyModel()
model.build(input_shape = (None, 28 * 28))
model.summary()

第二种方法比第一种方法灵活,第二种的前向逻辑可以任意定制(而第一种是依次调用每个网络层的前向传播函数)。


个人写法:可以不写个 Class,而用 tf.keras.layers.Input,这样写,也能任意定制前向逻辑。


8.5 模型乐园

对于常用的网络模型,如 ResNet,VGG 等,不需要手动创建网络,可以直接从 keras.appication 子模块下一行代码即可创建并使用这些经典模型,还可以通过设置 weights 参数加载预训练的网络参数。

8.5.1 加载模型

以 ResNet50 迁移学习为例,一般将 ResNet50 去掉最后一层后的网络作为新任务的特征提取子网络。

resnet = keras.applications.ResNet50(weights='imagenet',include_top=False)
resnet.summary()

参考资料:https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book

相关文章

  • TensorFlow深度学习-第八章

    Char8-Keras高层接口 第八章中讲解的是高层接口Keras的使用。Keras的几个特点 Python语言开...

  • PyTorch&TensorFlow对比

    同类框架 Googletheano ->TensorFlow(高层接口Keras) FacebookCaffe -...

  • 第八章 Keras高层接口

    8.3 模型保存与加载 在 Keras 中,有三种常用的模型保存与加载方法 8.3.1 张量方式 网络的状态主要体...

  • tf2.0学习(六)——过拟合

    前边介绍了TensorFlow的基本操作和Keras的高层接口:tf2.0学习(一)——基础知识[https://...

  • Keras学习记录

    Keras学习笔记 keras.io keras.io-zh keras-cn Keras是一个高层神经网络API...

  • Tensorflow-keras与keras的区别

    keras的背景 keras是开源高层深度学习API。所谓“高层”,是相对于“底层”运算而言(例如add,mat...

  • Keras调研

    Keras调研 关于Keras Keras基于Python编写,是一个高层神经网络API,基于TensorFlow...

  • 2018-12-05--Keras

    Keras:基于Theano和TensorFlow的深度学习库 Keras是一个高层神经网络API,Keras由纯...

  • Keras开发概述

    什么是Keras? Keras是一个高层神经网络API,Keras由纯Python编写而成并基Tensorflow...

  • tf2.0学习(五)——Keras高层接口

    前边介绍了TensorFlow的基本操作:tf2.0学习(一)——基础知识[https://www.jianshu...

网友评论

      本文标题:第八章 Keras高层接口

      本文链接:https://www.haomeiwen.com/subject/dbzvqhtx.html