Spark 中 Aggregate 的实现

使用数据准备

Seq((1, 1, 1, 1), (1, 2, 1, 1)).toDF("a", "b", "c", "d").createTempView("v")

无 Distinct Aggregate

Query : SELECT count(b) FROM v GROUP BY a

== Analyzed Logical Plan ==
count(b): string
GlobalLimit 21
+- LocalLimit 21+- Project [cast(count(b)#296L as string) AS count(b)#299]+- Aggregate [a#287], [count(b#288) AS count(b)#296L]+- SubqueryAlias v+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]+- LocalRelation [_1#278, _2#279, _3#280, _4#281]== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21+- Aggregate [a#287], [cast(count(1) as string) AS count(b)#299]+- Project [_1#278 AS a#287]+- LocalRelation [_1#278, _2#279, _3#280, _4#281]== Physical Plan ==
CollectLimit 21
+- *(2) HashAggregate(keys=[a#287], functions=[count(1)], output=[count(b)#299])+- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#108]+- *(1) HashAggregate(keys=[a#287], functions=[partial_count(1)], output=[a#287, count#302L])+- *(1) Project [_1#278 AS a#287]+- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]

生成物理Aggregate

  • 将原来的 groupingExpressions 全部转换为 NamedExpression,得到新的 namedGroupingExpressions

  • 过滤其中的 AggregateExpression 得到 aggregateExpressions = count(key#13)

  • 所有result expressions 使用新的 namedGroupingExpressions 和 AggregateExpression 进行替换

One Distinct Aggregate

Query: SELECT count(distinct b),sum(c) FROM v GROUP BY a

Plan

== Analyzed Logical Plan ==
count(DISTINCT b): string, sum(c): string
GlobalLimit 21
+- LocalLimit 21+- Project [cast(count(DISTINCT b)#297L as string) AS count(DISTINCT b)#303, cast(sum(c)#298L as string) AS sum(c)#304]+- Aggregate [a#287], [count(distinct b#288) AS count(DISTINCT b)#297L, sum(c#289, None) AS sum(c)#298L]+- SubqueryAlias v+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]+- LocalRelation [_1#278, _2#279, _3#280, _4#281]== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21+- Aggregate [a#287], [cast(count(distinct b#288) as string) AS count(DISTINCT b)#303, cast(sum(c#289, None) as string) AS sum(c)#304]+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289]+- LocalRelation [_1#278, _2#279, _3#280, _4#281]== Physical Plan ==
CollectLimit 21
+- *(3) HashAggregate(keys=[a#287], functions=[sum(c#289, None), count(distinct b#288)], output=[count(DISTINCT b)#303, sum(c)#304])+- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#123]+- *(2) HashAggregate(keys=[a#287], functions=[merge_sum(c#289, None), partial_count(distinct b#288)], output=[a#287, sum#308L, count#311L])+- *(2) HashAggregate(keys=[a#287, b#288], functions=[merge_sum(c#289, None)], output=[a#287, b#288, sum#308L])+- Exchange hashpartitioning(a#287, b#288, 5), ENSURE_REQUIREMENTS, [id=#118]+- *(1) HashAggregate(keys=[a#287, b#288], functions=[partial_sum(c#289, None)], output=[a#287, b#288, sum#308L])+- *(1) Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289]+- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]

四个 HashAggregate 说明:

  • Agg1 和 Agg2 : 根据 group by 字段和 distinct 字段进行聚合分组,同时计算 regular aggregate expression
  • Agg3 和 Agg4 : 根据 group by 字段进行聚合分组,计算regular aggregate expression 和 distinct aggregate expression
partialAggregate (groupingExpressions = groupingExpressions ++ distinctExpressions,aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
)||\|/
partialAggregate (requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes),groupingExpressions = groupingAttributes ++ distinctAttributes,aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
)||\|/
partialAggregate (groupingExpressions = groupingAttributes,  // 全局Grouping attributesaggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) // 非distinct function的 partial merge 模式++ distinctAggregateExpressions,  // 前面虽然对 groupingAttributes ++ distinctAttributes 去重,但是不保证对 groupingAttributes 去重,所以重写 functionsWithDistinct 依旧保留 distinct, 模式 PartialaggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes
)||\|/
finalAggregate (requiredChildDistributionExpressions = Some(groupingAttributes), // Shuffle by 全局Grouping attributesgroupingExpressions = groupingAttributes, aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // 非distinct function的 Final 模式++ distinctAggregateExpressions, // 同上,模式改为 FinalaggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes
)

Aggregate Distinct 多次同一个expresson

Query: SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i

== Optimized Logical Plan ==
Aggregate [i#284], [sum(distinct j#285, None) AS sum(DISTINCT j)#292, max(j#285) AS max(DISTINCT j)#293]
+- Project [_1#277 AS i#284, _2#278 AS j#285]+- LocalRelation [_1#277, _2#278, _3#279]== Physical Plan ==
*(3) HashAggregate(keys=[i#284], functions=[max(j#285), sum(distinct j#296, None)], output=[sum(DISTINCT j)#292, max(DISTINCT j)#293])
+- Exchange hashpartitioning(i#284, 5), ENSURE_REQUIREMENTS, [id=#119]+- *(2) HashAggregate(keys=[i#284], functions=[merge_max(j#285), partial_sum(distinct j#296, None)], output=[i#284, max#298, sum#301])+- *(2) HashAggregate(keys=[i#284, j#296], functions=[merge_max(j#285)], output=[i#284, j#296, max#298])+- Exchange hashpartitioning(i#284, j#296, 5), ENSURE_REQUIREMENTS, [id=#114]+- *(1) HashAggregate(keys=[i#284, knownfloatingpointnormalized(normalizenanandzero(j#285)) AS j#296], functions=[partial_max(j#285)], output=[i#284, j#296, max#298])+- *(1) Project [_1#277 AS i#284, _2#278 AS j#285]+- *(1) LocalTableScan [_1#277, _2#278, _3#279]

第一组Aggregate 实际上应该就是对i, j 进行了聚合,此时 j 变成了 distinct value,重命名为 distinct j#296

第二组Aggregate 根据 i 进行聚合,计算 max(j) 和 sum(distinct j)

Aggregate Distinct 多次不同expresson

Spark在生成物理计划的时候只支持处理一个 distinct group,如果distinct group 大于一个,RewriteDistinctAggregates 会重写distinct 为 expand,然后安装无 Distinct 的Aggregate 来处理。

Query: SELECT count(DISTINCT b, c), count(DISTINCT c, d) FROM v GROUP BY a

RewriteDistinctAggregates 处理逻辑:

  • 找出所有 distinct group 以及对应的表达式:

    • Group Set(b#288, c#289) 对应 List(count(distinct b#288, c#289))
    • Group Set(b#288, d#290) 对应 List(count(distinct b#288, d#290))
  • 将distinct group 中所有的key 字段解析出来并重新做映射 distinctAggChildren = (b, c, d),后面我们只要保证根据 groupByAttrs + 这些字段 + gid 去重,就可以保证数据的 Distinct

  • 接着处理第一步的 distinct group, 注意新生成的表达式是从 project 中取的结果

    • Distinct group count(distinct b#288, c#289) 对应的 projection = (b, c, null, gid(1)) 新的 expression = count(if ((gid#307 = 1)) v.b#308 else null, if ((gid#307 = 1)) v.c#309 else null)
    • Distinct group count(distinct c#289, d#290) 对应的 projection = (null, c, d, gid(2)) 新的 expression = count(if ((gid#307 = 2)) v.c#309 else null, if ((gid#307 = 2)) v.d#310 else null)
  • 处理 regularAggExpression

  • 生成 Expand -> firstAggregate -> Aggregate 新的Plan

DAG

== Analyzed Logical Plan ==
count(DISTINCT b, c): string, count(DISTINCT c, d): string
GlobalLimit 21
+- LocalLimit 21+- Project [cast(count(DISTINCT b, c)#297L as string) AS count(DISTINCT b, c)#303, cast(count(DISTINCT c, d)#298L as string) AS count(DISTINCT c, d)#304]+- Aggregate [a#287], [count(distinct b#288, c#289) AS count(DISTINCT b, c)#297L, count(distinct c#289, d#290) AS count(DISTINCT c, d)#298L]+- SubqueryAlias v+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]+- LocalRelation [_1#278, _2#279, _3#280, _4#281]== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21+- Aggregate [a#287], [cast(count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null) as string) AS count(DISTINCT b, c)#303, cast(count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null) as string) AS count(DISTINCT c, d)#304]+- Aggregate [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]+- Expand [ArrayBuffer(a#287, b#288, c#289, null, 1), ArrayBuffer(a#287, null, c#289, d#290, 2)], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]+- LocalRelation [_1#278, _2#279, _3#280, _4#281]== Physical Plan ==
CollectLimit 21
+- *(3) HashAggregate(keys=[a#287], functions=[count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null), count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null)], output=[count(DISTINCT b, c)#303, count(DISTINCT c, d)#304])+- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#128]+- *(2) HashAggregate(keys=[a#287], functions=[partial_count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null), partial_count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null)], output=[a#287, count#313L, count#314L])+- *(2) HashAggregate(keys=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], functions=[], output=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307])+- Exchange hashpartitioning(a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307, 5), ENSURE_REQUIREMENTS, [id=#123]+- *(1) HashAggregate(keys=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], functions=[], output=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307])+- *(1) Expand [ArrayBuffer(a#287, b#288, c#289, null, 1), ArrayBuffer(a#287, null, c#289, d#290, 2)], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]+- *(1) Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]+- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部