前言
我遇到了在同一个python脚本中加载两个同样的网络结构,并且他们包含同样的name的模块,在tensorflow中如果name相同的模块就会冲突,要想解决问题,利用两个graph分别包裹他们,并且用两个session来处理。如果能带来帮助很开心!
示例代码
lefteye_graph = tf.Graph()
righteye_graph = tf.Graph()
with lefteye_graph.as_default():
.....net construct.......
sess_left = tf.Session(graph=lefteye_graph)
sess_left.run(tf.global_variables_initializer())
sess_left.run(tf.local_variables_initializer())
left_saver = tf.train.Saver(tf.global_variables())
graph = tf.get_default_graph()
left_saver.restore(sess_left, model_name)
...................
with righteye_graph.as_default():
...............
sess_right = tf.Session(graph=righteye_graph)
sess_right.run(tf.global_variables_initializer())
sess_right.run(tf.local_variables_initializer())
right_saver = tf.train.Saver()
right_saver.restore(sess_right, model_name)
graph = tf.get_default_graph()
..........................
sess_left.close()
sess_right.close()
代码理解
其中的as_default()是必须的,表示在该模块中该图作为默认图
遇到的问题
其间查阅了很多种写法,但是我一直遇到后面的graph报找不到某个tensor的错误,我的错误是因为代码是从上个模块copy然后修改的,后面在run中sess用成上面的,又由于tensorflow在with内部的变量也作为全局变量,所以不会提示错误,所以一直报错。
最后看到链接中的文章认为代码一定没错误,然后把后半部分移到前面去发现没有定义的变量才发现问题
这篇文章写得不错