Tensorflow 保存和加载模型

xiaoxiao2021-02-28  12

import tensorflow as tf def save_model(): v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1') v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2') res = tf.add(v1,v2,name='add_res') saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess,'./save/model.ckpt') def restore_model(): saver = tf.train.import_meta_graph('./save/model.ckpt.meta') with tf.Session() as sess: saver.restore(sess,'./save/model.ckpt') print (sess.run(tf.get_default_graph().get_tensor_by_name('add_res:0'))) if __name__=='__main__': save_mode() #第二次运行本程序时,需要注释这一句。 restore_model()
转载请注明原文地址: https://www.6miu.com/read-200186.html

最新回复(0)