TensorFlow --- 图的基本操作

xiaoxiao2021-03-01  12

1.建立图 关于建立图的几个基本操作

tf.Graph() # 建立图 tf.get_default_graph() # 获取图 tf.reset_default_graph() # 重置图
import numpy as np import tensorflow as tf # 在默认图里建立的 c = tf.constant(0.0) # 建立了一个图,并且在新建的图里添加变量, # 可以通过变量的'.graph'获取所在的图 g = tf.Graph() # 表示使用tg.Graph函数来创建一个图 with g.as_default(): c1 = tf.constant(0.1) print(c1.graph) print(g) print(c.graph) # 获取默认图,所以跟c的值一样 g2 = tf.get_default_graph() print(g2) # 重建了一张图代替原来的默认图 tf.reset_default_graph() g3 = tf.get_default_graph() print(g3) 结果为: <tensorflow.python.framework.ops.Graph object at 0x7f9c90074f28> <tensorflow.python.framework.ops.Graph object at 0x7f9c90074f28> <tensorflow.python.framework.ops.Graph object at 0x7f9caa041b00> <tensorflow.python.framework.ops.Graph object at 0x7f9caa041b00> <tensorflow.python.framework.ops.Graph object at 0x7f9c90074e48>

使用tf.reset_default_graph函数时,必须保证当前图的资源已经全部释放,否则会报错。

2.获取张量 通过get_tensor_by_name可以获得图里面的张量

接上述例子 ... print(c1.name) t = g.get_tensor_by_name(name='Const:0') print(t) 结果为: Const:0 Tensor("Const:0", shape=(), dtype=float32)

3.获取节点的操作 通过get_operation_by_name来获取节点

... a = tf.constant([1.0, 2.0]) b = tf.constant([1.0], [3.0]) tensor1 = tf.matmul(a, b, name='exampleop') print(tensor1.name, tensor1) test = g3.get_tensor_by_name('exampleop:0') print(test) print(tensor1.op.name) testop = g3.get_operation_by_name('exampleop') print(testop) with tf.Session() as sess: test = sess.run(test) print(test) test = tf.get_default_graph().get_tensor_by_name('exampleop:0') print(test) 结果为: exampleop:0 Tensor("exampleop:0", shape=(1, 1), dtype=float32) Tensor("exampleop:0", shape=(1, 1), dtype=float32) # tensor1.op.name exampleop # get_operaion_by_name name: "exampleop" op: "MatMul" input: "Const" input: "Const_1" attr { key: "T" value { type: DT_FLOAT } } attr { key: "transpose_a" value { b: false } } attr { key: "transpose_b" value { b: false } } [[7.]] Tensor("exampleop:0", shape=(1, 1), dtype=float32)

4.获取元素列表 通过get_operaions获取图中的所有元素

... tt2 = g.get_operations() print(tt2) 结果为: [<tf.Operation 'Const' type=Const>]

5.获取对象 通过tf.Graph.as_graph_element(obj, allow_tensor=True, allow_operation=True)函数根据对象获取元素。即传入的是一个对象,返回的是一个张量或一个op。该函数具有验证和转换功能,在多线程方面偶尔会用到。

... tt3 = g.as_graph_element(c1) print(tt3) 结果为: Tensor("Const:0", shape=(), dtype=float32)
转载请注明原文地址: https://www.6miu.com/read-3850321.html

最新回复(0)