一、Distinct aggregation 算法
包含 distinct 关键字的 aggregation 由 4 个物理执行步骤组成。我们使用以下 query 来介绍:
val dataset = Seq(
(1, "a"), (1, "a"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset.groupBy($"nr").agg(functions.countDistinct("letter")).explain(true)
① partial aggregation 步骤
第一步是创建一个 partial aggregate,此 partial aggregate 的 grouping key 将不仅包括 query 中定义的 grouping key(nr),还包含 distinct 的列(letter),效果如 group by nr、letter
,执行计划如下:
HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
② partial merge aggregation 步骤
这一步将通过 shuffle 将具有相同 grouping key(此处为 nr、letter)的数据划分为同一分区:
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
③ partial aggregation for distinct 步骤
第三步,Spark 最终开始执行聚合,执行的是 partial aggregate:
+- HashAggregate(keys=[nr#5], functions=[partial_count(distinct letter#6)], output=[nr#5, count#18L])
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
④ final aggregation 步骤
第四步,partial aggregate(第三步)的结果将合并到最终结果中,并进行返回。它涉及 shuffle:
HashAggregate(keys=[nr#5], functions=[count(distinct letter#6)], output=[nr#5, count(DISTINCT letter)#12L])
+- Exchange hashpartitioning(nr#5, 200)
+- HashAggregate(keys=[nr#5], functions=[partial_count(distinct letter#6)], output=[nr#5, count#18L])
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- Exchange hashpartitioning(nr#5, letter#6, 200)
+- HashAggregate(keys=[nr#5, letter#6], functions=[], output=[nr#5, letter#6])
+- LocalTableScan [nr#5, letter#6]
我们用下面的这张图来总结上述几个步骤:
二、无 Distinct aggregation 算法
无 Distinct aggregation 会简单一些,仅包含两个步骤,我们通过下面的例子来说明:
val dataset = Seq(
(1, "a"), (1, "a"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset.groupBy($"nr").count().explain(true)
①、partial aggregations 步骤
第一步即进行局部聚合:
HashAggregate(keys=[nr#5], functions=[partial_count(1)], output=[nr#5, count#17L])
+- PlanLater LocalRelation [nr#5]
②、final aggregation 步骤
第二步,毫无疑问,对部分结果进行了最终汇总:
HashAggregate(keys=[nr#5], functions=[count(1)], output=[nr#5, count#12L])
+- HashAggregate(keys=[nr#5], functions=[partial_count(1)], output=[nr#5, count#17L])
+- PlanLater LocalRelation [nr#5]
三、Hash-based 和 Sort-based aggregation
上述两种模式都会调用到 createAggregate
方法,该方法为以下 3 种策略创建物理执行计划:
- hash-based
- object-hash-based
- sort-based
这 3 中策略有一些共性。一个 Spark Sql aggregation 主要由两部分组成:
- 一个 agg buffer(聚合缓冲区:包含 grouping keys 和 agg value)
- 一个 agg state(聚合状态:仅 agg value)
每次调用 GROUP BY key
并对其使用一些聚合时,框架都会创建一个聚合缓冲区,保留给定的聚合(GROUP BY key)。指定 key(COUNT,SUM等)所涉及的聚合都在此聚合缓冲区存储其部分(partial)或最终聚合结果,称为聚合状态。该状态的存储格式取决于聚合:
- 对于 AVG,它将是2个值,一个是出现次数,另一个是值的总和
- 对于 MIN,它将是到目前为止所看到的最小值
依此类推
hash-based
策略使用可变的、原始的、固定 size 的类型来作为 agg state,包括:
- NullType
- BooleanType
- ByteType
- ShortType
- IntegerType
- LongType
- FloatType
- DoubleType
- DateType
- TimestampType
这里的可变能力非常重要,因为 Spark 会直接修改该值(如对于 count 来说,遇到新的 row,就会把 count 的值(agg state)加上 1)。
对于 agg state 的值是其他类型的情况,使用 object-hash-based
策略,该策略自 2.2.0 版本引入,目的是为了解决 hash-based
策略的局限性(必须使用可变的、原始的、固定 size 的类型来作为 agg state)。在 2.2.0 之前,针对 HashAggregateExec 不支持的其他类型执行的聚合都会转换为 sort-based
的策略。大部分情况下,sort-based
的性能会比 hash-based
的差,因为在聚合前会进行额外的排序。通过参数 spark.sql.execution.useObjectHashAggregateExec
来控制是否使用 object-hash-based
聚合,默认为 true。我们通过下面的例子来理解 sort-based
和 object-hash-based
的区别:
查询
val dataset2 = Seq(
(1, "a"), (1, "aa"), (1, "a"), (2, "b"), (2, "b"), (3, "c"), (3, "c")
).toDF("nr", "letter")
dataset2.groupBy("nr").agg(functions.collect_list("letter").as("collected_letters")).explain(true)
如你所见,上图两个物理执行计划均只进行一次 shuffle,但 sort-based
聚合相对于 object-hash-based
额外多了两次排序,带来性能开销。
另一个值得关注的点是,hash-based
和 object-hash-based
运行过程中如果内存不够用,会切换成 sort-based
聚合。对于 object-hash-based
聚合,通过参数 spark.sql.objectHashAggregate.sortBased.fallbackThreshold
控内存中(一种 hashMap)最多持有多少个 agg buffer(一个 grouping key 的组合一个),若超过该值,则切换为 sort-based
agg,该配置默认值为 128。如果切换为 sort-based
agg,会打印如下日志:
ObjectAggregationIterator: Aggregation hash map reaches threshold capacity (128 entries), spilling and falling back to sort based aggregation. You may change the threshold by adjust option spark.sql.objectHashAggregate.sortBased.fallbackThreshold
对于 hash-based
,该值为 Integer.MaxValue