Spark -- 基于RDD实现 KNN
上一篇 基于DataFrame实现KNN的过程中,由于中间使用了笛卡尔积,以及大规模的排序,对于运算的性能有较大影响,经过一定的调整,笔者找到一个相对较好的实现方法
def runKnn(trainSet: DataFrame, testSet: DataFrame, k: Int, cl: String) = {
val testFetures: RDD[Seq[Double]] = testSet
.drop(cl).map(row => {
val fetuers: Seq[Double] = row.mkString(",").split(",").map(_.toDouble)
fetuers
}).rdd
val trainFetures: RDD[(String, Seq[Double])] = trainSet.map(row => {
val cla = row.getAs[String](cl)
val fetuers: Seq[Double] = row.mkString(",")
.split(",").filter(NumberUtils.isNumber(_)).map(_.toDouble)
(cla, fetuers)
}).rdd
// 将训练集广播
val trainBroad = spark.sparkContext.broadcast(trainFetures.collect())
val resRDD: RDD[Row] = testFetures.map(testTp => {
//定义一个TreeSet之前 先自定义一个排序规则
val orderRules: Ordering[(String, Double)] = Ordering.fromLessThan[(String, Double)](_._2 <= _._2)
//新建一个空的set 传入排序规则
var set: mutable.TreeSet[(String, Double)] = mutable.TreeSet.empty(orderRules)
trainBroad.value.foreach(trainTp => {
val dist = distance.Euclidean(testTp, trainTp._2)
set += (trainTp._1 -> dist)
// 设定了set的大小,排序的时候更高效
if (set.size > k) set = set.slice(0, k) else set
})
// 获取 投票数最多的类 (一个Wordcount)
val cla = set.toArray.groupBy(_._1)
.map(t => (t._1, t._2.length)).maxBy(_._2)._1
Row.merge(Row.fromSeq(testTp), Row(cla))
})
spark.createDataFrame(resRDD, trainSet.schema)
}
算法测试
val iris = spark.read
.option("header", true)
.option("inferSchema", true)
.csv(inputFile)
// 将鸢尾花分成两部分:训练集和测试集
val Array(testSet, trainSet) = iris.randomSplit(Array(0.3, 0.7), 1234L)
val knnMode2 = new KNNRunner(spark)
val res2 = knnMode2.runKnn(trainSet, testSet, 10, "class")
res2.show(truncate = false)
val check = udf((f1: String, f2: String) => {
if (f1.equals(f2)) 1 else 0
})
res2.join(testSet.withColumnRenamed("class", "yclass"),
Seq("sepalLength", "sepalWidth", "petalLength", "petalWidth"))
.withColumn("check", check($"class", $"yclass"))
.groupBy("check").count().show()
+-----------+----------+-----------+----------+---------------+
|sepalLength|sepalWidth|petalLength|petalWidth|class |
+-----------+----------+-----------+----------+---------------+
|4.6 |3.2 |1.4 |0.2 |Iris-setosa |
|4.8 |3.0 |1.4 |0.1 |Iris-setosa |
|4.8 |3.4 |1.6 |0.2 |Iris-setosa |
+-----+-----+
|check|count|
+-----+-----+
| 1| 53|
| 0| 2|
+-----+-----+
从结果看,两个实现过程是一致的,但是本文使用的方法更高效。