- 恢复完整的模型:包括模型的结构(架构)和权重,这是最简单、最常用的方法,特别是当你使用 Keras 高级 API 时。
- 恢复预训练的权重:当你已经定义好了模型结构,只是想加载别人训练好的或者之前保存的权重时使用。
下面我将详细介绍这两种情况,并提供清晰的代码示例。

(图片来源网络,侵删)
恢复完整的模型(推荐)
当你使用 model.save() 保存一个 Keras 模型时,它会创建一个包含模型架构、权重、训练配置(如优化器、损失函数)和状态(如优化器状态)的文件,恢复起来非常简单。
保存模型
我们创建并训练一个简单的模型,然后保存它。
import tensorflow as tf
import numpy as np
# 1. 准备数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 2. 创建模型
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# 3. 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 4. 训练模型 (为了演示,只训练一个epoch)
print("训练模型...")
model.fit(x_train, y_train, epochs=1)
print("训练完成。")
# 5. 保存完整模型
# 这会创建一个包含模型所有信息的文件夹
model.save('my_complete_model')
print("模型已保存到 'my_complete_model' 目录。")
# 你可以用 model.save('my_complete_model.h5') 来保存为单个 HDF5 文件
# model.save('my_complete_model.h5')
恢复模型
我们可以在一个新的 Python 脚本或同一个脚本的后续部分中加载这个模型。
import tensorflow as tf
# 1. 恢复模型
# TensorFlow 会自动从目录中重建模型结构、权重和配置
loaded_model = tf.keras.models.load_model('my_complete_model')
# 如果你保存的是 .h5 文件,用法也一样
# loaded_model = tf.keras.models.load_model('my_complete_model.h5')
print("模型已成功加载!")
# 2. 验证模型
# 你可以像使用原始模型一样使用加载的模型
# 进行评估
(_, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test / 255.0
loss, accuracy = loaded_model.evaluate(x_test, y_test, verbose=2)
print(f"加载的模型的准确率: {accuracy:.4f}")
# 也可以用它来做预测
predictions = loaded_model.predict(x_test[:1])
print(f"对第一张测试图片的预测结果: {np.argmax(predictions)}")
这种方法最简单,因为它保存和恢复的是“模型对象”本身,你无需重新定义模型结构。
恢复预训练的权重
当你已经有了模型的结构定义,但想加载一组预先训练好的权重时,这种方法就派上用场了。
保存权重
我们创建一个模型,训练它,然后只保存其权重。
import tensorflow as tf
import numpy as np
# 1. 准备数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 2. 创建模型 (注意:这里的定义必须和保存权重的模型结构完全一致)
model_to_save = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# 3. 编译并训练
model_to_save.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
print("训练模型以保存权重...")
model_to_save.fit(x_train, y_train, epochs=1)
print("训练完成。")
# 4. 只保存模型的权重 (保存为单个文件)
model_to_save.save_weights('my_model_weights.h5')
print("模型权重已保存到 'my_model_weights.h5'。")
加载权重
我们创建一个新的、但结构完全相同的模型,然后将之前保存的权重加载进去。
import tensorflow as tf
import numpy as np
# 1. 创建一个**新的**模型,其结构必须与保存权重的模型完全相同
model_to_load = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
# 2. 编译模型 (加载权重后通常需要重新编译,特别是如果你要继续训练)
model_to_load.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 3. 加载权重
# 在加载权重之前,模型必须被构建好(即至少运行一次,以创建权重变量)
# compile() 通常就足够了
model_to_load.load_weights('my_model_weights.h5')
print("权重已成功加载到新模型中!")
# 4. 验证模型
(_, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test / 255.0
loss, accuracy = model_to_load.evaluate(x_test, y_test, verbose=2)
print(f"加载权重后的模型准确率: {accuracy:.4f}")
关键点:
- 结构必须一致:
model_to_load的层数量、层类型、层名称(如果未指定,Keras会自动生成)和输入/输出形状必须与保存权重的模型完全相同,否则,load_weights会失败。 - 先构建,后加载:在调用
load_weights之前,模型必须已经被实例化并构建好(通过compile()或运行一次model.predict())。
进阶:TensorFlow 2.x 的 SavedModel 格式
model.save() 默认使用的是 TensorFlow 的 SavedModel 格式,这是一个更通用、更强大的格式,不仅包含 Keras 模型,还包含 TensorFlow 的计算图。
- 保存:
model.save('path_to_dir') - 加载:
tf.keras.models.load_model('path_to_dir')
这种格式是 TensorFlow Serving 和 TensorFlow.js 等工具部署模型时的标准格式。
总结与对比
| 特性 | model.save() / load_model() (恢复完整模型) |
model.save_weights() / load_weights() (恢复权重) |
|---|---|---|
| 模型架构 + 权重 + 优化器状态 + 训练配置 | 只有模型的权重 | |
| 使用场景 | 保存和恢复整个训练流程,用于模型部署、迁移学习或中断后继续训练 | 微调预训练模型、在不同实验间复用权重 |
| 模型要求 | 无需重新定义模型结构 | 必须重新定义一个结构完全相同的新模型 |
| 便捷性 | 非常方便,一步搞定 | 需要手动定义模型结构,稍显繁琐 |
| 文件格式 | 默认为目录 (SavedModel),也可为 .h5 |
单个文件,通常为 .h5 |
给你的建议:
- 对于大多数应用,使用
model.save()和tf.keras.models.load_model()是最简单、最不容易出错的选择。 - 当你进行迁移学习或使用预训练模型(如 VGG16, ResNet)时,通常使用
load_weights()或model.load_weights()来加载预训练层的权重,然后只训练你自己的顶层。
