k-近邻算法是个挺好的算法,我喜欢,也希望大家喜欢。它简单小巧,如同一柄鱼肠剑,但同样锋利无比。上一篇我们解读了核心的13行代码,由于作者用了一番python的特色函数,所以写的短小精悍。我也会尝试写一个行数更多、跑的更慢,但更容易理解的,这在后文再说,到时候也会就性能等做个对比。现在,我们先来看看,除了核心代码外的一些实现。
先来看一下数据的准备,如何从文本文件里读出数据并转换成numpy数组。我们看下代码:
def file2matrix(filename):
love_dictionary = {'largeDoses':3, 'smallDoses':2, 'didntLike':1}
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines) #get the number of lines in the file
returnMat = np.zeros((numberOfLines, 3)) #prepare matrix to return
classLabelVector = [] #prepare labels return
index = 0
for line in arrayOLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
if(listFromLine[-1].isdigit()):
classLabelVector.append(int(listFromLine[-1]))
else:
classLabelVector.append(love_dictionary.get(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
如上代码就不一行行的解释了,相信大家看了也很容易懂,无外乎做了如下几步:
1)生成一个love_dictionary的字典,目的是用来解析替换心动标签为数字;
2)打开文本文件,逐行读取数据并存入arrayOLines;
3)生成一个三列的矩阵数组returnMat,用来存储从文本中读取到的样本数据,并生成一个数组classLabelVector用来存储标签;
4)遍历文本中所有的行,以tab作为分割,将样本的特征列存入returnMat,标签存入classLabelVector;
没了。代码行数多,但不代表难理解,只要是学过python文本处理读取处理、python数组处理(如数组-1代表最后一行),就基本可以了。
第二步,我们来看下,使用matplotlib观察数据:
fig = plt.figure()
ax = fig.add_subplot(111)
datingDataMat,datingLabels = kNN.file2matrix('datingTestSet.txt')
# ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels))
ax.axis([-2,25,-0.2,2.0])
plt.xlabel('Percentage of Time Spent Playing Video Games')
plt.ylabel('Liters of Ice Cream Consumed Per Week')
plt.show()
这一段其实也没啥好讲的,可以自己敲一敲就明白了。matplotlib是一款挺好的可视化库,只用短短的几行,就能够完成可视化过程,并且还可以生成立体透视(三维效果)。
第三步,是归一化数据:
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = np.zeros(np.shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - np.tile(minVals, (m, 1))
normDataSet = normDataSet/np.tile(ranges, (m, 1)) #element wise divide
return normDataSet, ranges, minVals
这段代码也比较简单。归一化的思路,就是将数据转换到0-1的区间。如何实现0-1的转换呢?
分母就是ranges, normDataSet = dataSet - np.tile(minVals, (m, 1))这行其实就是将minVals扩展到了整个矩阵大小,然后再拿dataSet去减,等于就是分子。最后再进行除法运算,实现了如上公式(还是不太玩得好这个矩阵扩展+运算的方式,但确实狠简练,虽然看起来会有点晕,原理是简单的)。
第四步,我们来看看测试代码:
def datingClassTest():
hoRatio = 0.50 #hold out 10%
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
if (classifierResult != datingLabels[i]): errorCount += 1.0
print("the total error rate is: %f" % (errorCount / float(numTestVecs)))
print(errorCount)
hoRatio是设定了用来作为测试数据的数量,在这里是设置了50%。载入测试数据,并且归一化后,设置了errorCount作为计数。
之后就是跑一遍所有的测试数据,并计算错误率。这里主要就是数组的操作以及函数的调用,就不重点讲了。
第五步,进行约会的预测:
def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(input( "percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per year?"))
iceCream = float(input("liters of ice cream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = np.array([ffMiles, percentTats, iceCream, ])
classifierResult = classify0((inArr - minVals)/ranges, normMat, datingLabels, 3)
print("You will probably like this person: %s" % resultList[classifierResult - 1])
这段代码更简单了吧?无外乎就是让用户输入相应的数字,然后再去调用classify0的函数。哦对了,要先做归一化。最后再去根据预测结果的标签值,去resultList取到文字描述,O了。
上面的代码,再回顾一下,包括五段:
1)数据的准备,即从文本生成数据;
2)matplotlib,用来可视化数据;
3)归一化数据,使数据控制在0-1之间;
4)测试算法模型的有效性,看错误率是否可控;
5)将模型用于实际预测;
核心算法就是那个classify0,就像是宝剑,而核心算法外的代码很多也很重要,它们共同完成了算法的实施搭建,这些代码就像是宝剑的剑鞘,守护着算法的有效实施。
当然,如果你的剑鞘足够的“兼容”,那么就可以与宝剑自由搭配,并非一剑一鞘。这就是模块化,框架化,可组装。玩过keras等框架的都明白这种组装的便利性,就像是插接玩具一样。
所以,咱们学了这些算法后,是否能自己封装一个框架呢?这是后话了。
约会数据很好玩,那么其他数据是否k-近邻算法也很溜呢?这个咱们来验证下。正好书上有另一个应用,就是手写数字数据集的光学识别。我们在下一章实测下k-近邻算法在手写数字识别上的表现。