Spark-WholeStageCodeGen源码学习笔记

环境与版本

  • OS:centos 7
  • JDK版本:1.8
  • Spark版本:2.1.0
  • Scala版本:2.11
  • IDE:intellij idea 14.1.4

WholeStageCodeGen简介

Spark2.0集成了第二代Tungsten engine,经过我们的测试,性能相对spark1.6有明显的提升,而其中一个重要的特性就是WholeStageCodeGen,在databricks的官博上有详细讲解这个新特性的文章:
https://databricks.com/blog/2016/05/23/apache-spark-as-a-compiler-joining-a-billion-rows-per-second-on-a-laptop.html
简而言之,利用WholeStageCodeGen技术,可以将一次计算过程中的多个operators作为一个整体,生成与手写代码性能相近的代码。

源码学习

示例

先通过一个简单的示例,用远程调试的方式对整个执行步骤进行跟踪,并关注其中与WholeStageCodeGen相关的部分。

示例代码为:

1
2
3
4
5
6
7
val ss = getSparkSession("SparkSQLDemo")
val df = ss.read.parquet("/home/demodata/part-00000-90aead26-1478-474c-bb70-d190c5c7500b.snappy.parquet")
df.createOrReplaceTempView("demo_table")
val df2 = ss.sql("select * from demo_table where a='123'")
val res_cnt = df2.collect()
println(res_cnt.mkString(","))
ss.stop()

demo_table是通过读取一个parq文件并注册的临时表,数据包含两列:a和b,如果用explain打印出执行计划:

1
2
3
4
== Physical Plan ==
*Project [a#0, b#1]
+- *Filter (isnotnull(a#0) && (a#0 = 123))
+- *FileScan parquet [a#0,b#1] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/home/demodata/part-00000-90aead26-1478-474c-bb70-d190c5c7500b.snappy.parq..., PartitionFilters: [], PushedFilters: [IsNotNull(a), EqualTo(a,123)], ReadSchema: struct<a:string,b:string>

执行过程

由于整个执行过程是由DataFrame(DataSet)触发的,所以直奔主题,直接查看DataSet类的collect方法,源码如下:

1
2
3
4
5
6
7
8
9
10
11
private def collect(needCallback: Boolean): Array[T] = {
def execute(): Array[T] = withNewExecutionId {
queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
}
if (needCallback) {
withCallback("collect", toDF())(_ => execute())
} else {
execute()
}
}

从这里发现DataSet将真正的collect操作交给了queryExecution.executedPlan处理(如果再深入分析QueryExecution类的源码,会发现这里涉及了一大堆的lazy成员,此处对queryExecution.executedPlan的引用直接触发了物理执行计划的生成,这里仅将物理执行计划的生成过程看成一个黑盒,将重点放在代码生成部分),queryExecution.executedPlan是一个SparkPlan的具体子类的对象,顺着SparkPlan的executeCollect方法再跟下去:

1
2
3
4
5
6
7
8
9
def executeCollect(): Array[InternalRow] = {
val byteArrayRdd = getByteArrayRdd()
val results = ArrayBuffer[InternalRow]()
byteArrayRdd.collect().foreach { bytes =>
decodeUnsafeRows(bytes).foreach(results.+=)
}
results.toArray
}

该方法也很清楚,再顺着getByteArrayRdd方法看下去:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/**
* Packing the UnsafeRows into byte array for faster serialization.
* The byte arrays are in the following format:
* [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
*
* UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
* compressed.
*/
private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = {
execute().mapPartitionsInternal { iter =>
var count = 0
val buffer = new Array[Byte](4 << 10) // 4K
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(codec.compressedOutputStream(bos))
while (iter.hasNext && (n < 0 || count < n)) {
val row = iter.next().asInstanceOf[UnsafeRow]
out.writeInt(row.getSizeInBytes)
row.writeToStream(out, buffer)
count += 1
}
out.writeInt(-1)
out.flush()
out.close()
Iterator(bos.toByteArray)
}
}

该方法主要工作是把UnsafeRows转成byte数组并压缩,唯一可能和WholeStageCodeGen有关系的只有可能是对execute()方法的调用,于是再进入execute():

1
2
3
4
5
6
7
8
9
/**
* Returns the result of this query as an RDD[InternalRow] by delegating to `doExecute` after
* preparations.
*
* Concrete implementations of SparkPlan should override `doExecute`.
*/
final def execute(): RDD[InternalRow] = executeQuery {
doExecute()
}

该方法只是单纯调用SparkPlan具体子类的doExecute()方法。那么问题就来了,这里调用的是哪个子类的doExecute()方法?换句话说上面提到的queryExecution.executedPlan是什么类型?通过调试,发现queryExecution.executedPlan是WholeStageCodeGenExec类型的对象,该类型扩展了UnaryExecNode特质,包含child对象,整个物理执行计划的对象树如下:

1
2
3
4
WholeStageCodeGenExec
--child: ProjectExec
--child: FilterExec
--child: FileSourceScanExec

对照上面用explain打印出来的执行计划,发现整个执行计划被WholeStageCodeGenExec“包装”了,继续看WholeStageCodeGenExec的doExecute()方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
override def doExecute(): RDD[InternalRow] = {
val (ctx, cleanedSource) = doCodeGen()
// try to compile and fallback if it failed
try {
CodeGenerator.compile(cleanedSource)
} catch {
case e: Exception if !Utils.isTesting && sqlContext.conf.wholeStageFallback =>
// We should already saw the error message
logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString")
return child.execute()
}
val references = ctx.references.toArray
val durationMs = longMetric("pipelineTime")
val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
assert(rdds.size <= 2, "Up to two input RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitionsWithIndex { (index, iter) =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
}
}
} else {
// Right now, we support up to two input RDDs.
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
Iterator((leftIter, rightIter))
// a small hack to obtain the correct partition index
}.mapPartitionsWithIndex { (index, zippedIter) =>
val (leftIter, rightIter) = zippedIter.next()
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
}
}
}
}

