TensorFlow保存参数的步骤与代码是什么?

99ANYc3cd6
预计阅读时长 22 分钟
位置: 首页 参数 正文

Checkpoints (检查点)

在 TensorFlow 中,模型参数通常以 Checkpoints 的形式保存,一个 Checkpoint 文件(通常是 .ckpt 文件)并不是一个完整的、可独立运行的模型,它只包含了模型的权重(Weights)和偏置(Biases)

tensorflow 保存参数
(图片来源网络,侵删)

它通常会生成三个文件:

  • my_model.ckpt.index: 一个索引文件,用于查找变量。
  • my_model.ckpt.data-00000-of-00001: 一个二进制文件,包含了所有变量的实际值。
  • my_model.ckpt.meta: 一个协议缓冲区文件,包含了模型的计算图结构(Graph Definition),这个文件是可选的,但通常和 Checkpoint 一起保存,以便完整恢复模型。

tf.train.Checkpoint (推荐,现代且灵活)

这是目前 TensorFlow 2.x 中最推荐的方式,它不依赖于计算图的静态结构,而是通过一个对象的状态来追踪变量,这使得代码更简单、更健壮,尤其是在处理模型结构变化或使用 tf.function 时。

工作原理

tf.train.Checkpoint 会追踪任何 tf.Variable 对象,以及任何包含 tf.Variable 的对象(如 tf.keras.Model)。

示例代码

import tensorflow as tf
import os
# 1. 创建一个简单的模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10)
])
# 2. 创建一个检查点对象
# 我们将模型和优化器都传入,这样它们的参数和状态都会被保存
checkpoint = tf.train.Checkpoint(model=model, optimizer=tf.keras.optimizers.Adam())
# 3. 定义保存路径
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt-{epoch}")
# 4. 保存参数
# 假设模型已经训练过一段时间
print("保存模型参数...")
# save() 方法会保存所有被追踪的对象的状态
checkpoint.save(checkpoint_prefix) 
print("参数已保存。")
# 5. 恢复参数
# 创建一个新的模型实例(结构必须和之前一样)
new_model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10)
])
# 创建一个新的检查点对象,指向同一个目录
new_checkpoint = tf.train.Checkpoint(model=new_model)
# 从最新的检查点恢复
# 注意:这里我们使用 checkpoint.restore,而不是 new_checkpoint.restore
# 因为 checkpoint 对象在保存时记录了所有被追踪对象的路径
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
print("尝试恢复参数...")
# status.assert_existing_objects_matched() 是一个好习惯,确保所有变量都被正确恢复
status.assert_existing_objects_matched() 
print("参数恢复成功!")
# 验证恢复后的模型权重是否和原始模型一样
print("原始模型第一层权重:", model.layers[0].get_weights()[0][0, 0])
print("恢复后模型第一层权重:", new_model.layers[0].get_weights()[0][0, 0])
# 你会发现它们的值是相同的

tf.keras.callbacks.ModelCheckpoint (在训练中自动保存)

这是在训练过程中最常用的方法,你可以将它作为回调函数传递给 model.fit(),模型会在每个 epoch 结束时自动保存参数。

示例代码

import tensorflow as tf
import numpy as np
# 1. 创建模型和编译
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# 2. 准备虚拟数据
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
# 3. 定义 ModelCheckpoint 回调
# filepath: 保存文件的路径
# save_weights_only=True: 只保存权重,这是最轻量的方式
# save_best_only=True: 只保存验证集上表现最好的模型
# monitor='val_loss': 监控验证集的损失值
# mode='min': 损失值越小越好
checkpoint_path = "best_model.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='val_loss',
                                                 mode='min')
# 4. 训练模型 (为了演示,我们使用一个小的训练集和验证集)
print("开始训练,并使用 ModelCheckpoint 回调...")
history = model.fit(x_train[:1000], y_train[:1000],
                    epochs=5,
                    batch_size=32,
                    validation_data=(x_train[1000:2000], y_train[1000:2000]),
                    callbacks=[cp_callback])
print("训练完成。")
# 5. 恢复模型参数
# 创建一个结构相同的新模型
new_model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 恢复时,模型的编译状态(优化器、损失函数等)需要重新设置
new_model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
# 加载保存的权重
new_model.load_weights(checkpoint_path)
print("已从 'best_model.ckpt' 加载权重。")
# 验证
loss, acc = new_model.evaluate(x_train[1000:2000], y_train[1000:2000], verbose=2)
print(f"恢复后模型的准确率: {acc:.4f}")

model.save()model.load_weights() (保存完整模型或仅权重)

tf.keras.Model 对象本身提供了便捷的方法来保存和加载。

保存整个模型 (不推荐,除非需要部署)

model.save() 会保存:模型架构、权重、训练配置(优化器、损失等)和状态(如优化器的状态),这会生成一个包含所有信息的文件(通常是 .h5SavedModel 格式)。

# 保存整个模型
model.save('my_complete_model.h5') 
# 加载整个模型
loaded_model = tf.keras.models.load_model('my_complete_model.h5')

注意:这种方式虽然方便,但文件体积较大,且可能会因为 TensorFlow 版本不同而出现兼容性问题,它适用于需要将模型完整保存以便后续直接部署或继续训练的场景。

仅保存和加载权重 (常用)

如果你只需要模型的参数,save_weights()load_weights() 是更好的选择,它们只处理权重文件。

# 保存权重到文件
model.save_weights('my_model_weights.h5')
# 创建一个结构相同的新模型
new_model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 加载权重
# 注意:新模型在加载权重前**不需要编译**
new_model.load_weights('my_model_weights.h5')
# 只有在使用模型进行预测或评估前,才需要编译
new_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

总结与选择建议

方法 优点 缺点 适用场景
tf.train.Checkpoint 权重、优化器状态、自定义对象 现代、灵活、健壮,不依赖静态图,是 TF2.x 的标准方式 相对较新,社区教程可能较少 新项目、训练脚本、复杂模型结构、需要保存优化器状态
ModelCheckpoint 权重(或完整模型) 训练时自动化,可配置保存最佳模型、按 epoch 保存等 主要用于训练过程 在训练过程中自动保存模型,特别是用于保存验证集上的最佳模型
model.save_weights() 权重 轻量、简单,与 Keras 模型无缝集成 无法直接保存优化器状态 保存/恢复训练过程,模型部署前的准备
model.save() 完整模型 方便、完整,一键保存所有信息 文件体积大,版本兼容性风险 保存最终训练好的模型,用于后续的预测、部署或分析

给你的建议:

  • 如果你正在编写一个从头到尾的训练脚本:使用 tf.train.Checkpoint 是最稳妥、最现代的选择。
  • 如果你想在训练过程中自动保存模型:使用 tf.keras.callbacks.ModelCheckpoint 回调函数,这是 Keras 的标准实践。
  • 如果你只是想快速保存和加载一个已经训练好的模型的参数:使用 model.save_weights()model.load_weights() 非常直接。
-- 展开阅读全文 --
头像
宏基Aspire V15拆机后内部有何不同?
« 上一篇 今天
智能手机的移动支付案例
下一篇 » 8分钟前

相关文章

取消
微信二维码
支付宝二维码

最近发表

标签列表

目录[+]