pb文件的能够保存tensorflow计算图中的操作节点以及对应的各张量,方便我们日后直接调用之前已经训练好的计算图。
本文代码的运行软件为pycharm
保存pb文件
下面的代码展示了最简单的tensorflow四则运算计算图
import tensorflow as tf
x = tf.placeholder(tf.float32,name="input")
a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
c = tf.Variable(tf.constant(10.,shape=[1]),name="c")
d = tf.Variable(tf.constant(2.,shape=[1]),name="d")
tensor1 = tf.multiply(a,b,"mul")
tensor2 = tf.subtract(tensor1,c,"sub")
tensor3 = tf.div(tensor2,d,"div")
result = tf.add(tensor3,x,"add")
inial = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(inial)
print(sess.run(a))
print(result)
result = sess.run(result,feed_dict={x:1.0})
print(result)
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
with tf.gfile.FastGFile("wsj.pb", mode='wb') as f:
f.write(constant_graph.SerializeToString())
保存pb文件的功能主要是通过最后三行代码实现的
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["add"])
with tf.gfile.FastGFile("wsj.pb", mode='wb') as f:
f.write(constant_graph.SerializeToString())
第一行代码的作用是将计算图中的变量转化为常量,并指定输出节点为“add”
第二行代码用来生成一个名为wsj.pb的文件(未指定路径的话,默认在该python代码的同路径下生成)
第三行代码的作用是将计算图写入该pb文件中
读取pb文件
import tensorflow as tf
with tf.gfile.FastGFile("wsj.pb", "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
result, x = tf.import_graph_def(graph_def,return_elements=["add:0", "input:0"])
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(a))
result = sess.run(result, feed_dict={x: 5.0})
print(result)
上面代码主要分为两部分:读取pb文件并设置为默认的计算图;填充一个新的x值来计算结果。
读取pb文件时候需要注意的是,若要获取对应的张量必须用“tensor_name:0”的形式,这是tensorflow默认的。
若您觉得本文章对您有用,请您为我点上一颗小心心以表支持。感谢!