最近一遍学,一遍尝试进行模型的简单应用,需求驱动也是一个好的学习动力。
那么问题来了,难道我们每次应用模型,都要从头到尾训练一遍,然后再去做识别任务吗?
当然不是,所以,记录一下简单的模型保存和模型加载过程。
只是抛砖引玉,和给自己记录一下。更多使用,请参考官方文档。
以识别手写数字模型为例。
1、保存模型
在识别手写数字模型训练之后,保存代码。
model.save('test.h5')
保存后,即可看到,代码同目录下,我们保存的模型文件。
2、加载模型
在复用模型的地方:
my_model = tf.keras.models.load_model('test.h5')
即可加载我们保存过的模型。
加载后,即可使用模型的一些方法了。
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
my_model = tf.keras.models.load_model('test.h5')
my_model.summary()
3、保存模型的更多姿势
上面这种保存方式,是完整保存模式,保存的模型文件包括:
- architecture 模型的结构
- weight values 模型的权值
- training config 模型的配置:即我们通过compile编译模型的一些信息,如优化器,损失函数等
- optimizer and its state 优化器的状态信息,我们可以接着之前的训练继续训练
也可以分别保存:
3.1 保存模型结构
with open('model_config.json', 'w') as json_file:
json_file.write(json_config)
with open('model_config.json') as json_file:
json_config = json_file.read()
3.2 保存模型权重
model.save_weights('my_weights.h5')
new_model.load_weights('my_weights.h5')
3.3 保存为SavedModel格式
另外,也可以将模型保存为tensorflow
标准的SavedModel
格式,这是对tensorflow对象标准的序列化格式,是官方推荐使用,不同的是,他不是将模型保存为一个单独的文件,而是有几个文件组成。
model.save('mymodel', save_format='tf')
new_model = keras.models.load_model('mymodel')
需要注意的是:这种方式依然会保存模型的所有信息,即“网络结构、权重、配置、优化器状态”四个信息,所以可以接着训练。