WholeStageCodeGenExec的doExecute()方法顺序做了这么几件事:

  1. 调用doCodeGen()生成代码以及相关上下文信息(具体生成代码的步骤先忽略,下面再说明)
  2. 尝试编译生成的代码(从CodeGenerator的源码可以看出,编译使用的是Janino库),如果编译失败,则回退到不使用WholeStageCodeGenExec的“传统流程”。
  3. 真正编译代码并通过反射实例化一个GeneratedClass类型的对象(由于CodeGenerator内部有cache,实际的编译和实例化只会进行一次),通过GeneratedClass的generate方法得到继承自BufferedRowIterator的对象,并利用该对象完成实际的执行流程。

如果将cleanedSource打印出来:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
private scala.collection.Iterator[] inputs;
private scala.collection.Iterator scan_input;
private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows;
private org.apache.spark.sql.execution.metric.SQLMetric scan_scanTime;
private long scan_scanTime1;
private org.apache.spark.sql.execution.vectorized.ColumnarBatch scan_batch;
private int scan_batchIdx;
private org.apache.spark.sql.execution.vectorized.ColumnVector scan_colInstance0;
private org.apache.spark.sql.execution.vectorized.ColumnVector scan_colInstance1;
private UnsafeRow scan_result;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder scan_holder;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter scan_rowWriter;
private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows;
private UnsafeRow filter_result;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter;
private UnsafeRow project_result;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter;
public GeneratedIterator(Object[] references) {
this.references = references;
}
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
wholestagecodegen_init_0();
wholestagecodegen_init_1();
}
private void wholestagecodegen_init_0() {
scan_input = inputs[0];
this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
this.scan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
scan_scanTime1 = 0;
scan_batch = null;
scan_batchIdx = 0;
scan_colInstance0 = null;
scan_colInstance1 = null;
scan_result = new UnsafeRow(2);
this.scan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(scan_result, 64);
this.scan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(scan_holder, 2);
this.filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
filter_result = new UnsafeRow(2);
this.filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 64);
this.filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 2);
project_result = new UnsafeRow(2);
this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 64);
}
private void scan_nextBatch() throws java.io.IOException {
long getBatchStart = System.nanoTime();
if (scan_input.hasNext()) {
scan_batch = (org.apache.spark.sql.execution.vectorized.ColumnarBatch)scan_input.next();
scan_numOutputRows.add(scan_batch.numRows());
scan_batchIdx = 0;
scan_colInstance0 = scan_batch.column(0);
scan_colInstance1 = scan_batch.column(1);
}
scan_scanTime1 += System.nanoTime() - getBatchStart;
}
private void wholestagecodegen_init_1() {
this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 2);
}
protected void processNext() throws java.io.IOException {
if (scan_batch == null) {
scan_nextBatch();
}
while (scan_batch != null) {
int numRows = scan_batch.numRows();
while (scan_batchIdx < numRows) {
int scan_rowIdx = scan_batchIdx++;
boolean scan_isNull = scan_colInstance0.isNullAt(scan_rowIdx);
UTF8String scan_value = scan_isNull ? null : (scan_colInstance0.getUTF8String(scan_rowIdx));
if (!(!(scan_isNull))) continue;
boolean filter_isNull2 = false;
Object filter_obj = ((Expression) references[3]).eval(null);
UTF8String filter_value4 = (UTF8String) filter_obj;
boolean filter_value2 = false;
filter_value2 = scan_value.equals(filter_value4);
if (!filter_value2) continue;
filter_numOutputRows.add(1);
boolean scan_isNull1 = scan_colInstance1.isNullAt(scan_rowIdx);
UTF8String scan_value1 = scan_isNull1 ? null : (scan_colInstance1.getUTF8String(scan_rowIdx));
project_holder.reset();
project_rowWriter.zeroOutNullBytes();
project_rowWriter.write(0, scan_value);
if (scan_isNull1) {
project_rowWriter.setNullAt(1);
} else {
project_rowWriter.write(1, scan_value1);
}
project_result.setTotalSize(project_holder.totalSize());
append(project_result);
if (shouldStop()) return;
}
scan_batch = null;
scan_nextBatch();
}
scan_scanTime.add(scan_scanTime1 / (1000 * 1000));
scan_scanTime1 = 0;
}
}

