本机环境
centos6.8 + python2.7 + tensorflow0.11
下载tensorflow源码
从github上拉去代码并切换到0.11版本:
git clone https://github.com/tensorflow/tensorflow
git checkout r0.11
google-Inception模型示例
执行如下命令,利用google的inception模型识别图片space_shuttle.jpg
cd tensorflow/models/images/imagenet/
python classify_image.py --image_file /home/xiabing/TensorFlow_pics/space_shuttle.jpg
可以看到识别结果如下:
分析classify_image.py
下面看看classify_image.py的源码
classify_image.py会首先下载分类器模型:
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
下载后会放到本地/tmp/imagenet/路径下:
训练自己的分类模型
使用tensorflow中examples中的image_retraining来retraining谷歌的inception模型
准备图片数据
准备要训练的每个分类,需要有个对应的文件夹(因为每个子文件夹内的各个图片的label标签就是取分类文件夹名的)类似以下这种:
fruit/banna/
fruit/apple/
每个分类内的数据格式没有规定,本例如下:
使用retraining.py训练
调用如下命令开始训练,参数详解参见retrain.py文件:
python /home/xiabing/TensorFlow/tensorflow/tensorflow/examples/image_retraining/retrain.py --bottleneck_dir /home/xiabing/sd_classify_pics/bottleneck --how_many_training_steps 4000 --model_dir /home/xiabing/sd_classify_pics/model --output_graph /home/xiabing/sd_classify_pics/output_graph.pb --output_labels /home/xiabing/sd_classify_pics/output_labels.txt --image_dir /home/xiabing/TensorFlow_pics/fruit/
首次调用会出现如下错误:
ImportError: cannot import name graph_util
解决办法:
修改retrain.py,把
from tensorflow.python.framework import graph_util
替换为
from tensorflow.python.client import graph_util
再重新执行上面命令,看到如下打印表示训练完成:
训练结果
训练完成后,会在当前目录下生成下面两个文件。查看标签文件,会看到banana和apple。
使用训练好的模型
在训练结果路径下新建test.py文件,加入如下代码:
import tensorflow as tf
import sys
image_file = sys.argv[1]
#print(image_file)
image = tf.gfile.FastGFile(image_file, 'rb').read()
labels = []
for label in tf.gfile.GFile("output_labels.txt"):
labels.append(label.rstrip())
with tf.gfile.FastGFile("output_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})
top = predict[0].argsort()[-len(predict[0]):][::-1]
for index in top:
human_string = labels[index]
score = predict[0][index]
print(human_string, score)
测试训练好的模型:
python /home/xiabing/sd_classify_pics/test.py /home/xiabing/TensorFlow_pics/1510114397170.jpg
原始图片:
测试结果: