tensorflow 4——模型的保存、读取

xiaoxiao2021-02-28  66

tf.train.Saver类为tensorflow的一个API 可通过import tensorflow as tf help(tf.train.Saver)来查看这个API的用法

import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1') v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2') v3=tf.Variable(tf.constant(3.0,shape=[1]),name='v3') result1=v1+v2 result2=v2+v3 print(result1,result2)

Tensor(“add:0”, shape=(1,), dtype=float32) Tensor(“add_1:0”, shape=(1,), dtype=float32)

可以看出两个张量的名称add:0和add_1:0。指的是加法的名称、次数以及一开始初始化的第一个值 接下来看如何保存图

saver=tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess,'/path/to/model.ckpt') #然后就可以在路径下发现存储的文件 #接下来加载这些文件 import tensorflow as tf saver1=tf.train.import_meta_graph('/path/to/model.ckpt.meta')#加载图 with tf.Session() as sess: saver1.restore(sess,'/path/to/model.ckpt') #将图上的数据加载进来 print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

INFO:tensorflow:Restoring parameters from /path/to/model.ckpt [ 3.]

若在上述会话中,print(result2)会出现result2 is not defined的问题。所以加载图和张量时,只能依靠张量的名称来获取其值,接下来看看如何重定义或加载本来保存在图里的张量v1,v2,result

a1=tf.Variable(tf.constant(3.0,shape=[1]),name='a1') a2=tf.Variable(tf.constant(4.0,shape=[1]),name='a2') saver=tf.train.Saver({'v1':a1,'v2':a2}) #将原本图上的变量v1,v2加载过来到新的张量a1,a2上,可以看到a1为1而不是3 with tf.Session() as sess: saver.restore(sess,'/path/to/model.ckpt') print(sess.run(a1))

INFO:tensorflow:Restoring parameters from /path/to/model.ckpt [ 1.] 一般情况下,从保存的模型文件中加载计算原图meta,再restore所需变量值就可以达到调控模型参数的目的。 即模型每训练1000次保存一个模型,假若发现第4000次训练过拟合,第3000次训练的模型不太理想,则可加载restore第3000次模型的计算图以及变量,并在这基础上训练500次,来方便地实现调控模型训练次数的方式而又避免了重复训练,并可作多项研究。

saver指定要保存的变量,saver.save则指定在某个会话下,模型保存的路径,以及全局迭代的次数。 | saver = tf.train.Saver(…variables…) | # Launch the graph and train, saving the model every 1,000 steps. | sess = tf.Session() | for step in xrange(1000000): | sess.run(..training_op..) | if step % 1000 == 0: | # Append the step number to the checkpoint name: | saver.save(sess, ‘my-model’, global_step=step)

#一个滑动平均类变量的保存 import tensorflow as tf v=tf.Variable(0,dtype=tf.float32,name='v') ema=tf.train.ExponentialMovingAverage(0.99) print(ema.variables_to_restore()) {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}

总结: 1.tf.train.Saver(…variables) 2.variables一般重命名加载,获取可由reader或者variables_to_restore()函数来获取相应的列表 3.保存的路径,以及global step的设置,默认情况下,保存的模型文件最多5个,可自行阅读参数修改 4.具体完整的应用在tensorflow代码梳理2的神经网络中

转载请注明原文地址: https://www.6miu.com/read-2622081.html

最新回复(0)