模型预览
假设新模型分encoder+decoder两部分。其中encoder模块要导入预训练的参数,并且数值固定,不参与训练。decoder则是在encoder的基础上增加的分支,需要通过数据训练不断优化参数。
大体步骤
主要分为四个步骤:
1. 绘制整体网络图
2. 固定encoder参数
3. 导入encoder参数
4. 训练 + 模型保存
代码
part1:画图
#设置网络整体结构....
part2:固定参数
# 选择decode部分的参数
train_var_list = [var for var in tf.trainable_variables() if 'decode' in var.name]
# 优化器只优化选中的参数list
with tf.control_dependencies():
optimizer = optimizer.minimize(loss, global_step=global_step, var_list = train_var_list) #自行选择优化器
part3 导入旧参
# 选择encode部分参数
no_train_var = [var for var in tf.global_variables() if 'encode' in var.name] #这里的'encode'是在设置网络过程中某个scope的命名
# saver选择要导入的参数
saver = tf.train.Saver(no_train_var)
# 对整个网络所有参数做初始化
init = tf.global_variables_initializer()
sess.run(init)
# encode部分参数覆盖
saver.restore(sess, weights_path) #这里的weights_path是ckpt文件保存路径
part4 训练+保存
# 训练......
# 保存模型
# 重新定义saver为选中所有参数,否则最后将只保存no_train_var
saver = tf.train.Saver()
saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
其他
- 对于该网络还有另外一种方法:encode前向传播保存结果,将其作为decode网络输入,进行训练。
- 模型导入还有其他方法,可参考https://blog.csdn.net/CV_YOU/article/details/80698942。
不同类型的模型(npy, ckpt)导入保存方式有差异。 - 固定参数还可以在构建网络的时候选择变量的trainable为False,或者设置变量学习率为0.
参数导入方法2
当预训练模型和新模型的图不同时,无法用Saver导入参数,这时候要用到tf.assign
函数。
假设预训练模型只有encode部分,新模型encode+decode。遍历模型参数,用预训练参数进行替换。
代码
part3 导入旧参
# 导入所有参数
saver = tf.train.Saver()
# 对整个网络所有参数做初始化
init = tf.global_variables_initializer()
sess.run(init)
#读取预训练模型
reader = pywrap_tensorflow.NewCheckpointReader(weights_path)
# 逐层遍历参数并替换
for vv in tf.trainable_variables():
if 'encode' in vv.name:
weights = reader.get_tensor(weights_key)
_op = tf.assign(vv, weights)
sess.run(_op)