从生成的代码中可以看出generate方法返回了继承自BufferedRowIterator的GeneratedIterator对象,GeneratedIterator类重写了processNext方法,如果仔细阅读这部分代码,会发现该方法中完成了Filter和Project两个步骤的工作,实现了WholeStageCodegenExec类在其注释中说明的——“compile a subtree of plans that support codegen together into single Java function”。

doCodeGen()

前面在看到WholeStageCodeGenExec的doExecute()时,忽略了doCodeGen()方法的细节,这里再结合示例深入看下代码生成的具体实现。该方法代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/**
* Generates code for this subtree.
*
* @return the tuple of the codegen context and the actual generated source.
*/
def doCodeGen(): (CodegenContext, CodeAndComment) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}
${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")}
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
private scala.collection.Iterator[] inputs;
${ctx.declareMutableStates()}
public GeneratedIterator(Object[] references) {
this.references = references;
}
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
${ctx.initMutableStates()}
${ctx.initPartition()}
}
${ctx.declareAddedFunctions()}
protected void processNext() throws java.io.IOException {
${code.trim}
}
}
""".trim
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripOverlappingComments(
new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments()))
logDebug(s"\n${CodeFormatter.format(cleanedSource)}")
(ctx, cleanedSource)
}

从WholeStageCodeGenExec的角度说明doCodeGen的处理流程就是:编写了代码框架(对应代码中的source变量),具体的处理逻辑由child对象的doProduce()方法生成,由于ProjectExec和FilterExec的doProduce()的实现都是简单地将生成代码的工作委托给其child对象完成,所有最终核心逻辑代码其实是由FileSourceScanExec得doProduce()生成的。