TensorFlow如何高效提取模型参数?

99ANYc3cd6
预计阅读时长 16 分钟
位置: 首页 参数 正文

我会从最简单、最推荐的方法讲起,然后逐步深入到更底层的 API。

tensorflow 提取参数
(图片来源网络,侵删)

使用 Keras 内置方法(推荐)

对于使用 Keras API(tf.keras)构建的模型,这是最简单、最安全、最推荐的方法,Keras 模型提供了专门的属性和方法来访问参数。

获取所有参数的列表

你可以使用 model.trainable_variablesmodel.weights

  • model.trainable_variables: 返回一个包含所有可训练参数的列表。
  • model.weights: 返回一个包含所有参数(包括可训练和不可训练的,如 BatchNormalization 的移动平均)的列表。
import tensorflow as tf
# 1. 创建一个简单的 Keras 模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 2. 获取所有可训练参数的列表
trainable_params = model.trainable_variables
print(f"可训练参数数量: {len(trainable_params)}")
# 获取所有参数(包括非训练参数)的列表
all_params = model.weights
print(f"所有参数数量: {len(all_params)}")
# 3. 遍历并打印每个参数的名称和形状
for param in trainable_params:
    print(f"参数名称: {param.name}")
    print(f"参数形状: {param.shape}")
    print("-" * 30)

输出示例:

可训练参数数量: 4
所有参数数量: 4
参数名称: dense/kernel:0
参数形状: (784, 128)
------------------------------
参数名称: dense/bias:0
参数形状: (128,)
------------------------------
参数名称: dense_1/kernel:0
参数形状: (128, 10)
------------------------------
参数名称: dense_1/bias:0
参数形状: (10,)
------------------------------
  • kernel:0 通常指权重矩阵。
  • bias:0 指偏置向量。

按层获取参数

你也可以直接访问每一层的参数。

tensorflow 提取参数
(图片来源网络,侵删)
# 获取模型中第一个 Dense 层的权重和偏置
first_layer = model.layers[0]
kernel = first_layer.kernel
bias = first_layer.bias
print(f"第一层权重名称: {kernel.name}, 形状: {kernel.shape}")
print(f"第一层偏置名称: {bias.name}, 形状: {bias.shape}")

获取参数的总数量和总大小

# 获取参数总数量
total_params = model.count_params()
print(f"模型总参数数量: {total_params}")
# 获取参数的总大小(字节)
total_size = sum([p.numpy().nbytes for p in model.trainable_variables])
print(f"模型可训练参数总大小: {total_size / (1024 * 1024):.2f} MB")

使用 TensorFlow 变量 API

当你构建一个更底层的模型,或者想直接操作 TensorFlow 变量时,可以使用 tf.Variable 相关的 API。

假设你已经有一个模型,并且已经通过 model.build() 或调用一次数据来创建了变量。

# 假设我们有一个和上面一样的模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 必须先调用一次模型以创建变量
# 用虚拟数据调用一次
_ = model(tf.random.normal((1, 784)))
# 1. 遍历模型的所有变量
print("使用 tf.Variable API:")
for var in model.variables:
    print(f"变量名称: {var.name}, 形状: {var.shape}, 是否可训练: {var.trainable}")
# 2. 通过名称获取特定变量
# 你需要知道变量的确切名称
kernel = model.get_layer('dense').get_variable('kernel')
print(f"\n通过名称获取的权重: {kernel.name}, 形状: {kernel.shape}")

注意: 这种方法需要你了解变量在模型中的确切名称,不如 Keras 内置方法直观。


使用检查点和 SavedModel

如果你已经训练好了一个模型,并希望提取其参数,最好的方式是使用 TensorFlow 的检查点或 SavedModel 格式,这不仅能保存参数,还能保存模型的结构。

保存模型

# 训练模型(这里用虚拟数据代替)
import numpy as np
x_train = np.random.rand(1000, 784).astype(np.float32)
y_train = np.random.randint(0, 10, (1000,)).astype(np.int32)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1, batch_size=32)
# 保存为 SavedModel 格式(推荐)
model.save('my_model.keras')
# 或者只保存权重为检查点文件
model.save_weights('my_model_weights.h5')

加载模型并提取参数

# 方式 A: 加载整个 SavedModel
loaded_model = tf.keras.models.load_model('my_model.keras')
# 现在你可以像处理原始模型一样提取参数
print("从 SavedModel 加载后提取参数:")
for param in loaded_model.trainable_variables:
    print(f"参数名称: {param.name}, 形状: {param.shape}")
# 方式 B: 从检查点文件加载权重到新模型
# 首先创建一个结构完全相同的新模型
new_model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 加载权重
new_model.load_weights('my_model_weights.h5')
# 然后提取参数
print("\n从检查点加载后提取参数:")
for param in new_model.trainable_variables:
    print(f"参数名称: {param.name}, 形状: {param.shape}")

总结与对比

方法 适用场景 优点 缺点
Keras 内置方法 tf.keras 模型,开发阶段 最简单、最安全,与模型结构高度耦合,无需关心底层细节 仅适用于 Keras 模型
TensorFlow 变量 API 底层模型构建,或直接操作变量 灵活,可以与任何 tf.Variable 交互 需要知道变量名称,代码不够直观
检查点 / SavedModel 模型训练完成后,持久化和恢复 最可靠,保存了完整的模型状态和结构,便于部署和迁移 需要先保存再加载,不适合在训练循环中频繁使用

最佳实践建议:

  • 在开发和训练过程中,始终使用 方法一(Keras 内置方法) 来检查和操作参数。
  • 在需要保存或加载模型时,使用 方法三(检查点 / SavedModel).keras.h5 格式是首选,因为它们包含了模型结构。
  • 仅在构建非常底层的自定义训练逻辑时,才考虑使用 方法二(TensorFlow 变量 API)
-- 展开阅读全文 --
头像
佳明vivosmart hr拆解,内部结构有何玄机?
« 上一篇 今天
ThinkPad T470拆机视频有哪些关键步骤?
下一篇 » 今天

相关文章

取消
微信二维码
支付宝二维码

最近发表

标签列表

目录[+]