Tensorflow win10 c++ 运行 python训练出的模型

简介

由于生产环境使用windows、C++,而tensorflow模型训练使用python更为方便,因此存在需求:在windows环境使用tensorflow的c++接口载入训练好的tensorflow模型,并进行测试。类似的文档比较缺乏,并且由于tf本身一直在完善,相比现有的博客各个步骤都有进一步的简化,这里针对1.2.0版本梳理对应的最简单的一种流程:

  1. 利用tensorflow的python API定义、训练自己的模型
  2. 利用tensorflow的python API保存模型,并进一步将模型中的变量都转化为常量,通过这样“freeze graph”使得模型导出为一个文件,便于c++调用
  3. 编译tensorflow的源码来使用tensorflow的c++接口
  4. 在tensorflow的tutorrials Image Recognition 的基础上修改代码,利用模型进行测试。

利用tf的python API训练模型

这部分属于tensorflow的基础,官方文档getting started有相当详细的介绍和描述,在此不做赘述。值得注意的是tf的命名方式,在python代码中的变量名和在tf的graph中的变量名是两个概念,因此至少针对输入输出要定义tf的graph中的变量名,定义变量名的语法类似loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')。此外也可以利用tf.name_scope
来规划命名。

导出tf模型并freeze graph

这部分有官方工具的代码freeze_graph.py,对应的博客也很多。这里我推荐博客TensorFlow: How to freeze a model and serve it with a python API
  freeze graph就是把原本的图中的变量(卷积核、偏置)等都使用训练好的模型中的值来代替,变成常量。frozen graph的意义在于(freeze_graph.py的注释)

It's useful to do this when we need to load a single file in C++, especially in environments like mobile or embedded where we may not have access to the RestoreTensor ops and file loading calls that they rely on.

推荐的主要原因在于博客中使用方法saver = tf.train.Saver();last_chkp = saver.save(sess, 'results/graph.chkp')是最为简单的保存模型的方法,同时博客提供了freeze graph的代码,核心采用graph_util.convert_variables_to_constants 方法来进行freeze graph,使得不需要使用官方工具freeze_graph.py。对应freeze_graph的代码引用如下(其中注意到write使用参数‘wb'写为二进制):


import os, argparse

import tensorflow as tf
from tensorflow.python.framework import graph_util

dir = os.path.dirname(os.path.realpath(__file__))

def freeze_graph(model_folder):
    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path
    
    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_folder + "/frozen_model.pb"

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    # NOTE: this variable is plural, because you can have multiple output nodes
    output_node_names = "Accuracy/predictions"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True
    
    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            input_graph_def, # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_folder", type=str, help="Model folder to export")
    args = parser.parse_args()

    freeze_graph(args.model_folder)

编译源码来使用tf的c++ API

编译源码的方式官方有文档Installing TensorFlow from Sources,其中有段:

We don't officially support building TensorFlow on Windows; however, you may try to build TensorFlow on Windows if you don't mind using the highly experimental Bazel on Windows or TensorFlow CMake build.

在两种方案中,我选择采用cmake,理由是相对来说环境配置更为容易,但可能使用google自己的bazel相对支持度更高。
  参考官方readme一步一步来,值得注意的有两点,一个是git clone的时候推荐git对应的稳定版本的分支(直接master可能会有编译错误和未知bug);另一个是要用命令行进行编译,直接采用vs2015 IDE进行编译会出错C1060,原因应该是默认的编译器调用的不是native 64位的toolset,如何设置使得能够使用IDE直接编译调试的方法还没有找到。
  相比于官方的项目tf_tutorials_example_trainer.vcxproj,更有参考意义的项目是tf_label_image_example.vcxproj,对应的详尽官方教程Image Recognition,这个教程使用inception模型来进行识别,对应运行时可能需要修改图片和文件的路径才能正确输出结果。

修改代码实现自己的模型

教程源码提供了模型读取,图片读取,Label读取等核心步骤,修改对应代码进行编译能够很容易上手完成任务,下面贴一下保存图片的代码,总体是读取图片的逆向过程:

// Given an output tensor with 4d, reduce dim and output jpg image
Status SaveTensorToImageFile(const string& file_name, const Tensor* out_tensor) {
    auto root = tensorflow::Scope::NewRootScope();
    using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)

    auto output_image_data = tensorflow::ops::Reshape(root, *out_tensor, { 256, 256, 3 });
    auto output_image_data_cast = tensorflow::ops::Cast(root, output_image_data, tensorflow::DT_UINT8);
    auto output_image = tensorflow::ops::EncodeJpeg(root, output_image_data_cast);
    auto output_op = tensorflow::ops::WriteFile(root.WithOpName("output/image"), file_name/*"D:/tf_face/trained_model_fast/output.jpg"*/, output_image);
    string output_name = "output/image";
    // This runs the GraphDef network definition that we've just constructed, and
    // returns the results in the output tensor.
    tensorflow::GraphDef graph;
    TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));

    std::unique_ptr<tensorflow::Session> session(
        tensorflow::NewSession(tensorflow::SessionOptions()));
    TF_RETURN_IF_ERROR(session->Create(graph));
    Status writeResult = session->Run({}, {}, { output_name }, {});
    return writeResult;
}

代码中图片的尺寸可以自行定义,其中要注意的是c++中session->Run函数传入的参数无论是ops或是Tensor都是要使用tf定义的名字root.WithOpName("output/image")而不是c++代码中定义的局部变量output_op,以上在tf的CPU版本上流程走通。

参考链接

Tensorflow C++ API调用预训练模型和生产环境编译 (unix )
TensorFlow: How to freeze a model and serve it with a python API
TensorFlow CMake build
Tensorflow Tutorial Image Recognition

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,098评论 5 476
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,213评论 2 380
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 149,960评论 0 336
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,519评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,512评论 5 364
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,533评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,914评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,574评论 0 256
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,804评论 1 296
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,563评论 2 319
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,644评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,350评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,933评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,908评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,146评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,847评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,361评论 2 342

推荐阅读更多精彩内容