本文介绍了 Estimators 模型的保存和恢复。(官方文档连接:https://www.tensorflow.org/guide/checkpoint)
TensorFlow提供了两种模型格式:
checkpoints:这种格式依赖于创建模型的代码。
SavedModel:这种格式与创建模型的代码无关。
1、Checkpoints
checkpoints是什么?
- 在tensorflow中checkpoints文件是一个二进制文件,用于存储所有的weights,biases,gradients和其他variables的值。.meta文件则用于存储 graph中所有的variables, operations, collections等。简言之一个存储参数,一个存储图。
-“checkpoint”文件仅用于告知某些TF函数,这是最新的检查点文件。
-.ckpt-meta
包含元图,即计算图的结构,没有变量的值(基本上你可以在tensorboard / graph中看到)。
-.ckpt-data
包含所有变量的值,没有结构。要在python中恢复模型,您通常会使用元数据和数据文件(但也可以使用.pb
文件):saver = tf.train.import_meta_graph(path_to_ckpt_meta) saver.restore(sess, path_to_ckpt_data)
-.ckpt-index是内部需要的某种索引来正确映射前两个文件。它通常不是必需的,可以只用
.ckpt-meta和恢复一个模型
.ckpt-data。
.pb文件可以保存您的整个图表(元+数据),要在c ++中加载和使用(但不训练)图形,通常会使用它来创建[
freeze_graph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py),它会
.pb从元数据和数据创建文件。要小心,(至少在以前的TF版本和某些人中)py提供的功能
freeze_graph不能正常工作,所以你必须使用脚本版本。Tensorflow还提供了一种
tf.train.Saver.to_proto()`方法。保存经过部分训练的模型
Estimator自动将如下内容写入磁盘
- checkpoints: 训练期间所创建的模型版本
- event files: 包含有TensorBoard用于创建可视化图标的全部信息
如果要指定模型的顶级存储目录,可以使用Estimator构造函数的可选参数model_dir
,设置代码如下所示:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir="./models_dir")
当调用Estimator的train
方法时,Estimator会将checkpoint和其他文件保存到model_dir
目录中,保存之后,这个目录中的文件如下所示:
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
这个目录存储的是Estimator在第一步训练开始和第200不训练结束时创建的checkpoints
-
Checkpoint频率
默认情况下,Estimator按照如下时间将checkpoint保存到model_dir
中
- 每600秒保存一次
- 在train方法开始以及完成时都要保存checkpoint
- 在目录中最多保留5个最近的checkpoints
可以通过如下步骤来更改默认设置:
- 创建
RunConfig
对象来自定义设置 - 在实例化Estimator时,将该```RunConfig
对象传递个Estimatro的
config``参数
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60,
keep_checkpoint_max = 10,
)
-
变量的保存与恢复
tf.train.Checkpoint
TensorFlow 提供了tf.train.Checkpoint
这一强大的变量保存与恢复类,可以使用其save()
和restore()
方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer
或者tf.keras.Model
实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:
checkpoint = tf.train.Checkpoint(model=model)
这里tf.train.Checkpoint()
接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model
的模型实例 model`` 和一个继承
tf.train.Optimizer的优化器
optimizer`` ,我们可以这样写:
checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)
这里myAwesomeModel
是我们为待保存的模型 model`` 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。 接下来,当模型训练完成需要保存的时候,使用(
save_path_with_prefix``` 是保存文件的目录 + 前缀。):
checkpoint.save(save_path_with_prefix)
例如,在源代码目录建立一个名为save
的文件夹并调用一次 checkpoint.save('./save/model.ckpt')
,我们就可以在可以在 save 目录下发现名为 ``checkpoint 、
model.ckpt-1.index 、
model.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。
checkpoint.save() 方法可以运行多次,每运行一次都会得到一个
. index文件和
. data ```文件,序号依次累加。
当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:
model_to_be_restored = MyModel() # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)
当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path)
这个辅助函数f。例如如果save
目录下有 model.ckpt-1.index
到 model.ckpt-10.index
的 10 个保存文件, tf.train.latest_checkpoint('./save')
即返回 ./save/model.ckpt-10
。
总体而言,恢复与保存变量的典型代码框架如下:
# train.py 模型训练阶段
model = MyModel()
# 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型训练代码)
# 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
checkpoint.save('./save/model.ckpt')
# test.py 模型使用阶段
model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model) # 实例化Checkpoint,指定恢复对象为model
checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数
# 模型使用代码
使用 TensorFlow 的 tf.train.CheckpointManager
设置保存的数量:
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)
# 此处, directory 参数为文件保存的路径, checkpoint_name 为文件名前缀(不提供则默认为 ckpt ),
# max_to_keep 为保留的 Checkpoint 数目。