14:Spark Streaming源码解读之State管理之updateStateByKey和mapWithState解密

首先简单解释一下什么是state(状态)管理?我们以wordcount为例。每个batchInterval会计算当前batch的单词计数,那如果需要计算从流开始到目前为止的单词出现的次数,该如计算呢?SparkStreaming提供了两种方法:updateStateByKey和mapWithState 。mapWithState 是1.6版本新增功能,目前属于实验阶段。mapWithState具官方说性能较updateStateByKey提升10倍。那么我们来看看他们到底是如何实现的。
一、updateStateByKey 解析
1.1 updateStateByKey 的使用实例
首先看一个updateStateByKey函数使用的例子:

object UpdateStateByKeyDemo {

 def main(args: Array[String]) {

 val conf = new SparkConf().setAppName("UpdateStateByKeyDemo")

 val ssc = new StreamingContext(conf,Seconds(20))

 //要使用updateStateByKey方法,必须设置Checkpoint。

 ssc.checkpoint("/checkpoint/")

 val socketLines = ssc.socketTextStream("localhost",9999)

 

 socketLines.flatMap(_.split(",")).map(word=>(word,1))

 .updateStateByKey( 
(currValues:Seq[Int],preValue:Option[Int]) =>{

     
val currValue = currValues.sum
//将目前值相加

 Some(currValue + preValue.getOrElse(0))
//目前值的和加上历史值

 }).print()

 

 ssc.start()

 ssc.awaitTermination()

 ssc.stop()

 

 }

}

代码很简单,关键地方写了详细的注释。

1.2 updateStateByKey 方法源码分析

我们知道map返回的是MappedDStream,而MappedDStream并没有updateStateByKey方法,并且它的父类DStream中也没有该方法。但是DStream的伴生对象中有一个隐式转换函数

 implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])

 (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null):

 PairDStreamFunctions[K, V] = {

 new PairDStreamFunctions[K, V](stream)

 }

PairDStreamFunction 中updateStateByKey的源码如下:
 def updateStateByKey[S: ClassTag](

 updateFunc: (Seq[V], Option[S]) => Option[S]

 ): DStream[(K, S)] = ssc.withScope {

 updateStateByKey(updateFunc, defaultPartitioner())

 }

其中updateFunc就要传入的参数,他是一个函数,
Seq[V]表示当前key对应的所有值,Option[S] 是当前key的历史状态,返回的是新的状态。

最终会调用下面的方法:

 def updateStateByKey[S: ClassTag](

 updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],

 partitioner: Partitioner,

 rememberPartitioner: Boolean

 ): DStream[(K, S)] = ssc.withScope {

 new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)

 }

在这里面new出了一个StateDStream对象。在其compute方法中,会先获取上一个batch计算出的RDD(包含了至程序开始到上一个batch单词的累计计数),然后在获取本次batch中StateDStream的父类计算出的RDD(本次batch的单词计数)分别是prevStateRDD和parentRDD,然后在调用 computeUsingPreviousRDD 方法:

 private [this] def computeUsingPreviousRDD (

 parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = {

 // Define the function for the mapPartition operation on cogrouped RDD;

 // first map the cogrouped tuple to tuples of required type,

 // and then apply the update function

 val updateFuncLocal = updateFunc

 val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {

 val i = iterator.map { t =>

 val itr = t._2._2.iterator

 val headOption = if (itr.hasNext) Some(itr.next()) else None

 (t._1, t._2._1.toSeq, headOption)

 }

 updateFuncLocal(i)

 }

 val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)

 val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)

 Some(stateRDD)

 }

两个RDD进行cogroup然后应用updateStateByKey传入的函数。cogroup的性能是比较低下的。

二、mapWithState方法解析
2.1 mapWithState方法使用实例:

object StatefulNetworkWordCount {

 def main(args: Array[String]) {

 if (args.length < 2) {

 System.err.println("Usage: StatefulNetworkWordCount <hostname> <port>")

 System.exit(1)

 }

 

 StreamingExamples.setStreamingLogLevels()

 

 val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")

 // Create the context with a 1 second batch size

 val ssc = new StreamingContext(sparkConf, Seconds(1))

 ssc.checkpoint(".")

 

 // Initial state RDD for mapWithState operation

 val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))

 

 // Create a ReceiverInputDStream on target ip:port and count the

 // words in input stream of \n delimited test (eg. generated by 'nc')

 val lines = ssc.socketTextStream(args(0), args(1).toInt)

 val words = lines.flatMap(_.split(" "))

 val wordDstream = words.map(x => (x, 1))

 

 // Update the cumulative count using mapWithState

 // This will give a DStream made of state (which is the cumulative count of the words)

 val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {

 val sum = one.getOrElse(0) + state.getOption.getOrElse(0)

 val output = (word, sum)

 state.update(sum)

 output

 }

 

 val stateDstream = wordDstream.mapWithState(

 StateSpec.function(mappingFunc).initialState(initialRDD))

 stateDstream.print()

 ssc.start()

 ssc.awaitTermination()

 }

}

mapWithState接收的参数是一个StateSpec对象。在StateSpec中封装了状态管理的函数
mapWithState函数中创建了MapWithStateDStreamImpl对象

 def mapWithState[StateType: ClassTag, MappedType: ClassTag](

 spec: StateSpec[K, V, StateType, MappedType]

 ): MapWithStateDStream[K, V, StateType, MappedType] = {

 new MapWithStateDStreamImpl[K, V, StateType, MappedType](

 self,

 spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]

 )

 }

MapWithStateDStreamImpl 中创建了一个InternalMapWithStateDStream类型对象internalStream,在MapWithStateDStreamImpl的compute方法中调用了internalStream的getOrCompute方法。

