此文翻译自:A quick complete tutorial to save and restore Tensorflow models
这篇tensorflow的教程,将解释:
1. Tensorflow模型是什么样的?
2. 如何保存一个Tensorflow模型??
3. 如何恢复一个Tensorflow模型,用于预测或者迁移学习?
4. 如何利用导入的预训练好模型,进行fine-tuning或改造。
这篇教程,假定读者对神经网络的训练有基本的了解。如果不是,请先阅读Tensorflow Tutorial 2: image classifier using convolutional neural network,然后阅读本文。
1. Tensorflow 模型是什么?
当训练完一个神经网络,你就会保存它,以便日后使用和产品发布。所以,Tensorflow模型是如何表示的呢?Tensorflow模型主要包含网络设计(Graph)和训练好的参数的值。因此,Tensorflow模型包含两个主要的文件:
a) Meta graph:
这是一个协议缓冲区(protocol buffer,google推出的数据存储格式),保存完整的Tensorflow的graph信息;例如:所有的变量,操作(ops),集合(collection)等。此文件带有.meta扩展。
b) Checkpoint file:
它是一个二进制文件,包含所有的权重,偏置,导数和其他保存变量的值。文件后缀为: .ckpt。但自从0.11版本之后,Temsorflow作了改变,不再是一个单独的.ckpt文件,取而代之的是两个文件:
<<mymodel.data-00000-of-00001>>
<<mymodel.index>>
.data文件包含着训练好的变量的值,除此之外,Tensorflow还有一个名为checkpoint的文件,持续记录着最新的保存数据。
所以,总结下来,0.10之后的Tensorflow模型如下图所示:
而,0.11版本之前的Tensorflow模型,仅仅包含三个文件:
<<inception_v1.meta>>
<<inception_v1.ckpt>>
<<checkpoint>>
2. 保存一个Tensorflow模型:
假设,你正在训练一个卷积神经网络,用于图片分类。作为一个标准操作,你持续观测Loss function和Accuracy。一旦你看到网络收敛,你可以人为停止训练或者只训练固定数目的epochs。当训练完成之后,我们想要保存所有的变量和网络图(network graph)到一个文件,以便日后使用。因此,在Tensorflow中,为了保存graph和变量,我们应该新建一个tf.train.Saver()类。
谨记Tensorflow的变量只有在一个session中才是有效的。因此,你不得不在一个session中保存模型,使用刚刚新建的saver对象,调用save方法,如下:
这里,sess是一个session对象,“my-test-model”是你想要保存的模型的名字。完整的例子如下:
如果,我们想要在1000次迭代之后保存模型,可以传入表示步数的参数:
这行代码将添加‘-1000’至模型的名字,以下文件将被建立:
假设,训练时,我们每隔1000次迭代保存一次模型,因此,.meta文件第1000次迭代生成.meta文件后,我们不必要每次新建.meta文件(即在2000,3000次等迭代无须新建.meta文件)。我们仅仅保存最新的迭代模型。因为graph结构并没有改变,因此,也没必要写meta-graph,使用如下代码:
如果你想要只记录最新的4个模型,并每隔2个小时保存一个模型,可以使用这两个参数:max_to_keep和keep_checkpoint_every_n_hours,如下:
需要指出的是,如果我们在tf.train.Saver()中不指定任何事情,它将保存所有的变量。如果,我们不想保存所有的变量,仅仅是一部分。我们可以指定想要保存的变量或集合。当新建tf.train.Saver实例时,传递给它一个想要保存的变量的列表或者字典。看下面的例子:
可以保存Tensorflow Graph的指定的需要的部分。
3. 导入预训练的模型
如果你想要使用别人训练好的模型做fine-tuning,有两件事需要做:
a) 构建网络:
你可以写python代码,像写预训练的网络一样,人为地复原每一层或者每一个模块。但是,如果你想到我们已经将网络保存到.meta文件里了,就可以使用tf.train.import()函数,恢复网络结构,如下:
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
记住,import_meta_graph方法将预定义在.meta文件的网络添加到当前网络。因此,该方法构造graph结构,但我们仍需要加载预训练的参数的值。
b) 加载参数:
我们可以通过tf.train.Saver()的restore方法,恢复网络的参数:
执行完上述代码,w1和w2张量的值就被恢复了,可以通过如下代码获取:
所以,至此你已经理解了如何保存和导入Tensorflow模型的工作。下一章节,我将描述加载任意预训练模型的实际使用。
4. 使用恢复模型
既然你已经理解如何保存并恢复Tensorflow模型,让我们养成一个规范去恢复任意预训练模型,并使用它做预测,fine-tuning或者进一步训练。不管什么时候使用Tensorflow,你将定义一个Graph,包含输入,一些超参数,如learning rate, global step等。一个标准的喂入数据和超参数的方式是使用placeholders。让我们构建一个小的使用placeholders的网络,并保存它。值得指出的是。当网络被保存。placeholders的值并未保存。
现在,当我们想要恢复模型时,不仅需要恢复graph和权重,也需要准备新的feed_dict去喂新的训练数据给网络。我们可以通过graph.get_tensor_by_name()等方法得到保存的ops和placeholder变量的引用。
如果我们仅仅想要在网络上跑不同的数据,可以通过feed_dict传递新的数据给网络。
如果想要增加更多的操作(增加更多的layers)到graph里,并训练它。当然,你也可以如下:
但是,可以只恢复一部分的graph然后增加一些操作进行fine-tuning么?当然可以。利用graph.get_tensor_by_name()方法得到相应操作的引用,在顶层构建网络。这里有个实际的例子。我们加载一个预训练的VGG网络,改变输出的单元数目为2,利用新的训练数据fine-tuning。
希望这篇文章能让你清晰地理解Tensorflow模型的保存和恢复。
转载请注明来源,谢谢。