Tensorflow-iris-模型的保存与恢复-案例教程

首先请阅读和完成
Tensorflow-iris-案例教程-零基础-机器学习
在上篇文章中我们每次运行iris.py都会重新训练和测试模型,这很不科学。能不能把训练好的模型保存起来,下次直接使用呢?

--

checkpoints和SavedModel

Tensorflow可以将训练好的模型以两种形式保存:

  1. chekpoints检查点集,依赖于创建模型的代码
  2. SavedModel已保存模型,不依赖于创建模型的代码

修改iris.py文件中创建评估器/分类器的代码,添加model_dir保存模型的目录:

#选定估算器:深层神经网络分类器
models_path=os.path.join(dir_path,'models/')
classifier = tf.estimator.DNNClassifier(
   feature_columns=feature_columns,
   hidden_units=[10, 10],
   n_classes=3,
   model_dir=models_path)

保存需要往硬盘写入文件,所以需要操作系统的管理员权限才能运行,在windows下需要右键名利提示符工具选择【以管理员权限运行】:

python desktop/iris/iris.py

在MacOS下需要加sodu运行,回车后输入系统登陆密码:

 sudo python3 ~/desktop/iris/iris.py

运行起来后稍等一下模型训练train完成,就可以在桌面iris文件夹下看到一个models文件夹,打开它看起来类似下图的一些文件:


以前我们没有设定models_dir的时候,tensorflow也会把我们训练的数据放在默认的路径文件夹里,可以print(classifier.model_dir)查看具体地址。

注意到model.ckpt-1和-1000表示在我们训练的第1步step和第1000步都进行了保存,还记得classifier.train(input_fn=lambda:train_input_fn(train_x, train_y,batch_size),steps=1000)中的steps吗?

注意,如果我们修改了classifier或train的参数(比如hidden_units,batch_size,steps等),就会导致再次运行失败。这时候你需要手工删除models文件夹。

Tensorflow默认每10分钟保存一次,最多保留最近5次,训练第一步step和最后一步step时候一定会保存。

我们可以调整代码修改这个规则,先设定新规则ckpt_config,然后添加到train方法的括号里面config=ckpt_config:

#选定估算器:深层神经网络分类器
ckpt_config= tf.estimator.RunConfig(
    save_checkpoints_secs = 60,  # 每60秒保存一次
    keep_checkpoint_max = 10,       # 保留最近的10次
)
models_path=os.path.join(dir_path,'models/')
classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir=models_path,
    config=ckpt_config) #

--

恢复使用

从数据文件(data files)输入到估算器训练(estimator,train),到保存检查点集checkpoints,然后利用保存好的检查点集再进行评估evaluate或应用模型进行预测predict,整个的流程如下图所示:


model_dir设置了模型存储的路径,同时,如果已经存储了,那么这也是自动读取模型的路径。
我们把train一行注释掉,然后再运行iris.py,可以发现可以更快速的开始预测,这是因为并没有重新用数据进行训练,而是读取了models文件夹已经存储的模型。

#classifier.train(input_fn=lambda:train_input_fn(train_x, train_y,batch_size),steps=1000)

--

整理文件

当然我们可以将整个iris文件拆分成2个文件
iris.load

  1. 读取两个数据文件
  2. 提供.load载入方法
  3. 提供.train_input_fn训练数据“喂食”方法
  4. 提供.eva_input_fn评估数据“喂食”方法
import os
import pandas as pd
import tensorflow as tf

