简介
KNN是很多人接触机器学习的第一个算法,我也不例外。在利用OpenCV (C++)结合KNN处理MNIST数据,遇到了很多的坑,在这里和各位分享一下心得。
完整代码在这里,喜欢的可以Star,不喜欢的可以提建议!
环境是MacOS + OpenCV4
关键步骤概览
关键步骤的代码取自于我实现的部分,这里只是阐述关键步骤和一些心得,详细地可以看我代码,比较容易看懂的!
- 获得MNIST的训练集(包含图片和数据)
bool get_train_images_with_label_from_mnist(cv::Ptr<cv::ml::TrainData> &trainData)
- 获得MNIST的测试集(包含图片和数据)
bool get_test_images_with_label_from_mnist(cv::Mat &testData, cv::Mat &testLabel)
- 创建KNN模型,并设定一些基本的参数。
Ptr<ml::KNearest> knn_model = ml::KNearest::create();
knn_model->setDefaultK(K_value); // 指明KNN的K
knn_model->setIsClassifier(true); // 指明这个KNN是用来分类的
knn_model->setAlgorithmType(cv::ml::KNearest::Types::BRUTE_FORCE);
- 训练刚刚创建的KNN模型
knn_model->train(training_set, 0); // 利用训练集训练KNN
- 用
findNearest
进行预测
knn_model->findNearest(test_set, knn_model->getDefaultK(), result_set);
注意: 这里的result_set的结果返回的是CV_32F的类型,也就是说里面的元素是32位的float
,可能会和我们之后用的标记(可能会用int32_t
来存储),所以需要static_cast
。
- 利用测试集的标记
testLabel
和result_set
的比较来计算预测准确率。(如果它们类型不一样,比如一个是float32,另一个是int32,请记得cast)
如何处理MNIST数据集
这里给出3个关键的提示
MNIST 数据集是用大端的方式存储的,用Intel处理器的PC机一般是小端存储的,需要做转换。
cv::ml::TrainData::create()
只能处理CV_32F类型的,也就是32位float
, 但是NMIST中的像素是用unsigned byte
存的。MNIST中的图片是二维的,但你需要把它转存成一维的数组以便于它被
cv::Mat
处理。
完整代码
再说一遍,完整代码位置:
https://github.com/VinStarry/CV_codes/tree/master/elementary/knn
测试结果
准确率与K的取值散点图
错误结果示例
-
预测结果:9, 实际数字:4
-
预测结果:6,实际数字:4
-
预测结果:8,实际数字:9