1.迁移学习概述
2.Tensorflow实现
2.1 详细步骤
(1)
#获取要恢复的变量列表
exclude=[]
variables_to_restore=slim.get_variables_to_restore(exclude=['Mixed_7c'])
(2)方法一、利用slim.assign_from_checkpoint_fn函数进行恢复
#一个完整的只取部分参数的例子:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
s = tf.Session(config=tf.ConfigProto(gpu_options={'allow_growth':True}))
images = tf.placeholder(tf.float32, [None, 224, 224, 3])
predictions = nets.vgg.vgg_16(images, 200)
variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', variables_to_restore)
s.run(init_assign_op, init_feed_dict)
(3)方法二、利用saver.restore进行恢复
3.相关函数解析
3.0 optimizer.minimize(loss_score,var_list = output_vars)训练特定层
#定义优化算子
optimizer = tf.train.AdamOptimizer(1e-3)
#选择待优化的参数
output_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='outpt')
train_step = optimizer.minimize(loss_score,var_list = output_vars)
sess.run(init)
把需要更新梯度的层放在get_collection这个函数里面,不需要更新的不放进去。
3.1 slim.assign_from_checkpoint_fn():
很方便的实现。其中,第一个参数 model_path 指定预训练模型 xxx.ckpt 文件的路径,第二个参数 var_list 指定需要导入对应预训练参数的所有变量,通过函数
slim.get_variables_to_restore(include=None,
exclude=None)
可以快速指定,如果需要排除一些变量,也就是如果想让某些变量随机初始化而不是直接使用预训练模型来初始化,则直接在参数 exclude 中指定即可。第三个参数 ignore_missing_vars 非常重要,一定要将其设置为 True,也就是说,一定要忽略那些在定义的模型结构中可能存在的而在预训练模型中没有的变量,因为如果自己定义的模型结构中存在一个参数,而这些参数在预训练模型文件 xxx.ckpt 中没有,那么如果不忽略的话,就会导入失败(这样的变量很多,比如卷积层的偏置项 bias,一般预训练模型中没有,所以需要忽略,即使用默认的零初始化)。最后一个参数 reshape_variabels 指定对某些变量进行变形,这个一般用不到,使用默认的 False 即可。
返回一个函数,它从checkpoint文件读取变量值并分配给给特定变量。如果ignore_missing_vars为True,并且在检查点中找不到变量,则返回None。函数的源码如下
def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False,
reshape_variables=False):
"""
Args:
model_path: 模型的checkpoint文件的绝对路径。为了得到最新的checkpoint文件,可以使用:model_path = tf.train.latest_checkpoint(checkpoint_dir)
var_list: A list of `Variable` objects or a dictionary mapping names in the
checkpoint to the corresponding variables to initialize. If empty or
`None`, it would return `no_op(), None`.
ignore_missing_vars: Bool型,如果为True,它将忽略在checkpoint文件中缺失的那些变量,。
reshape_variables: Bool型, 如果为真,那么那些与checkpoint文件中的变量有不同形状的变量将会自动被reshape。
Returns:
一个只需要一个参数(tf.Session)函数,它作用是进行赋值操作。如果在checkpoint文件中没有找到任何匹配的变量,那么将会返回None
"""
if not var_list:
raise ValueError('var_list cannot be empty')
if ignore_missing_vars:
reader = pywrap_tensorflow.NewCheckpointReader(model_path)
if isinstance(var_list, dict):
var_dict = var_list
else:
var_dict = {var.op.name: var for var in var_list}
available_vars = {}
for var in var_dict:
if reader.has_tensor(var):
available_vars[var] = var_dict[var]
else:
logging.warning(
'Variable %s missing in checkpoint %s', var, model_path)
var_list = available_vars
if var_list:
saver = tf_saver.Saver(var_list, reshape=reshape_variables)
def callback(session):
saver.restore(session, model_path)
return callback
else:
logging.warning('No Variables to restore')
return None
3.2 slim.get_model_variables()
import tensorflow as tf
import tensorflow.contrib.slim as slim
my_non_trainable = tf.get_variable("my_non_trainable",
shape=(),
trainable=False)
my_trainable = tf.get_variable("my_trainable",
shape=(),
trainable=True)
weights = slim.model_variable('weights',
shape=[10, 10, 3 , 3],
initializer=tf.truncated_normal_initializer(stddev=0.1),
regularizer=slim.l2_regularizer(0.05),
device='/CPU:0')
print('slim.get_variables',slim.get_variables())
print('slim.get_model_variables',slim.get_model_variables())
print('tf.global_variables',tf.global_variables())
print('tf.trainable_variables',tf.trainable_variables())
3.3 tf.get_collection()
参考资料
[1] Tensorflow读取并使用预训练模型:以inception_v3为例 ****
[2] tensorflow对自己的数据进行训练(选择性的恢复权值)(26)---《深度学习》 不是非常好,一般般,可参考
[3] TensorFlow 使用预训练模型 ResNet-50 要看
[6] 第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码) 特别详细,特别好,豁然开朗,是想要找的
TensorFlow微调的实现
[1] 使用slim从ckpt里导出指定层的参数 简介明了,很好
[2] 6.5.2 Tensorflow 实现迁移学习 简介明了,很好
[3] TensorFlow - TF-Slim 之 checkpoint 恢复模型
[1] 深度学习模型-13 迁移学习(Transfer Learning)技术概述 不包含代码
[2] Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning 包含代码,InceptionResnet_v1
[3] 深度学习入门篇--手把手教你用 TensorFlow 训练模型 用处不大,不用看
[4] TensorFlow-Slim图像分类库
[5] 【深度学习-微调模型】使用Tensorflow Slim fine-tune(微调)模型
[6] Tensorflow学习笔记:CNN篇(10)——Finetuning,猫狗大战,VGGNet的重新针对训练 有些用,还需要再看看
[7] 什么是迁移学习 (Transfer Learning)?这个领域历史发展前景如何?
[8] 深度学习 -> 强化学习 ->迁移学习(杨强教授报告)
[9] 迁移学习 Transfer Learning
[10] 迁移学习--综述
[11] 什么是迁移学习(Transfer Learning)?【精讲+代码实例】
[12] 什么是迁移学习 Transfer Learning
[13] 如何在tensorflow中进行FineTuning 看起来有些用,看没怎么看懂
函数解析参考资料
[1] Slim下的函数介绍(一)
[2] 【Tensorflow slim】slim losses包
[3] TF.slim简单用法
[4] 『TensorFlow』使用集合collection控制variables
[5] TensorFlow 学习笔记(二) 不同variables之间的区别
[6] tf.get_collection()
只训练部分层
[1] tensorflow加载部分层方法
[2] tensorflow冻结部分层,只训练某一层
[3] tensorflow 固定部分参数训练,只训练部分参数