您当前的位置:首页 > IT编程 > 深度学习
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch |

自学教程:深度学习模型复用——模型的保存与加载

51自学网 2022-12-08 20:45:41
  深度学习
这篇教程深度学习模型复用——模型的保存与加载写得很实用,希望能帮到您。

深度学习模型复用——模型的保存与加载

 

最近一遍学,一遍尝试进行模型的简单应用,需求驱动也是一个好的学习动力。

那么问题来了,难道我们每次应用模型,都要从头到尾训练一遍,然后再去做识别任务吗?

当然不是,所以,记录一下简单的模型保存模型加载过程。

只是抛砖引玉,和给自己记录一下。更多使用,请参考官方文档。

以识别手写数字模型为例。

1、保存模型

在识别手写数字模型训练之后,保存代码。

# 保存模型
model.save('test.h5') #保存到与代码文件同目录,h5:模型文件后缀名

在这里插入图片描述

保存后,即可看到,代码同目录下,我们保存的模型文件。
在这里插入图片描述

2、加载模型

在复用模型的地方:

my_model = tf.keras.models.load_model('test.h5') 

即可加载我们保存过的模型。

加载后,即可使用模型的一些方法了。

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
# 加载并预处理数据
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
# 加载模型
my_model = tf.keras.models.load_model('test.h5') 
my_model.summary()

在这里插入图片描述

在这里插入图片描述

3、保存模型的更多姿势

上面这种保存方式,是完整保存模式,保存的模型文件包括:

  • architecture 模型的结构
  • weight values 模型的权值
  • training config 模型的配置:即我们通过compile编译模型的一些信息,如优化器,损失函数等
  • optimizer and its state 优化器的状态信息,我们可以接着之前的训练继续训练

也可以分别保存:

3.1 保存模型结构

# 保存模型结构
with open('model_config.json', 'w') as json_file:
    json_file.write(json_config)


# 加载
with open('model_config.json') as json_file:
    json_config = json_file.read()

3.2 保存模型权重

# 保存权重到磁盘,注意,这里是一个 h5 文件哦!
model.save_weights('my_weights.h5')
 
# 新模型从磁盘加载权重
new_model.load_weights('my_weights.h5')

3.3 保存为SavedModel格式

另外,也可以将模型保存为tensorflow标准的SavedModel格式,这是对tensorflow对象标准的序列化格式,是官方推荐使用,不同的是,他不是将模型保存为一个单独的文件,而是有几个文件组成。

# 模型保存,注意:仅仅是多了一个save_format的参数而已
# 注意:这里的'path_to_saved_model'不再是模型名称,仅仅是一个文件夹,模型会保存在这个文件夹之下
model.save('mymodel', save_format='tf')
 
# 加载模型,通过指定存放模型的文件夹来加载
new_model = keras.models.load_model('mymodel')

需要注意的是:这种方式依然会保存模型的所有信息,即“网络结构、权重、配置、优化器状态”四个信息,所以可以接着训练


Keras CAM实现-绘制CNN每层的类激活图(CAM)
全面解决样本不均衡(Python)的方法
51自学网,即我要自学网,自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1