tensorflow 获取模型所有参数总和数量

xiaoxiao2021-02-28  103

from functools import reduce from operator import mul def get_num_params(): num_params = 0 for variable in tf.trainable_variables(): shape = variable.get_shape() num_params += reduce(mul, [dim.value for dim in shape], 1) return num_params
转载请注明原文地址: https://www.6miu.com/read-70418.html

最新回复(0)