参考:《深度学习图像识别技术--基于TensorFlow Object Detection API 和 OpenVINO》
问题:假设你的生物学家,要对鸢尾(Iris)花分类。Iris有300多类,这里仅仅对Iris setosa,Iris virginica,Iris versicolor 这三类进行识别,如下图所示
本文范例程序下载地址:IrisClassifier.py
方法有很多种,比如,基于CNN的深度学习,直接学习图像。这里采用更加简单的方法,通过 sepals(花萼)和 petals(花瓣)的长度和宽度数据,进行模型训练和分类,这样更加适合初学者。
收集和构架数据集要花很多时间,幸运的是,已经有现成的Iris flower data set,which contains a set of 150 records under 5 attributes - Petal Length , Petal Width , Sepal Length , Sepal width and Class 如下图所示
基于这样的数据集(DataSet),可以让我们更加专注于学习机器学习的算法,而不需要花大量时间准备数据
第一步:下载训练数据集:
我们需要把dataset文件下载到本地,然后把它转化为Python可以使用的数据结构。范例代码如下:
打开文件:C:\Users\tf\.keras\datasets\iris_training.csv
可以看到有120行数据,跟Iris data set wiki里面说的不大一样,不过没有关系,不影响训练。
前四列是Features,分别是:Petal Length , Petal Width , Sepal Length , Sepal width
第五列是label,分别用整型数来代表花的种类,对机器来说,用整型数比用字符串更加方便,但我们要知道整型数和花种类之间的映射:
0: Iris setosa
1: Iris versicolor
2: Iris virginica
第二步:解析(Parse)数据集
下载到本地的数据集iris_training.csv 是一个 CSV格式的文本文件, TensorFlow模型还不能直接使用。我们需要把feature和label的值按照TensorFlow模型的数据输入要求,重新格式化。
创建一个函数 parse_csv
输入参数是:iris_training.csv文件的一行(line),
功能是:把 前四个 feature 值合并成为一个List,并reshape成为一个 single tensor;把最后一个 label 变量reshape成为一个single tensor.
返回值: features 和 label tensors
如下所示:
tf.decode_csv函数功能是:Convert CSV records to tensors. Each column maps to one tensor.
tf.reshape(tensor, shape,name=None)函数的功能是:Given tensor, this operation returns a tensor that has the same values as tensor with shape shape
第三步:创建训练 tf.data.Dataset
TensorFlow's Dataset API 用于feeding data into a model,它负责读取data,并将data转换为适合模型训练的格式
代码如下所示:
执行结果如下所示: