使用的数据集是tpch工具生成的数据集,存放在hive中,关于相关的细节,请看
https://www.jianshu.com/p/154069c0e721
ColleborativeFilter2.scala
传入参数:model保存路径 迭代次数
作用:使用数据训练模型,最后将模型保存至本地
说明:将用户购买物品的数量作为rating值
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.recommendation.{ALS, Rating}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.hive.HiveContext
object ColleborativeFilter2 {
def main(args: Array[String]): Unit = {
SetLogger
val path= args(0)
val num=args(1).toInt
println("==========程序初始化===============")
val sparkConf = new SparkConf().setAppName("CF").setMaster("local[2]")
val spark = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
val sc =spark.sparkContext
println("==========数据准备阶段===============")
val ratings: RDD[Rating] = prepareData(sc)
println("==========训练阶段===============")
val rank = 10
val numIterations = num
val model = ALS.train(ratings, rank, numIterations, 0.01)
println("==========训练完成===============")
model.save(sc,path)
println("保存到:"+path)
sc.stop()
}
private def prepareData(sc: SparkContext) = {
val hiveContext = new HiveContext(sc)
hiveContext.sql("use tpch")
//利用hive查询数据
val resultDf = hiveContext
.sql("select o.O_CUSTKEY customer,l.L_PARTKEY part,count(*) rating" +
" from orders o,lineitem l where o.O_ORDERKEY=l.L_ORDERKEY" +
" group by o.O_CUSTKEY,l.L_PARTKEY")
//增加评分,默认10
//val resultDf=customerPartDf.withColumn("rating",customerPartDf("customer")*0+10.0)
resultDf.show()
val ratings = resultDf.rdd.map(u =>
Rating(u.getInt(0), u.getInt(1), u.get(2).toString.toDouble)
)
val numRatings = ratings.count()
val numUsers = ratings.map(_.user).distinct().count()
val numMovies = ratings.map(_.product).distinct().count()
println("共计:ratings: " + numRatings + " User " + numUsers + " Part " + numMovies)
ratings
}
def SetLogger = {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("com").setLevel(Level.OFF)
System.setProperty("spark.ui.showConsoleProgress", "false")
Logger.getRootLogger().setLevel(Level.OFF);
}
}
输出结果:
==========程序初始化===============
==========数据准备阶段===============
+--------+------+------+
|customer| part|rating|
+--------+------+------+
| 25001|115772| 1|
| 103915|175999| 1|
| 79666| 56901| 1|
| 126154|192471| 1|
| 147884|165801| 1|
| 92054| 75664| 1|
| 40555|187715| 1|
| 22195| 14042| 1|
| 51124| 31213| 1|
| 96481|193796| 1|
| 32779| 14503| 1|
| 129082| 73486| 1|
| 134419| 97723| 1|
| 26981|116112| 1|
| 125698|109181| 1|
| 23536|148693| 1|
| 43201|129019| 1|
| 135277| 82917| 1|
| 63298| 19008| 1|
| 78565|119137| 1|
+--------+------+------+
only showing top 20 rows
共计:ratings: 6000127 User 99996 Movie 200000
==========训练阶段===============
==========训练完成===============
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
保存到:file:///Users/david/IdeaProjects/ideaTest/MySpark/target/tmp/myCollaborativeFilter
TestModel.scala
传入参数: model位置 文件存储位置
作用:读取模型,进行推荐
package com.example.spark
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
object TestModel {
def main(args: Array[String]): Unit = {
SetLogger
println("==========模型加载阶段===============")
val modelPath=args(0)
val savePath=args(1)
val conf = new SparkConf().setAppName("TM").setMaster("local[2]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val sc = spark.sparkContext
val model = MatrixFactorizationModel.load(sc, modelPath)
println("模型加载成功:path="+modelPath)
println("==========推荐阶段===============")
recommend(model)
}
def recommend(model: MatrixFactorizationModel) = {
var choose = ""
while (choose != "3") { //如果选择3.离开,就结束运行程序
print("请选择要推荐类型 1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?")
choose = readLine().toString //读取用户输入
if (choose == "1") { //如果输入1.针对用户推荐电影
print("请输入用户id?")
val inputUserID = readLine() //读取用户ID
RecommendMovies(model, inputUserID.toString.toInt) //针对此用户推荐电影
} else if (choose == "2") { //如果输入2.针对电影推荐感兴趣的用户
print("请输入产品的 id?")
val inputMovieID = readLine() //读取MovieID
RecommendUsers(model, inputMovieID.toString().toInt) //针对此电影推荐用户
}
}
}
def RecommendMovies(model: MatrixFactorizationModel, inputUserID: Int) = {
val RecommendMovie = model.recommendProducts(inputUserID, 10)
var i = 1
println("针对用户id" + inputUserID + "推荐下列产品:")
RecommendMovie.foreach { r =>
println(i.toString() + "." + r.product + "评分:" + r.rating.toString())
i += 1
}
}
def RecommendUsers(model: MatrixFactorizationModel, inputMovieID: Int) = {
val RecommendUser = model.recommendUsers(inputMovieID, 10)
var i = 1
println("针对产品 id" + inputMovieID + "推荐下列用户id:")
RecommendUser.foreach { r =>
println(i.toString + "用户id:" + r.user + " 评分:" + r.rating)
i = i + 1
}
}
def SetLogger = {
Logger.getLogger("org").setLevel(Level.OFF)
Logger.getLogger("com").setLevel(Level.OFF)
System.setProperty("spark.ui.showConsoleProgress", "false")
Logger.getRootLogger().setLevel(Level.OFF);
}
}
输出结果:
==========初始化模型===============
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
请选择要推荐类型 1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?1
请输入用户id?125698
针对用户id125698推荐下列产品:
1.194564评分:3.3862003302193537
2.95318评分:3.227529363190912
3.107867评分:3.0270877690246434
4.86599评分:2.908890211972091
5.165007评分:2.8965168519326028
6.152244评分:2.8816292303536546
7.127895评分:2.832183626366389
8.37218评分:2.8070734618310933
9.43516评分:2.7800139701236577
10.162949评分:2.755918520650188
请选择要推荐类型 1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?2
请输入产品的 id?148693
针对产品 id148693推荐下列用户id:
1用户id:74882 评分:3.2278715120519137
2用户id:60653 评分:2.980748402528624
3用户id:147077 评分:2.900603769820539
4用户id:75080 评分:2.7945391669012976
5用户id:44345 评分:2.7765308146132384
6用户id:110015 评分:2.7676577792488897
7用户id:57929 评分:2.5332419522978946
8用户id:136910 评分:2.4901329135980883
9用户id:124451 评分:2.442147327035805
10用户id:109289 评分:2.360915024772536
请选择要推荐类型 1.针对用户推荐产品 2.针对产品推荐感兴趣的用户 3.离开?3