FUTURES = ['SepalLength', 'SepalWidth','PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
    
#格式化数据文件的目录地址
dir_path = os.path.dirname(os.path.realpath(__file__))
train_path=os.path.join(dir_path,'iris_training.csv')
test_path=os.path.join(dir_path,'iris_test.csv')
    
#载入数据函数
def load():    
    #载入训练数据
    train = pd.read_csv(train_path, names=FUTURES, header=0)
    train_x, train_y = train, train.pop('Species')

    #载入测试数据
    test = pd.read_csv(test_path, names=FUTURES, header=0)
    test_x, test_y = test, test.pop('Species') 
    
    return (train_x, train_y),(test_x, test_y)    
    
#针对训练的喂食函数
def train_input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    dataset = dataset.shuffle(1000).repeat().batch(batch_size) #每次随机调整数据顺序
    return dataset


#针对测试的喂食函数
def eval_input_fn(features, labels, batch_size):
    features=dict(features)
    inputs=(features,labels)
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    dataset = dataset.batch(batch_size)
    return dataset

iris_premade.py

  1. estamator()函数用于根据my_cfg设置生成估算分类器classifier
  2. train()函数执行训练命令,用户输入自定义四个设置如10,10,100,1000
  3. evalute()函数执行评估命令,对训练的模型进行评估
    1.predict()函数执行预测命令,对输入的四个测量数据进行评估
import os
import tensorflow as tf
import iris_load as dts
import shutil

#利用iris_load.py读取训练数据和测试数据
(train_x, train_y), (test_x, test_y) = dts.load()
    
#设定特征值的名称
feature_columns = []
for key in train_x:
    feature_columns.append(tf.feature_column.numeric_column(key=key))   

#估算器存储路径
dir_path = os.path.dirname(os.path.realpath(__file__))
models_path=os.path.join(dir_path,'models/')

#估算器存储设置选项
ckpt_config= tf.estimator.RunConfig(
    save_checkpoints_secs = 60,  #每60秒保存一次
    keep_checkpoint_max = 10,    #保留最近的10次
)

#估算器预设
my_cfg=dict() 
my_cfg['layer1'],my_cfg['layer2'],my_cfg['batch_size'],my_cfg['steps']=10,10,100,1000

#生产估算器函数:深层神经网络分类器
def estimator():
    classifier = tf.estimator.DNNClassifier(
        feature_columns=feature_columns,
        hidden_units=[my_cfg['layer1'], my_cfg['layer2']],
        n_classes=3,
        model_dir=models_path,
        config=ckpt_config)
    return classifier 

#训练模型函数
def train():
    print('Please input:layer1 nodes,layer2 nodes,batch_size,steps')
    params=input().split(',')
    if len(params)>3:
        if os.path.exists(models_path):
            print('Removing models folder...')
            shutil.rmtree(models_path) #移除models目录
            
        my_cfg['layer1'],my_cfg['layer2'],my_cfg['batch_size'],my_cfg['steps'] = map(int, params)
        
    print('Training...')
    classifier=estimator()   
    classifier.train(input_fn=lambda:dts.train_input_fn(
            train_x,
            train_y,
            my_cfg['batch_size']),
         steps=my_cfg['steps'])
    print('Train OK')         
        

#评估模型函数
def evalute():
    print('Evaluating...') 
    classifier=estimator()  
    eval_result = classifier.evaluate(
        input_fn=lambda:dts.eval_input_fn(test_x, test_y,my_cfg['batch_size']))
    print('Evaluate result:',eval_result)
    
def predict():
    print('Please enter features: SepalLength,SepalWidth,PetalLength,PetalWidth;0 for exit.')
    params=input().split(',');
    if len(params)>3:
        predict_x = {
            'SepalLength': [float(params[0])],
            'SepalWidth': [float(params[1])],
            'PetalLength': [float(params[2])],
            'PetalWidth': [float(params[3])],
        }    

        #进行预测
        classifier=estimator()        
        predictions = classifier.predict(
                input_fn=lambda:dts.eval_input_fn(predict_x,
                                                labels=[0],
                                                batch_size=my_cfg['batch_size']))

        #预测结果是数组,尽管实际我们只有一个
        for pred_dict in predictions:
            class_id = pred_dict['class_ids'][0]
            probability = pred_dict['probabilities'][class_id]
            print('Predict result:',dts.SPECIES[class_id],100 * probability)
    else:
        print('Input format error,ignored.')

#定义入口主函数
def main(args):
    while 1==1:
        print('Please enter train,evalute or predict:')
        cmd = input() #捕获用户输入的数字
        if cmd=='train':
            train()
        elif cmd=='evalute':
            evalute()
        elif cmd=='predict':
            predict()
        elif cmd=='retrain':
            retrain()
            
#运行主函数            
if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run(main)

注意my_cfg=dict() 这个字典数据用法,它和下面一行生成类似下面这种数据结构

my_dict={
  'layer1':10,
  'layer2':10,
  'batch_size':100,
  'steps':1000
}

然后我们才能在def定义的函数中修改它并使其在estimator()方法中生效。

def定义的函数里面无法直接修改外面的数据,比如下面代码,打印出来是100而不是99

a=100
def change():
    a=99
print(a)

同样,下面的代码输出的是101,而不是11:

n=100

def estimator():
    b=n+1
    print(b)
    
def train():
    n=10

def evalute():
    estimator()
    
train() #这行并不能真正改变n
evalute()

所以,如果不使用dict字典,那么当我们在train训练时候改变参数(hidden_units,batch_size,steps)的时候,evalute和predict都不能使用到这些参数。

如果遇到问题,您也可以点击这里直接下载代码
提取密码: 83qe


探索人工智能的新边界

如果您发现文章错误,请不吝留言指正;
如果您觉得有用,请点喜欢;
如果您觉得很有用,感谢转发~


END

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,723评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,485评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,998评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,323评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,355评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,079评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,389评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,019评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,519评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,971评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,100评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,738评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,293评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,289评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,517评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,547评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,834评论 2 345

推荐阅读更多精彩内容