实用价值
也叫KNN算法,它能根据已有数据推测将来可能发生的事情。比如说个性化推荐,股票预测和机器学习。
思路
KNN把要处理的事物抽象出许多属性特征,并且把每个属性特征看成是一个维度,这种抽象和分组的过程就是KNN算法的第一个步骤。
而与某目标的接近程度就用它与每个事物在这些维度中的距离来衡量,距离越短越接近。
当然,此处的目标事物的数据都是以往的,除了目标事物以外的其他事物的数据可能是新的也可能是旧的,而现在要做的预测它未来可能是什么样子。
现在选取距离目标事物最近的若干个事物,因为这些事物和目标事物非常接近,所以可以用来预测。
把这些事物的各个属性值求平均就是理论上的目标事物的未来可能了。
其实也是很简单的。
用法
package com.company;
import java.util.Iterator;
import java.util.Map;
public class Main {
public static void main(String[] args) {
// write your code here
Entity targetEntity = new Entity("猪");
targetEntity.characterMap.put("喜剧片",5);
targetEntity.characterMap.put("动作片",4);
targetEntity.characterMap.put("生活片",4);
targetEntity.characterMap.put("恐怖片",5);
targetEntity.characterMap.put("爱情片",3);
Entity neighourEntity0 = new Entity("兔子");
neighourEntity0.characterMap.put("喜剧片",3);
neighourEntity0.characterMap.put("动作片",4);
neighourEntity0.characterMap.put("生活片",4);
neighourEntity0.characterMap.put("恐怖片",1);
neighourEntity0.characterMap.put("爱情片",4);
Entity neighourEntity1 = new Entity("狗");
neighourEntity1.characterMap.put("喜剧片",4);
neighourEntity1.characterMap.put("动作片",3);
neighourEntity1.characterMap.put("生活片",5);
neighourEntity1.characterMap.put("恐怖片",1);
neighourEntity1.characterMap.put("爱情片",5);
Entity neighourEntity2 = new Entity("猫");
neighourEntity2.characterMap.put("喜剧片",2);
neighourEntity2.characterMap.put("动作片",5);
neighourEntity2.characterMap.put("生活片",1);
neighourEntity2.characterMap.put("恐怖片",3);
neighourEntity2.characterMap.put("爱情片",1);
Entity[] neighourEntities = {
neighourEntity0,
neighourEntity1,
neighourEntity2
};
Entity resultEntity = KNN.KNearesNeighours(targetEntity,neighourEntities);
Iterator iterator = resultEntity.characterMap.entrySet().iterator();
System.out.println("预测" + targetEntity.getName() + "的结果是:");
while (iterator.hasNext()) {
Map.Entry entry = (Map.Entry) iterator.next();
String propertyKey = (String) entry.getKey();
Integer propertyValue = (Integer) entry.getValue();
System.out.println(propertyKey + ":" + propertyValue);
}
}
}
输出
兔子与猪距离的平方为:21.0
狗与猪距离的平方为:23.0
猫与猪距离的平方为:27.0
兔子与猪最接近
预测猪的结果是:
动作片:5
恐怖片:2
喜剧片:4
爱情片:4
生活片:4
Process finished with exit code 0
实现
package com.company;
import java.util.Iterator;
import java.util.Map;
public class KNN {
/**
* 这就是KNN算法的完全体
* 就两步
* 1:分类,就是把各种属性特征剥离出来,分组。
* 2:回归,就是选取一定数量的最近邻,然后求平均数。
* @param targetEntity
* @param neighourEntities
* @return
*/
static public Entity KNearesNeighours(Entity targetEntity,Entity[] neighourEntities) {
Entity[] sortedEntityArrays = KNN.getMostLikeArray(targetEntity,neighourEntities);
return KNN.gainRegressionResult(sortedEntityArrays,3);
}
/**
* KNN算法把每个特征都看成是一个维度
* 所谓维度呢,就是一维是一条数轴,二维就是两条数轴,三维就是三条数轴。
* 所谓相似呢?就是在所有维度之内计算每个邻居实体与目标实体之间的距离。
* 然后找出其中距离最短者,被视为最相似。
* @param targetEntity
* @param neighourEntities
* @return
*/
static private Entity[] getMostLikeArray(Entity targetEntity,Entity[] neighourEntities) {
Entity[] sortedEntities = new Entity[neighourEntities.length];
int sortedEntityPointer = -1;
//此处为距离的平方,因为不计算开发不影响比较。
for (Entity element:neighourEntities) {
if (element == null)continue;
System.out.print(element.getName() + "与" + targetEntity.getName() + "距离的平方为:");
double distanceSquare = 0;
Iterator iterator = targetEntity.characterMap.entrySet().iterator();
//判断特征是否一致的标志位。因为特征不一致根本就没法比较,不具备可比性。
boolean areSameCharacters = true;
while (iterator.hasNext()) {
Map.Entry entry = (Map.Entry) iterator.next();
Integer targetValue = (Integer) entry.getValue();
String targetKey = (String) entry.getKey();
if (!element.characterMap.containsKey(targetKey)) {
areSameCharacters = false;
break;
}
Integer neighourValue = element.characterMap.get(targetKey);
//这里采用的是距离公式,而不是余弦公式。前者计算的是距离,后者计算的是角度。
distanceSquare += Math.pow(targetValue - neighourValue,2);
}
if (areSameCharacters) {
System.out.println(distanceSquare);
try {
Entity copyedElement = (Entity) element.clone();
copyedElement.setDistance(distanceSquare);
sortedEntities[++sortedEntityPointer] = copyedElement;
} catch (CloneNotSupportedException e) {
e.printStackTrace();
}
}
}
if (sortedEntities.length > 0)
System.out.println(sortedEntities[0].getName() + "与" + targetEntity.getName() + "最接近");
//现在用堆排序对邻居进行排序
HeapSort.heapSort0(sortedEntities,sortedEntityPointer + 1);
return sortedEntities;
}
/**
* 这个是KNN算法中的回归步骤,其实就是找到最近邻的n个邻居求其每个特征的平均数
* @param sortedNeighours
* @param amount
* @return
*/
static private Entity gainRegressionResult(Entity[] sortedNeighours,int amount) {
if (amount < 1 || sortedNeighours.length == 0)return null;
int counter = 0;
Entity resultEntity = new Entity("RESULT");
Iterator iterator = sortedNeighours[0].characterMap.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry entry = (Map.Entry) iterator.next();
Integer targetValue = (Integer) entry.getValue();
String targetKey = (String) entry.getKey();
resultEntity.characterMap.put(targetKey,targetValue);
}
for (;counter < amount && counter < sortedNeighours.length && sortedNeighours[counter] != null;counter++) {
Iterator iterator0 = sortedNeighours[counter].characterMap.entrySet().iterator();
while (iterator0.hasNext()) {
Map.Entry entry = (Map.Entry) iterator0.next();
Integer targetValue = (Integer) entry.getValue();
String targetKey = (String) entry.getKey();
Integer newNeighourValue = resultEntity.characterMap.get(targetKey) + targetValue;
resultEntity.characterMap.put(targetKey,newNeighourValue);
}
}
//求平均
Iterator iterator1 = resultEntity.characterMap.entrySet().iterator();
while (iterator1.hasNext()) {
Map.Entry entry = (Map.Entry) iterator1.next();
Integer targetValue = (Integer) entry.getValue();
String targetKey = (String) entry.getKey();
resultEntity.characterMap.put(targetKey,targetValue / counter);
}
return resultEntity;
}
}
堆排序
package com.company;
public class HeapSort {
/**
* 传说中的堆排序,用于大量数据,
* 本次采用小根堆排序。
* 其原理,你在网上搜一下堆排序图
* 解就能一目了然。
* 现在我想要得到一个递减的有序序
* 列,就应该使用小根堆来做。
* 个人感觉堆排序有点麻烦
* 堆是个完全二叉树,所以可以应用
* 定理——非叶子结点序号乘以2+1是其
* 左孩子结点的序号。
* 反之,数组长度除以2-1就是父结点
* 的序号。
* @param sourceArray
*/
static public void heapSort0(Entity[] sourceArray,int length) {
//一开始的时候由数组构成的二叉树是完全二叉树,
// 但是还称不上是小根堆,需要先进行调整才行。
// 是从最后一个非叶子结点往上开始的。
for (int counter = length / 2 - 1;counter > -1;counter--)
HeapSort.adjustToSmallHeap(sourceArray,counter,length);
//如此一来完全二叉树变成了有序的小根堆,
// 堆顶的结点就是整个数组最小的值。
// counter>0就可以了,因为此时待
// 调整的元素只有2个,它俩调整完了就不需要调整了。
for (int counter = length - 1;counter > 0;counter--) {
//把最小值放到数组最后面。
Entity tempElement = sourceArray[counter];
sourceArray[counter] = sourceArray[0];
sourceArray[0] = tempElement;
HeapSort.adjustToSmallHeap(sourceArray,0,counter);
}
}
/**
* 不过本步骤有一个前期,
* 那就是它会认为本次调整
* 的结点下面的结点全都是
* 已经调整好的,不然得出
* 的结果是错误的。所以很
* 显然第一次调整的时候的
* rootIndex必须是最后一
* 个非叶子结点才行。
* 本方法是要把一个非小
* 根堆调整成小根堆的方法
* 也是堆排序的核心所在
* @param sourceArray
* @param rootIndex 需要调整的堆的顶部结点的index
* @param adjustLength 需要调整的长度
*/
static private void adjustToSmallHeap(Entity[] sourceArray,int rootIndex,int adjustLength) {
//应用定理父结点的序号等于
// 该结点所在长度除以2再减
// 一,遍历只能是非叶子结
// 点,所以边界值设置为adjustLength / 2。
// 这样可以减少遍历的次数,也算是精确控制吧
for (int counter = rootIndex;counter < adjustLength / 2;) {
//2*rootIndex必然是它的左结点的序号
// 并且是存在的,而2*rootIndex+1却
// 未必存在。
//因为adjustLength/2是非叶子结点的
// 最大范围加一,所以其范围内的非叶子
// 结点必然有叶子结点存在。因为2*rootIndex+1
// 和2*rootIndex求余的结果是一样的,
// 所以最后一个非叶子结点不一定有右孩
// 子结点,因此如下所示。
int leftChildIndex = 2 * counter + 1;
int rightChildIndex = leftChildIndex + 1;
int smallerPointer = leftChildIndex;
//如果右孩子结点的值小于左孩子结点
// ,那说明右孩子结点应该优先被交
// 换,前提是存在可能的话。
if (rightChildIndex < adjustLength && sourceArray[leftChildIndex].getDistance() < sourceArray[rightChildIndex].getDistance())
smallerPointer = rightChildIndex;
//因为是小根堆,所以只要父结点比
// 较小的孩子结点大就应该交换。
if (sourceArray[counter].getDistance() < sourceArray[smallerPointer].getDistance()) {
Entity tempElement = sourceArray[counter];
sourceArray[counter] = sourceArray[smallerPointer];
sourceArray[smallerPointer] = tempElement;
//此时应该把循环指针指向较小的子结点的位置
counter = smallerPointer;
//一旦不满足if中的条件,就说明不再会发生
// 交换了,因为本次调整是在已经被调整好的
// 前提下的。这样也可以省去很多不必要的遍历。
} else break;
}
}
}
事物
package com.company;
import java.util.HashMap;
import java.util.Map;
public class Entity implements Cloneable {
//存储特征的map。
public Map<String,Integer> characterMap = new HashMap<>();
private String name;
private double distance = 0;
public Entity(String name) {
this.name = name;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
@Override
public Object clone() throws CloneNotSupportedException {
Object object = null;
try {
object = super.clone();
} catch (Exception e) {
e.printStackTrace();
}
return super.clone();
}
}