Skip to content

Commit 68213fb

Browse files
committed
Unify with UnsafeProjectionCreator.
1 parent 36f90cf commit 68213fb

File tree

5 files changed

+73
-49
lines changed

5 files changed

+73
-49
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodegenObjectFactory.scala

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,49 +31,37 @@ object CodegenError {
3131
}
3232
}
3333

34+
trait CodegenObjectFactoryBase[IN, OUT] {
35+
protected def createObject(in: IN): OUT
36+
protected def createCodeGeneratedObject(in: IN): OUT
37+
protected def createInterpretedObject(in: IN): OUT
38+
}
39+
3440
/**
35-
* A factory class which can be used to create objects that have both codegen and interpreted
41+
* A factory which can be used to create objects that have both codegen and interpreted
3642
* implementations. This tries to create codegen object first, if any compile error happens,
3743
* it fallbacks to interpreted version.
3844
*/
39-
abstract class CodegenObjectFactory[IN, OUT] {
40-
41-
def createObject(in: IN): OUT = try {
45+
trait CodegenObjectFactory[IN, OUT] extends CodegenObjectFactoryBase[IN, OUT] {
46+
override protected def createObject(in: IN): OUT = try {
4247
createCodeGeneratedObject(in)
4348
} catch {
4449
case CodegenError(_) => createInterpretedObject(in)
4550
}
46-
47-
protected def createCodeGeneratedObject(in: IN): OUT
48-
protected def createInterpretedObject(in: IN): OUT
4951
}
5052

51-
object UnsafeProjectionFactory extends CodegenObjectFactory[Seq[Expression], UnsafeProjection]
52-
with UnsafeProjectionCreator {
53-
54-
override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
55-
UnsafeProjection.createProjection(in)
56-
}
57-
58-
override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
59-
InterpretedUnsafeProjection.createProjection(in)
60-
}
61-
62-
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection =
63-
createObject(exprs)
64-
65-
/**
66-
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
67-
* The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
68-
* when fallbacking to interpreted execution, it is not supported.
69-
*/
70-
def create(
71-
exprs: Seq[Expression],
72-
inputSchema: Seq[Attribute],
73-
subexpressionEliminationEnabled: Boolean): UnsafeProjection = try {
74-
UnsafeProjection.create(exprs, inputSchema, subexpressionEliminationEnabled)
75-
} catch {
76-
case CodegenError(_) => InterpretedUnsafeProjection.create(exprs, inputSchema)
77-
}
53+
/**
54+
* A factory which can be used to create codegen objects without fallback to interpreted version.
55+
*/
56+
trait CodegenObjectFactoryWithoutFallback[IN, OUT] extends CodegenObjectFactoryBase[IN, OUT] {
57+
override protected def createObject(in: IN): OUT =
58+
createCodeGeneratedObject(in)
7859
}
7960

61+
/**
62+
* A factory which can be used to create objects with interpreted implementation.
63+
*/
64+
trait InterpretedCodegenObjectFactory[IN, OUT] extends CodegenObjectFactoryBase[IN, OUT] {
65+
override protected def createObject(in: IN): OUT =
66+
createInterpretedObject(in)
67+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,11 @@ class InterpretedUnsafeProjection(expressions: Array[Expression]) extends Unsafe
8787
/**
8888
* Helper functions for creating an [[InterpretedUnsafeProjection]].
8989
*/
90-
object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
91-
90+
object InterpretedUnsafeProjection {
9291
/**
9392
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
9493
*/
95-
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
94+
protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
9695
// We need to make sure that we do not reuse stateful expressions.
9796
val cleanedExpressions = exprs.map(_.transform {
9897
case s: Stateful => s.freshCopy()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,20 @@ abstract class UnsafeProjection extends Projection {
108108
override def apply(row: InternalRow): UnsafeRow
109109
}
110110

111-
trait UnsafeProjectionCreator {
111+
/**
112+
* The factory base class for `UnsafeProjection`.
113+
*/
114+
abstract class UnsafeProjectionCreator extends
115+
CodegenObjectFactoryBase[Seq[Expression], UnsafeProjection] {
116+
117+
override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
118+
GenerateUnsafeProjection.generate(in)
119+
}
120+
121+
override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
122+
InterpretedUnsafeProjection.createProjection(in)
123+
}
124+
112125
protected def toBoundExprs(
113126
exprs: Seq[Expression],
114127
inputSchema: Seq[Attribute]): Seq[Expression] = {
@@ -157,18 +170,36 @@ trait UnsafeProjectionCreator {
157170
/**
158171
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
159172
*/
160-
protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection
173+
protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection =
174+
createObject(exprs)
161175
}
162176

163-
object UnsafeProjection extends UnsafeProjectionCreator {
164-
165-
override protected[sql] def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
166-
GenerateUnsafeProjection.generate(exprs)
177+
// A `UnsafeProjectionCreator` can fallback to interpreted version if codegen compile error happens.
178+
object UnsafeProjection extends UnsafeProjectionCreator
179+
with CodegenObjectFactory[Seq[Expression], UnsafeProjection] {
180+
/**
181+
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
182+
* The param `subexpressionEliminationEnabled` doesn't guarantee to work. For example,
183+
* when fallbacking to interpreted execution, it is not supported.
184+
*/
185+
def create(
186+
exprs: Seq[Expression],
187+
inputSchema: Seq[Attribute],
188+
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
189+
val unsafeExprs = toUnsafeExprs(toBoundExprs(exprs, inputSchema))
190+
try {
191+
GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled)
192+
} catch {
193+
case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs)
194+
}
167195
}
196+
}
168197

198+
// Codegen version `UnsafeProjectionCreator`.
199+
object CodegenUnsafeProjectionCreator extends UnsafeProjectionCreator
200+
with CodegenObjectFactoryWithoutFallback[Seq[Expression], UnsafeProjection] {
169201
/**
170202
* Same as other create()'s but allowing enabling/disabling subexpression elimination.
171-
* TODO: refactor the plumbing and clean this up.
172203
*/
173204
def create(
174205
exprs: Seq[Expression],
@@ -179,6 +210,10 @@ object UnsafeProjection extends UnsafeProjectionCreator {
179210
}
180211
}
181212

213+
// Interpreted version `UnsafeProjectionCreator`.
214+
object InterpretedUnsafeProjectionCreator extends UnsafeProjectionCreator
215+
with InterpretedCodegenObjectFactory[Seq[Expression], UnsafeProjection]
216+
182217
/**
183218
* A projection that could turn UnsafeRow into GenericInternalRow
184219
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,10 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
196196
expression: Expression,
197197
expected: Any,
198198
inputRow: InternalRow = EmptyRow): Unit = {
199-
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, UnsafeProjection)
200-
checkEvaluationWithUnsafeProjection(expression, expected, inputRow, InterpretedUnsafeProjection)
199+
checkEvaluationWithUnsafeProjection(expression, expected, inputRow,
200+
CodegenUnsafeProjectionCreator)
201+
checkEvaluationWithUnsafeProjection(expression, expected, inputRow,
202+
InterpretedUnsafeProjectionCreator)
201203
}
202204

203205
protected def checkEvaluationWithUnsafeProjection(
@@ -228,7 +230,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
228230
protected def evaluateWithUnsafeProjection(
229231
expression: Expression,
230232
inputRow: InternalRow = EmptyRow,
231-
factory: UnsafeProjectionCreator = UnsafeProjection): InternalRow = {
233+
factory: UnsafeProjectionCreator = CodegenUnsafeProjectionCreator): InternalRow = {
232234
// SPARK-16489 Explicitly doing code generation twice so code gen will fail if
233235
// some expression is reusing variable names across different instances.
234236
// This behavior is tested in ExpressionEvalHelperSuite.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
3737
name: String)(
3838
f: UnsafeProjectionCreator => Unit): Unit = {
3939
test(name) {
40-
f(UnsafeProjection)
41-
f(InterpretedUnsafeProjection)
40+
f(CodegenUnsafeProjectionCreator)
41+
f(InterpretedUnsafeProjectionCreator)
4242
}
4343
}
4444

0 commit comments

Comments
 (0)