深度学习 - Tensorflow on iOS 入门 + MNIST

前言

本文主要参考了几篇文章,搭建了一个在iOS上跑Tensorflow MNIST模型的demo,本文会给出一个可用的Demo,写出当时我遇到的问题。想要把项目跑起来,需要详细的阅读我贴出来的几篇文章,某些具体步骤我会给出链接和索引。

如果你是tensorflow新手,想要知道如何读取训练好的MNIST模型并且做预测,你会从这篇文章得到帮助并节约时间,下载demo

注意:这个demo需要Tensorflow的库以及各种环境,你可以找到这个感受一下,直接下载在iOS10上的真机就可以运行。

你需要一些Python,Tensorflow和iOS的知识。

MNIST on iOS

Reference

1. python脚本,训练MNIST+用自己的图片做输入预测结果
Using TensorFlow to create your own handwriting recognition engine
GitHub 下载脚本

2. 工程如何搭建请参考这篇
Getting started with TensorFlow on iOS

3. 在iOS里怎么load模型和读取数据
Getting Started with Deep MNIST and TensorFlow on iOS

4. 深度学习,卷积,神经网络简单的解释看这篇
机器学习原来这么有趣!第三章:图像识别【鸟or飞机】?深度学习与卷积神经网络

5. 删除iOS不能支持的node
Drop dropout from Tensorflow

跑一个Tensorflow的例子

MNIST是一个手写数字0~9的数据集,通常机器学期的入门会使用这个数据集来跑一边例子,因为数据量不大,训练的时间比较短,可以很快看到结果。

  1. 参考Getting started with TensorFlow on iOS中的Installing TensorFlow,在mac上搭建起运行tensorflow的环境。

  2. 创建一个文件夹,比如名字叫train,下载train3.py,解压好下载的MNIST数据集在MNIST_data文件夹中,在terminal中直接
    python ./train3.py
    这时Tensorflow会帮我们进行训练。

  3. 下载predict_2.py,并随便的网上找几张手写的数字0~9的图片,使用我们刚才训练的模型做预测
    python ./predict_2.py ‘number1.png’
    在predict_2.py中,我们读取了一张图片,并对这张图片做了一些处理,包括使这张图变为黑白色,缩放图片到28*28的大小(也是MNIST数据集中图片的大小),读取图片的每一个像素并按照tensorflow需要的格式做处理,然后将数据输入到模型中,获取结果。

为iOS准备Tensorflow的环境

这里请详细参考Getting started with TensorFlow on iOS中的TensorFlow on iOS小结,文章里已经说得非常详细了。步骤不复杂,但是编译iOS需要的库要一些时间,我的macbook 13' 大概跑了2个多小时。

Freezing the graph

这一节也在Getting started with TensorFlow on iOS有所提及,细节问题,我在这里说明。

如果你跟着做到了这里,那我们现在有了训练好的模型,这一步我们需要对这个模型进行处理以便它可以用在iOS上面。

模型文件

上面的截图显示了你在运行过train3.py之后会生成的模型文件。Freeze graph指的是将这些模型和训练好的网络参数合并成一个文件,方便工程上的使用。

在terminal中,进入到tensorflow文件夹,复制粘贴执行:
bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=/mnist/model/graph.pb --input_checkpoint=/mnist/model/model.ckpt \ --output_node_names=softmax \ --output_graph=/mnist/model/frozen.pb

注意这个目录:/mnist/model/是指Macintosh HD下的/mnist/model/,也就是mac硬盘的根目录下面。

这样我们就把模型和参数合并到了一起,这里拿到的模型里面,有一些操作有可能是不能直接在iOS上面运行的,所以我在train3.py中移除了一些node使得这个模型可以直接放到ios上面。

接下来我们需要用optimize_for_inference优化这个模型,获得一个final.pb,这个才是最后用在iOS上的文件:
bazel-bin/tensorflow/python/tools/optimize_for_inference --input=/mnist/model/frozen.pb --output=/mnist/model/final.pb --output_names=softmax --frozen_graph=True --input_names=x

你可以在这里下载我训练并处理好的模型文件。

The iOS App

1.创建一个新的App工程

2.修改ViewController.m为.mm,因为我们需要使用c++

3.在Build Settings中,根据你编译好的tensorflow文件夹地址修改other link flags:

/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf-lite.a 

/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf.a 

-force_load /Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/lib/
libtensorflow-core.a

4.同样修改 library search path:

-force_load 

/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf-lite.a 

/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/protobuf_ios/lib/
libprotobuf.a 

/Users/matthijs/tensorflow/tensorflow/contrib/makefile/gen/lib/
libtensorflow-core.a

注意这里有 -force_load 不然runtime要出错

5.修改Header Search Paths:

~/tensorflow
~/tensorflow/tensorflow/contrib/makefile/downloads 
~/tensorflow/tensorflow/contrib/makefile/downloads/eigen 
~/tensorflow/tensorflow/contrib/makefile/downloads/protobuf/src 
~/tensorflow/tensorflow/contrib/makefile/gen/proto

6.修改Enable Bitcode: No

7.将final.pb拖入iOS项目中,记得勾选Add to target

0.参考Getting started with TensorFlow on iOS 里面的 The iOS App,这里有不懂的对照着看一下

iOS代码部分

详细代码请去github下载我的demo,可以配好环境运行一下

