Checkpoints (检查点)
在 TensorFlow 中,模型参数通常以 Checkpoints 的形式保存,一个 Checkpoint 文件(通常是 .ckpt 文件)并不是一个完整的、可独立运行的模型,它只包含了模型的权重(Weights)和偏置(Biases)。

(图片来源网络,侵删)
它通常会生成三个文件:
my_model.ckpt.index: 一个索引文件,用于查找变量。my_model.ckpt.data-00000-of-00001: 一个二进制文件,包含了所有变量的实际值。my_model.ckpt.meta: 一个协议缓冲区文件,包含了模型的计算图结构(Graph Definition),这个文件是可选的,但通常和 Checkpoint 一起保存,以便完整恢复模型。
tf.train.Checkpoint (推荐,现代且灵活)
这是目前 TensorFlow 2.x 中最推荐的方式,它不依赖于计算图的静态结构,而是通过一个对象的状态来追踪变量,这使得代码更简单、更健壮,尤其是在处理模型结构变化或使用 tf.function 时。
工作原理
tf.train.Checkpoint 会追踪任何 tf.Variable 对象,以及任何包含 tf.Variable 的对象(如 tf.keras.Model)。
示例代码
import tensorflow as tf
import os
# 1. 创建一个简单的模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])
# 2. 创建一个检查点对象
# 我们将模型和优化器都传入,这样它们的参数和状态都会被保存
checkpoint = tf.train.Checkpoint(model=model, optimizer=tf.keras.optimizers.Adam())
# 3. 定义保存路径
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt-{epoch}")
# 4. 保存参数
# 假设模型已经训练过一段时间
print("保存模型参数...")
# save() 方法会保存所有被追踪的对象的状态
checkpoint.save(checkpoint_prefix)
print("参数已保存。")
# 5. 恢复参数
# 创建一个新的模型实例(结构必须和之前一样)
new_model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10)
])
# 创建一个新的检查点对象,指向同一个目录
new_checkpoint = tf.train.Checkpoint(model=new_model)
# 从最新的检查点恢复
# 注意:这里我们使用 checkpoint.restore,而不是 new_checkpoint.restore
# 因为 checkpoint 对象在保存时记录了所有被追踪对象的路径
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
print("尝试恢复参数...")
# status.assert_existing_objects_matched() 是一个好习惯,确保所有变量都被正确恢复
status.assert_existing_objects_matched()
print("参数恢复成功!")
# 验证恢复后的模型权重是否和原始模型一样
print("原始模型第一层权重:", model.layers[0].get_weights()[0][0, 0])
print("恢复后模型第一层权重:", new_model.layers[0].get_weights()[0][0, 0])
# 你会发现它们的值是相同的
tf.keras.callbacks.ModelCheckpoint (在训练中自动保存)
这是在训练过程中最常用的方法,你可以将它作为回调函数传递给 model.fit(),模型会在每个 epoch 结束时自动保存参数。
示例代码
import tensorflow as tf
import numpy as np
# 1. 创建模型和编译
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 2. 准备虚拟数据
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
# 3. 定义 ModelCheckpoint 回调
# filepath: 保存文件的路径
# save_weights_only=True: 只保存权重,这是最轻量的方式
# save_best_only=True: 只保存验证集上表现最好的模型
# monitor='val_loss': 监控验证集的损失值
# mode='min': 损失值越小越好
checkpoint_path = "best_model.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
save_best_only=True,
monitor='val_loss',
mode='min')
# 4. 训练模型 (为了演示,我们使用一个小的训练集和验证集)
print("开始训练,并使用 ModelCheckpoint 回调...")
history = model.fit(x_train[:1000], y_train[:1000],
epochs=5,
batch_size=32,
validation_data=(x_train[1000:2000], y_train[1000:2000]),
callbacks=[cp_callback])
print("训练完成。")
# 5. 恢复模型参数
# 创建一个结构相同的新模型
new_model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
# 恢复时,模型的编译状态(优化器、损失函数等)需要重新设置
new_model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 加载保存的权重
new_model.load_weights(checkpoint_path)
print("已从 'best_model.ckpt' 加载权重。")
# 验证
loss, acc = new_model.evaluate(x_train[1000:2000], y_train[1000:2000], verbose=2)
print(f"恢复后模型的准确率: {acc:.4f}")
model.save() 和 model.load_weights() (保存完整模型或仅权重)
tf.keras.Model 对象本身提供了便捷的方法来保存和加载。
保存整个模型 (不推荐,除非需要部署)
model.save() 会保存:模型架构、权重、训练配置(优化器、损失等)和状态(如优化器的状态),这会生成一个包含所有信息的文件(通常是 .h5 或 SavedModel 格式)。
# 保存整个模型
model.save('my_complete_model.h5')
# 加载整个模型
loaded_model = tf.keras.models.load_model('my_complete_model.h5')
注意:这种方式虽然方便,但文件体积较大,且可能会因为 TensorFlow 版本不同而出现兼容性问题,它适用于需要将模型完整保存以便后续直接部署或继续训练的场景。
仅保存和加载权重 (常用)
如果你只需要模型的参数,save_weights() 和 load_weights() 是更好的选择,它们只处理权重文件。
# 保存权重到文件
model.save_weights('my_model_weights.h5')
# 创建一个结构相同的新模型
new_model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
# 加载权重
# 注意:新模型在加载权重前**不需要编译**
new_model.load_weights('my_model_weights.h5')
# 只有在使用模型进行预测或评估前,才需要编译
new_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
总结与选择建议
| 方法 | 优点 | 缺点 | 适用场景 | |
|---|---|---|---|---|
tf.train.Checkpoint |
权重、优化器状态、自定义对象 | 现代、灵活、健壮,不依赖静态图,是 TF2.x 的标准方式 | 相对较新,社区教程可能较少 | 新项目、训练脚本、复杂模型结构、需要保存优化器状态 |
ModelCheckpoint |
权重(或完整模型) | 训练时自动化,可配置保存最佳模型、按 epoch 保存等 | 主要用于训练过程 | 在训练过程中自动保存模型,特别是用于保存验证集上的最佳模型 |
model.save_weights() |
权重 | 轻量、简单,与 Keras 模型无缝集成 | 无法直接保存优化器状态 | 保存/恢复训练过程,模型部署前的准备 |
model.save() |
完整模型 | 方便、完整,一键保存所有信息 | 文件体积大,版本兼容性风险 | 保存最终训练好的模型,用于后续的预测、部署或分析 |
给你的建议:
- 如果你正在编写一个从头到尾的训练脚本:使用
tf.train.Checkpoint是最稳妥、最现代的选择。 - 如果你想在训练过程中自动保存模型:使用
tf.keras.callbacks.ModelCheckpoint回调函数,这是 Keras 的标准实践。 - 如果你只是想快速保存和加载一个已经训练好的模型的参数:使用
model.save_weights()和model.load_weights()非常直接。
