# 去掉 warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import numpy as np
#remember to define the same dtype and shape and name when restore
W = tf.Variable([[1, 2, 3], [1, 2, 3]], dtype = tf.float32, name = "Weights")
b = tf.Variable([[1, 2, 3]], dtype = tf.float32, name = "biases")
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "E:/tensorflow/aa/aa.ckpt")
print("Save to path:", save_path)
# W 和 b 的要和在文件中保存的 W 和 b有相同的shape,在此有两种方法设置shape
# W = tf.Variable(np.arange(6).reshape((2, 3)), dtype = np.float32, name = "Weights")
# b = tf.Variable(np.arange(3).reshape((1, 3)), dtype = np.float32, name = "biases")
W = tf.Variable(tf.zeros([2, 3]), dtype = np.float32, name = "Weights")
b = tf.Variable(tf.zeros([1, 3]), dtype = np.float32, name = "biases")
# 取数据的时候不需要初始化
# tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "E:/tensorflow/aa/aa.ckpt")
print("weights: ", sess.run(W))
print("biases: ", sess.run(b))
# weights: [[ 1. 2. 3.]
# [ 1. 2. 3.]]
# biases: [[ 1. 2. 3.]]
注意
- 提取的容器要和 在文件中保存的变量有相同的 dtype,shape 和 name。否则报错!