TensorFlow如何高效统计模型参数数量?

99ANYc3cd6
预计阅读时长 20 分钟
位置: 首页 参数 正文

核心概念

在 TensorFlow 中,模型参数通常指的是权重(Weights)偏置,这些参数是模型在训练过程中需要学习并优化的变量,统计参数通常包括:

tensorflow 统计参数
(图片来源网络,侵删)
  1. 总参数量:模型中所有参数的总个数。
  2. 可训练参数量:在训练过程中会被梯度更新器(如 tf.keras.optimizers.Adam)更新的参数个数,绝大多数参数都是可训练的。
  3. 非可训练参数量:模型中存在但不会被训练更新的参数,批归一化层中的移动均值和移动方差,它们在训练时通过指数移动平均更新,但不受梯度下降器控制。
  4. 参数大小:所有参数占用的内存空间(通常以 MB 或 GB 为单位)。

使用 tf.keras.Model.summary() (最简单、最常用)

这是最直接、最推荐的方法,尤其适用于 tf.keras 构建的模型,它会打印出一个格式化的表格,清晰地展示每一层的类型、输出形状、参数量以及模型的总参数量。

示例代码

import tensorflow as tf
from tensorflow.keras import layers, models
# 创建一个简单的卷积神经网络模型
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10) # 输出层,10个类别
])
# 打印模型摘要
model.summary()

输出结果

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
 max_pooling2d (MaxPooling2D  (None, 13, 13, 32)        0         
 )                                                               
 conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496     
 max_pooling2d_1 (MaxPooling  (None, 5, 5, 64)          0         
 2D)                                                             
 conv2d_2 (Conv2D)           (None, 3, 3, 64)          36928     
 flatten (Flatten)           (None, 576)               0         
 dense (Dense)               (None, 64)                36928     
 dense_1 (Dense)             (None, 10)                650       
=================================================================
Total params: 92,722
Trainable params: 92,722
Non-trainable params: 0
_________________________________________________________________

输出解读

  • Layer (type): 层的名称和类型。
  • Output Shape: 该层输出的张量形状。None 通常代表批处理大小。
  • Param #: 该层拥有的参数数量。
    • Conv2D: 参数量 = (卷积核高度 * 卷积核宽度 * 输入通道数 + 1) * 输出通道数+1 是指偏置项。
      • 第一个 Conv2D: (3 * 3 * 1 + 1) * 32 = 320
    • Dense: 参数量 = (输入节点数 + 1) * 输出节点数
      • 第一个 Dense: (576 + 1) * 64 = 36928
  • Total params: 模型所有参数的总和。
  • Trainable params: 可训练参数的总和。
  • Non-trainable params: 非可训练参数的总和。

遍历模型变量并手动计算 (更灵活)

如果你想以编程方式获取这些数字,或者需要对参数进行更复杂的分析,可以手动遍历模型的变量。

示例代码

# 使用上面创建的同一个模型
# 1. 获取所有变量
all_vars = model.trainable_variables + model.non_trainable_variables
# 2. 初始化计数器
total_params = 0
trainable_params = 0
non_trainable_params = 0
# 3. 遍历变量并计算
for var in all_vars:
    # var.shape 是一个 TensorShape 对象,可以使用 .num_elements() 获取元素总数
    param_count = var.shape.num_elements()
    total_params += param_count
    # 检查变量是否可训练
    if var.trainable:
        trainable_params += param_count
    else:
        non_trainable_params += param_count
# 4. 打印结果
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")
print(f"Non-trainable params: {non_trainable_params:,}")
# 计算参数大小 (假设 float32 类型,每个参数占4字节)
param_size_bytes = total_params * 4
param_size_mb = param_size_bytes / (1024 * 1024)
print(f"Model size (params only): {param_size_mb:.2f} MB")

输出结果

Total params: 92,722
Trainable params: 92,722
Non-trainable params: 0
Model size (params only): 0.35 MB

说明:

  • model.trainable_variables 返回一个包含所有可训练变量的列表。
  • model.non_trainable_variables 返回一个包含所有非可训练变量的列表。
  • var.shape.num_elements() 是一个非常方便的函数,可以直接计算一个张量中所有元素的总数。

使用 tf.keras.utils.plot_model (可视化模型结构)

虽然这个函数主要用于绘制模型结构图,但它生成的图片中通常也包含了每一层的参数信息,这对于报告和演示非常有用。

tensorflow 统计参数
(图片来源网络,侵删)

示例代码

# 使用上面创建的同一个模型
# 绘制模型结构图,并显示层形状和参数信息
tf.keras.utils.plot_model(
    model,
    to_file='model_plot.png', # 保存图片的文件名
    show_shapes=True,          # 显示层的输入/输出形状
    show_layer_names=True,     # 显示层名称
    show_dtype=True,           # 显示层的数据类型
    expand_nested=False,       # 是否展开嵌套模型
    dpi=96,                    # 图片分辨率
    submodel=False
)

这会生成一张名为 model_plot.png 的图片,图片中会清晰地展示出模型的结构和各层的参数量。


特殊情况:统计特定层的参数

有时候你可能只想统计某一类层的参数,比如所有卷积层的参数。

示例代码

# 使用上面创建的同一个模型
conv_params = 0
dense_params = 0
# 遍历模型的所有层
for layer in model.layers:
    # 检查层类型
    if isinstance(layer, layers.Conv2D):
        # 获取该层的权重和偏置
        weights, biases = layer.get_weights()
        conv_params += weights.size + biases.size
    elif isinstance(layer, layers.Dense):
        weights, biases = layer.get_weights()
        dense_params += weights.size + biases.size
print(f"Total Conv2D params: {conv_params:,}")
print(f"Total Dense params: {dense_params:,}")

输出结果

Total Conv2D params: 55,744
Total Dense params: 37,578

说明:

  • layer.get_weights() 返回一个列表,包含了该层的所有权重张量,对于 Conv2DDense,通常是 [权重矩阵, 偏置向量]
  • tensor.sizetensor.shape.num_elements() 的一个便捷别名。
方法 优点 缺点 适用场景
model.summary() 最简单、最直观,信息全面,格式化好。 以文本形式输出,难以在代码中进一步处理。 日常开发、调试、快速查看模型信息的首选。
手动遍历变量 非常灵活,可编程控制,可进行复杂分析(如分类统计)。 代码稍显繁琐,需要自己实现计数逻辑。 需要将参数数量用于脚本逻辑、生成报告、或进行更细粒度分析时。
plot_model 可视化效果好,适合用于报告、论文和演示。 依赖图形库(如 pydot),输出是图片而非数据。 需要向他人展示模型结构和参数时。

对于绝大多数情况,model.summary() 已经完全足够并且是最好的选择,当你需要在代码中自动获取这些数字时,手动遍历变量则是最可靠的方法。

tensorflow 统计参数
(图片来源网络,侵删)
-- 展开阅读全文 --
头像
儿童智能手表真能守护安全吗?
« 上一篇 今天
S7 LED智能保护套,智能在哪?如何保护?
下一篇 » 今天

相关文章

取消
微信二维码
支付宝二维码

最近发表

标签列表

目录[+]