tensorflow导入部分checkpoint

xiaoxiao2025-05-19  38

版权声明: https://blog.csdn.net/b876144622/article/details/79962727 现实中碰到一个问题,训练好分类模型,比如训练保存了一个10分类的模型,但是实际用的时候呢,可能是做20分类,但是还想继续使用前面保存的模型。那么相当于是只加载前几层的参数,最后一层做一些修改。 一般实验情况下保存的时候,都是用的saver类来保存,如下 saver = tf.train.Saver() saver.save(sess,"model.ckpt")加载时的代码 saver.restore(sess,"model.ckpt") 前面的描述相当于是保存了所有的参数,然后加载所有的参数。但是目前的情况有所变化了,不能加载所有的参数,最后一层的参数不一样了,需要随机初始化。如何操作呢? 首先对每一层添加name scope,如下: with name_scope('conv1'): xxx with name_scope('conv2'): xxx with name_scope('fc1'): xxx with name_scope('output'): xxx 然后根据变量的名字,选择加载哪些变量, #得到该网络中,所有可以加载的参数 variables = tf.contrib.framework.get_variables_to_restore() #删除output层中的参数 variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='output'] #构建这部分参数的saver saver = tf.train.Saver(variables_to_restore) saver.restore(sess,'model.ckpt') 在tensorflow中,有多种方式可以得到变量的信息: tf.contrib.framework.get_variables_to_restore() tf.all_variables() tf.trainable_varialbes()

等等,可以多看看API

转载请注明原文地址: https://www.6miu.com/read-5030346.html

最新回复(0)