这篇教程深度学习模型运行的浮点次数FLOPs和训练参数程序自动获取方法写得很实用,希望能帮到您。 import tensorflow as tf # 必须要下面这行代码 tf.compat.v1.disable_eager_execution() print(tf.__version__) # 我自己使用的函数 def get_flops_params(): sess = tf.compat.v1.Session() graph = sess.graph flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()) params = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter()) print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters)) # 网上推荐的 # sess = tf.compat.v1.Session() # graph = sess.graph # stats_graph(graph) def stats_graph(graph): flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()) # print('FLOPs: {}'.format(flops.total_float_ops)) params = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter()) # print('Trainable params: {}'.format(params.total_parameters)) print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters)) def get_flops(model): run_meta = tf.compat.v1.RunMetadata() opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() # We use the Keras session graph in the call to the profiler. flops = tf.compat.v1.profiler.profile(graph=tf.compat.v1.keras.backend.get_session().graph, run_meta=run_meta, cmd='op', options=opts) return flops.total_float_ops # Prints the "flops" of the model. # 必须使用tensorflow中的keras才能够获取到FLOPs, 模型中的各个函数都必须使用tensorflow.keras中的函数,和keras混用会报错 from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense from tensorflow.keras.models import Sequential model = Sequential() model.add(Conv2D(filters=64, kernel_size=(3, 3), input_shape=(28, 28, 1), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Flatten()) model.add(Dense(units=100, activation='relu')) model.add(Dense(units=10, activation='softmax')) # 获取模型每一层的参数详情 model.summary() # 获取模型浮点运算总次数和模型的总参数 get_flops_params() 深度学习中常用的几种卷积归纳总结 计算pytorch构建的网络的参数,空间大小,MAdd,FLOPs等指标 |