愿天堂没有Tensorflow! 阿门。
NotFoundError (see above for traceback): Key local3/weights not found in checkpoint
这是一个困扰我好久的问题,在我们保存一个训练好的模型,然后找了一些测试数据来调用该模型测试模型的效果时,出现了上述错误,local3/weights可能会随机变化(比如conv1/weights)。下面调用模型的代码是Tensorflow官网上的。
with tf.Session() as sess:
tf.get_variable_scope().reuse_variables()
print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success, global_step is %s' % global_step)
else:
print('No checkpoint file found')
看起来无懈可击,这个错误无从下手。再仔细读一下这个Error,有没有一种checkpoint模型保存的参数名字和实际网络模型参数的名字不一样的感觉?(哈哈,反正我有)。看一下自己的checkpoint和网络参数名字:
此时我们会产生这样一个大胆的想法(小姐姐,我想...):难道checkpoint里的参数名字和我们网络的参数名字不一样吗??
可是如何去验证这样一个大胆的想法呢? 如何去看checkpoint里的参数名呢? 如何讨得小姐姐的芳心呢?(哦哦,跑题了QAQ)我们可以使用下面的代码:
import os
model_dir = '/home/mml/siamese_net/logs/train/'
from tensorflow.python import pywrap_tensorflow
#checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
checkpoint_path = os.path.join(model_dir, "model.ckpt-9999")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
print(reader.get_tensor(key))
运行完上述代码后,发现水落石出:
果然,checkpoint参数名和网络的参数名是不一样的,当然会导致无法在checkpoint里找到local5,因为checkpoint里只有siamese/local5,所以只要修改统一参数名,即可顺利消除错误。