加载model
- (void)viewDidLoad {
    [super viewDidLoad];
    // Do any additional setup after loading the view, typically from a nib.
    NSString *path = [[NSBundle mainBundle] pathForResource:@"final" ofType:@"pb"];
    if ([self loadGraphFromPath:path] && [self createSession]) {
        NSLog(@"load model and create session");
    }
}
-(BOOL)loadGraphFromPath:(NSString *)path {
    auto status = ReadBinaryProto(tensorflow::Env::Default(), path.fileSystemRepresentation, &graph);
    if (!status.ok()) {
        NSLog(@"Error reading graph: %s", status.error_message().c_str());
        return NO;
    }
    auto nodeCount = graph.node_size();
    NSLog(@"Node count: %d", nodeCount);
    for (auto i = 0; i < nodeCount; ++i) {
        auto node = graph.node(i);
        NSLog(@"Node %d: %s '%s'", i, node.op().c_str(), node.name().c_str());
    }
    return YES;
}
-(BOOL)createSession {
    tensorflow::SessionOptions options;
    auto status = tensorflow::NewSession(options, &session);
    if (!status.ok()) {
        NSLog(@"Error creating session: %s", status.error_message().c_str());
        return NO;
    }
    status = session->Create(graph);
    if (!status.ok()) {
        NSLog(@"Error creating session: %s", status.error_message().c_str());
        return NO;
    }
    return YES;
}
做预测
  1. 读取图片,将图片scale,读取像素做normalize
  2. 放入input
  3. 跑网络
  4. 拿到输出,获得结果
-(void)predict {
    // 1. 读取图片,将图片scale,读取像素做normalize
    UIImage *orignalImage = [UIImage imageNamed:@"9-1.png"];
    UIImage *scaledImage = [self scaleImage:orignalImage];
    UIImage *image = [self convertImageToGrayScale:scaledImage];
    UIImageView *imageView = [UIImageView new];
    imageView.frame = CGRectMake(0, 0, 50, 50);
    imageView.image = image;
    [self.view addSubview:imageView];
    tensorflow::Tensor x(tensorflow::DT_FLOAT, tensorflow::TensorShape({1,kInputLength}));
    
    NSArray *pixel = [self getRGBAsFromImage:image atX:0 andY:0 count:kInputLength];
    
    for (auto i = 0; i < kInputLength; i++) {
        UIColor *color = pixel[i];
        CGFloat red = 0.0, green = 0.0, blue = 0.0, alpha =0.0;
        [color getRed:&red green:&green blue:&blue alpha:&alpha];
        x.matrix<float>().operator()(0,i) = (255.0 - red) / 255.0f;
        NSLog(@"%f",x.matrix<float>().operator()(0,i));
    }
    // 2. 放入input
    std::vector<std::pair<tensorflow::string, tensorflow::Tensor>> inputs = {
        {"x", x}
    };
    
    std::vector<std::string> nodes = {
        {"softmax"}
    };
    
    const auto start = CACurrentMediaTime();
    
    std::vector<tensorflow::Tensor> outputs;
    // 3. 跑网络
    auto status = session->Run(inputs, nodes, {}, &outputs);
    if (!status.ok()) {
       NSLog(@"Error reading graph: %s", status.error_message().c_str());
        return;
    }
    
    NSLog(@"Time: %g seconds", CACurrentMediaTime() - start);
    // 4. 拿到输出,获得结果
    const auto outputMatrix = outputs[0].matrix<float>();
    float bestProbability = 0;
    int bestIndex = -1;
    for (auto i = 0; i < kOutputs; i++) {
        const auto probability = outputMatrix(i);
        if (probability > bestProbability) {
            bestProbability = probability;
            bestIndex = i;
        }
    }
    NSLog(@"!!!!!!!!!!! result %d",bestIndex);
}

至此,我们就成功的在iOS上用tensorflow跑起了我们训练好的模型,并做出预测了!

其他遇到的问题

当我使用create_model_2.py创建了一个模型,但是在iOS上却报这么一个错:

Invalid argument: No OpKernel was registered to support Op 'RandomUniform' with these attrs.  Registered devices: [CPU], Registered kernels:
  <no registered kernels>

     [[Node: dropout/random_uniform/RandomUniform = RandomUniform[T=DT_INT32, dtype=DT_FLOAT, seed=0, seed2=0](dropout/Shape)]]

显示我的模型里面,有iOS不能支持的node。Google了很久,发现我最开始使用的训练模型中,使用了dropout来防止训练过拟合,但是dropout中有iOS不能执行的node操作,并且freeze_graph和optimize_for_inference也不能删除iOS不支持的节点。

目前我知道的解决方式就是手动的删除model中iOS不可以支持的节点,在这里我们可以直接干掉dropout相关的节点。

具体方法参考:
Drop dropout from Tensorflow
optimize_for_inference.py should remove Dropout operations #5867

主意上面的train3.py这个脚本,这里使用的train3脚本是我对谷歌给出例子的修改,使得每训练1000个数据会自动保存一下模型,我们可以训练一会并ctrl+c取消训练,用已经保存的模型来做预测,虽然预测会不那么准。我在这个脚本中创建了graph.pb并且移除了dropout的操作,所以这里训练出来的模型不会遇到有node在iOS上不支持的问题。

最后

如果你看到这里还没有放弃,那希望你有一点点收获:P

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,670评论 5 460
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,928评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,926评论 0 320
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,238评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,112评论 4 356
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,138评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,545评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,232评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,496评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,596评论 2 310
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,369评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,226评论 3 313
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,600评论 3 299
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,906评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,185评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,516评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,721评论 2 335

推荐阅读更多精彩内容