一、概述
数据集:数据集来自 scikit-learn 内置的红酒数据集,包括数据 'data', 目标分类 'target', 目标分类名'target_names', 数据描述'DESCR', 以及特征变量的名称'features_names',共 178 个样本,每个样本有13 个特征变量,最终被归入 3 个类别中。
二、代码逻辑
1.载入数据集
2.切分训练集和测试集(默认情况下 75%及对应标签归为训练集,25%及对应标签归为测试集)
3.用K近邻算法进行建模
4.在训练数据集上进行训练,在测试集上验证测试结果,不理想就调节参数
5.预测新样本的分类
三、源码
import numpy as np
from sklearn.datasets import load_wine #载入内置酒数据集
from sklearn.model_selection import train_test_split #数据集拆分工具
from sklearn.neighbors import KNeighborsClassifier #导入 KNN 分类器
wine_data = load_wine()
#酒数据的键值查看
print('===============')
print("红酒数据集中的键:\n{}".format(wine_data.keys()))
print('\n')
print('===============')
print('数据概况:{}'.format(wine_data['data'].shape))
print('\n')
print('===============')
print('红酒分类:{}'.format(wine_data['target']))
#拆分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(wine_data['data'], wine_data['target'], random_state=0)
#设定模型参数为 1
knn = KNeighborsClassifier(n_neighbors=1) #KNN
#数据拟合
knn.fit(X_train, y_train)
#训练数据预测
print('==========================')
print('训练数据得分:{:.2f}'.format(knn.score(X_train, y_train)))
#测试数据预测
print('==========================')
print('测试数据得分:{:.2f}'.format(knn.score(X_test, y_test)))
#假定得到一瓶新的酒,预测酒的分类
X_new = np.array([[13.2, 2.77, 2.51, 18.5, 96.6, 1.04, 2.55, 0.57, 1.47, 6.2, 1.05, 3.33, 820]])
#进行预测
prediction = knn.predict(X_new)
print('==========================')
print('预测新红酒的分类是:{}'.format(wine_data['target_names'][prediction]))
执行结果为
===============
红酒数据集中的键:
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
===============
数据概况:(178, 13)
===============
红酒分类:[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
==========================
训练数据得分:1.00
==========================
测试数据得分:0.76
==========================
预测新红酒的分类是:['class_2']
四、关于调参
可以使用网格法调参,后续补充,主要调节的是 n_neighbors 值。
但是,K近邻算法 对超大规模数据集拟合时间长、对高维数据集拟合欠佳、以及对稀疏矩阵束手无策