/** Internal implementation of the [[MapWithStateDStream]] */

private[streaming] class MapWithStateDStreamImpl[

 KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](

 dataStream: DStream[(KeyType, ValueType)],

 spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])

 extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {

 

 private val internalStream =

 new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)

 

 override def slideDuration: Duration = internalStream.slideDuration

 

 override def dependencies: List[DStream[_]] = List(internalStream)

 

 override def compute(validTime: Time): Option[RDD[MappedType]] = {

 internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }

 }

InternalMapWithStateDStream中没有getOrCompute方法,这里调用的是其父类 DStream 的getOrCpmpute方法,该方法中最终会调用InternalMapWithStateDStream的Compute方法:

 /** Method that generates a RDD for the given time */

 override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {

 // Get the previous state or create a new empty state RDD

 val prevStateRDD = getOrCompute(validTime - slideDuration) match {

 case Some(rdd) =>

 if (rdd.partitioner != Some(partitioner)) {

 // If the RDD is not partitioned the right way, let us repartition it using the

 // partition index as the key. This is to ensure that state RDD is always partitioned

 // before creating another state RDD using it

 MapWithStateRDD.createFromRDD[K, V, S, E](

 rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)

 } else {

 rdd

 }

 case None =>

 MapWithStateRDD.createFromPairRDD[K, V, S, E](

 spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),

 partitioner,

 validTime

 )

 }

 

 

 // Compute the new state RDD with previous state RDD and partitioned data RDD

 // Even if there is no data RDD, use an empty one to create a new state RDD

 val dataRDD = parent.getOrCompute(validTime).getOrElse {

 context.sparkContext.emptyRDD[(K, V)]

 }

 val partitionedDataRDD = dataRDD.partitionBy(partitioner)

 val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>

 (validTime - interval).milliseconds

 }

 Some(new MapWithStateRDD(

 prevStateRDD
, partitionedDataRDD
, mappingFunction
, validTime, timeoutThresholdTime))

 }

根据给定的时间生成一个MapWithStateRDD,首先获取了先前状态的RDD:preStateRDD和当前时间的RDD:dataRDD,然后对dataRDD基于先前状态RDD的分区器进行重新分区获取partitionedDataRDD。最后将preStateRDD,partitionedDataRDD和用户定义的函数mappingFunction传给新生成的MapWithStateRDD对象返回。

下面看一下MapWithStateRDD的compute方法:

 override def compute(

 partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {

 

 val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]

 val prevStateRDDIterator = prevStateRDD.iterator(

 stateRDDPartition.previousSessionRDDPartition, context)

 val dataIterator = partitionedDataRDD.iterator(

 stateRDDPartition.partitionedDataRDDPartition, context)

//prevRecord 代表一个分区的数据
 

 val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None

 val newRecord = MapWithStateRDDRecord.updateRecordWithData(

 prevRecord,

 dataIterator,

 mappingFunction,

 batchTime,

 timeoutThresholdTime,

 removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled

 )

 Iterator(newRecord)

 }

MapWithStateRDDRecord 对应MapWithStateRDD 的一个分区:

private[streaming] case class MapWithStateRDDRecord[K, S, E](

 var stateMap: StateMap[K, S], var mappedData: Seq[E])

其中stateMap存储了key的状态,mappedData存储了mapping function函数的返回值

看一下MapWithStateRDDRecord的
updateRecordWithData方法

 def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](

 prevRecord: Option[MapWithStateRDDRecord[K, S, E]],

 dataIterator: Iterator[(K, V)],

 mappingFunction: (Time, K, Option[V], State[S]) => Option[E],

 batchTime: Time,

 timeoutThresholdTime: Option[Long],

 removeTimedoutData: Boolean

 ): MapWithStateRDDRecord[K, S, E] = {

// 创建一个新的 state map 从过去的Recoord中复制 (如果存在) 否则创建一下空的StateMap对象

 val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

 

 val mappedData = new ArrayBuffer[E]

    
//状态

 val wrappedState = new StateImpl[S]()

 

 // Call the mapping function on each record in the data iterator, and accordingly

 // update the states touched, and collect the data returned by the mapping function

 dataIterator.foreach { case (key, value) =>

    
//获取key对应的状态

 wrappedState.wrap(newStateMap.get(key))

    
//调用mappingFunction获取返回值

 val returned = mappingFunction(batchTime, key, Some(value), wrappedState)

    
//维护

newStateMap的值

 if (wrappedState.isRemoved) {

 newStateMap.remove(key)

 } else if (wrappedState.isUpdated

 || (wrappedState.exists && timeoutThresholdTime.isDefined)) {

 newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)

 }

 mappedData ++= returned

 }

 

 // Get the timed out state records, call the mapping function on each and collect the

 // data returned

 if (removeTimedoutData && timeoutThresholdTime.isDefined) {

 newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>

 wrappedState.wrapTimingOutState(state)

 val returned = mappingFunction(batchTime, key, None, wrappedState)

 mappedData ++= returned

 newStateMap.remove(key)

 }

 }

 

 MapWithStateRDDRecord(newStateMap, mappedData)

 }

最终返回MapWithStateRDDRecord
对象交个MapWithStateRDD的compute函数,MapWithStateRDD的compute函数将其封装成Iterator返回。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,457评论 5 459
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,837评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,696评论 0 319
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,183评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,057评论 4 355
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,105评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,520评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,211评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,482评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,574评论 2 309
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,353评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,213评论 3 312
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,576评论 3 298
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,897评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,174评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,489评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,683评论 2 335

推荐阅读更多精彩内容