8.3 模型保存与加载
在 Keras 中,有三种常用的模型保存与加载方法
8.3.1 张量方式
网络的状态主要体现在网络的结构以及网络层内部张量参数上,因此在拥有网络结构的源文件的条件下,直接保存网络张量参数到文件上是最轻量级的一种方式
Model.save_weights(path)
将当前的网络参数保存到 path 文件上。
network.save_weights('weights.ckpt')
这种保存与加载网络的方式最为轻量,文件中保存的仅仅是参数的张量的数值,并没有其它额外的结构参数。但是它需要使用相同的网络结构才能恢复网络状态。
创建完一个名为 network 的网络之后,使用
network.load_weights('weights.ckpt')
加载参数。
8.3.2 网络方式
Model.save(path)
可以将模型的结构以及模型的参数保存到一个 path 文件上,在不需要网络源文件的条件下,通过
keras.models.load_model(path)
即可恢复网络结构和网络参数。
8.3.3 SaveModel 方式
当需要将模型部署到其他平台时,采用 TensorFlow 提出的 SavedModel 方式更具有平台无关性。
tf.keras.experimental.export_saved_model(network, path)
可以将模型以 SavedModel 方式保存到 path 目录中。
然后可以通过函数
tf.keras.experimental.load_from_save_model(path)
即可恢复成网络结构和参数,方便各个平台能够无缝对接训练好的网络模型。
8.4 自定义类
-
对于需要创建自定义逻辑的网络层,可以通过自定义类来实现,需要继承自 layers.Layer 基类;
-
对于需要创建自定义网络类,可以通过自定义类来实现,需要继承自 keras.Model 基类。
这样产生的自定义类才能够方便的利用 Layer / Model 基类提供的参数管理功能,同时也能与其他的标准网络层类交互使用。
8.4.1 自定义网络层
对于自定义的网络层,需要实现初始化
__init__
方法和前向传播逻辑
call
方法。
首先创建并继承自 Layer 基类,创建初始化方法函数,并调用母类的初始化函数。
以全连接层为例;需要设置特征的长度 inp_dim 和输出特征的长度 outp_dim,并通过 self.add_variable(name,shape) 创建 shape 大小,名字为 name 的张量,并设置为需要优化:
class MyDense(layers.Layer):
def __init__(self, inp_dim, outp_dim):
super(MyDense, self).__init__()
self.kernel = self.add_variable('w',shape = [inp_dim, outp_dim],
trainable = True)
def call(self,inputs):
out = inputs @ self.kernel
out = tf.nn.relu(out)
return out
8.4.2 自定义网络
1 . 自定义的类可以和其他标准类一样,通过 Sequential 容器方便地包裹成一个网络模型:
network = Sequential([
MyDense(784, 256),
MyDense(256, 128),
MyDense(128, 64),
MyDense(64, 32),
MyDense(32, 10)
])
network.build(input_shape = (None, 28 * 28))
network.summary()
- 可以继承基类来实现任意逻辑的自定义网络类。
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = MyDense(28*28, 256)
self.fc2 = MyDense(256, 128)
self.fc3 = MyDense(128, 64)
self.fc4 = MyDense(64, 32)
self.fc5 = MyDense(32, 10)
def call(self,inputs):
x = self.fc1(inputs)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
x = self.fc5(x)
return x
model = MyModel()
model.build(input_shape = (None, 28 * 28))
model.summary()
第二种方法比第一种方法灵活,第二种的前向逻辑可以任意定制(而第一种是依次调用每个网络层的前向传播函数)。
个人写法:可以不写个 Class,而用 tf.keras.layers.Input,这样写,也能任意定制前向逻辑。
8.5 模型乐园
对于常用的网络模型,如 ResNet,VGG 等,不需要手动创建网络,可以直接从 keras.appication 子模块下一行代码即可创建并使用这些经典模型,还可以通过设置 weights 参数加载预训练的网络参数。
8.5.1 加载模型
以 ResNet50 迁移学习为例,一般将 ResNet50 去掉最后一层后的网络作为新任务的特征提取子网络。
resnet = keras.applications.ResNet50(weights='imagenet',include_top=False)
resnet.summary()
参考资料:https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book
网友评论