TensorFlow如何恢复参数?30秒掌握模型权重加载技巧

99ANYc3cd6
预计阅读时长 18 分钟
位置: 首页 参数 正文
  1. 恢复完整的模型:包括模型的结构(架构)和权重,这是最简单、最常用的方法,特别是当你使用 Keras 高级 API 时。
  2. 恢复预训练的权重:当你已经定义好了模型结构,只是想加载别人训练好的或者之前保存的权重时使用。

下面我将详细介绍这两种情况,并提供清晰的代码示例。

tensorflow 恢复参数
(图片来源网络,侵删)

恢复完整的模型(推荐)

当你使用 model.save() 保存一个 Keras 模型时,它会创建一个包含模型架构、权重、训练配置(如优化器、损失函数)和状态(如优化器状态)的文件,恢复起来非常简单。

保存模型

我们创建并训练一个简单的模型,然后保存它。

import tensorflow as tf
import numpy as np
# 1. 准备数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 2. 创建模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 3. 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# 4. 训练模型 (为了演示,只训练一个epoch)
print("训练模型...")
model.fit(x_train, y_train, epochs=1)
print("训练完成。")
# 5. 保存完整模型
# 这会创建一个包含模型所有信息的文件夹
model.save('my_complete_model')
print("模型已保存到 'my_complete_model' 目录。")
# 你可以用 model.save('my_complete_model.h5') 来保存为单个 HDF5 文件
# model.save('my_complete_model.h5')

恢复模型

我们可以在一个新的 Python 脚本或同一个脚本的后续部分中加载这个模型。

import tensorflow as tf
# 1. 恢复模型
# TensorFlow 会自动从目录中重建模型结构、权重和配置
loaded_model = tf.keras.models.load_model('my_complete_model')
# 如果你保存的是 .h5 文件,用法也一样
# loaded_model = tf.keras.models.load_model('my_complete_model.h5')
print("模型已成功加载!")
# 2. 验证模型
# 你可以像使用原始模型一样使用加载的模型
# 进行评估
(_, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test / 255.0
loss, accuracy = loaded_model.evaluate(x_test, y_test, verbose=2)
print(f"加载的模型的准确率: {accuracy:.4f}")
# 也可以用它来做预测
predictions = loaded_model.predict(x_test[:1])
print(f"对第一张测试图片的预测结果: {np.argmax(predictions)}")

这种方法最简单,因为它保存和恢复的是“模型对象”本身,你无需重新定义模型结构。


恢复预训练的权重

当你已经有了模型的结构定义,但想加载一组预先训练好的权重时,这种方法就派上用场了。

保存权重

我们创建一个模型,训练它,然后只保存其权重。

import tensorflow as tf
import numpy as np
# 1. 准备数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 2. 创建模型 (注意:这里的定义必须和保存权重的模型结构完全一致)
model_to_save = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 3. 编译并训练
model_to_save.compile(optimizer='adam',
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
print("训练模型以保存权重...")
model_to_save.fit(x_train, y_train, epochs=1)
print("训练完成。")
# 4. 只保存模型的权重 (保存为单个文件)
model_to_save.save_weights('my_model_weights.h5')
print("模型权重已保存到 'my_model_weights.h5'。")

加载权重

我们创建一个新的、但结构完全相同的模型,然后将之前保存的权重加载进去。

import tensorflow as tf
import numpy as np
# 1. 创建一个**新的**模型,其结构必须与保存权重的模型完全相同
model_to_load = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 2. 编译模型 (加载权重后通常需要重新编译,特别是如果你要继续训练)
model_to_load.compile(optimizer='adam',
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
# 3. 加载权重
# 在加载权重之前,模型必须被构建好(即至少运行一次,以创建权重变量)
# compile() 通常就足够了
model_to_load.load_weights('my_model_weights.h5')
print("权重已成功加载到新模型中!")
# 4. 验证模型
(_, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test / 255.0
loss, accuracy = model_to_load.evaluate(x_test, y_test, verbose=2)
print(f"加载权重后的模型准确率: {accuracy:.4f}")

关键点

  • 结构必须一致model_to_load 的层数量、层类型、层名称(如果未指定,Keras会自动生成)和输入/输出形状必须与保存权重的模型完全相同,否则,load_weights 会失败。
  • 先构建,后加载:在调用 load_weights 之前,模型必须已经被实例化并构建好(通过 compile() 或运行一次 model.predict())。

进阶:TensorFlow 2.x 的 SavedModel 格式

model.save() 默认使用的是 TensorFlow 的 SavedModel 格式,这是一个更通用、更强大的格式,不仅包含 Keras 模型,还包含 TensorFlow 的计算图。

  • 保存: model.save('path_to_dir')
  • 加载: tf.keras.models.load_model('path_to_dir')

这种格式是 TensorFlow Serving 和 TensorFlow.js 等工具部署模型时的标准格式。

总结与对比

特性 model.save() / load_model() (恢复完整模型) model.save_weights() / load_weights() (恢复权重)
模型架构 + 权重 + 优化器状态 + 训练配置 只有模型的权重
使用场景 保存和恢复整个训练流程,用于模型部署、迁移学习或中断后继续训练 微调预训练模型、在不同实验间复用权重
模型要求 无需重新定义模型结构 必须重新定义一个结构完全相同的新模型
便捷性 非常方便,一步搞定 需要手动定义模型结构,稍显繁琐
文件格式 默认为目录 (SavedModel),也可为 .h5 单个文件,通常为 .h5

给你的建议

  • 对于大多数应用,使用 model.save()tf.keras.models.load_model() 是最简单、最不容易出错的选择。
  • 当你进行迁移学习或使用预训练模型(如 VGG16, ResNet)时,通常使用 load_weights()model.load_weights() 来加载预训练层的权重,然后只训练你自己的顶层。
-- 展开阅读全文 --
头像
Bose Revolve参数有哪些关键点值得注意?
« 上一篇 前天
15款MacBook Pro拆机,散热升级还是挤牙膏?
下一篇 » 前天

相关文章

取消
微信二维码
支付宝二维码

最近发表

标签列表

目录[+]