From 90efeffb039a4c3458add840f98ea91d01cdc4a8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 4 May 2018 15:03:09 -0700 Subject: [PATCH 1/5] [SPARK-21274] Add a new generator function replicate_rows to support EXCEPT ALL and INTERSECT ALL --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 16 ++++ .../sql/catalyst/expressions/generators.scala | 46 +++++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 25 +++++ .../sql-tests/inputs/udtf_replicate_rows.sql | 29 ++++++ .../results/udtf_replicate_rows.sql.out | 93 +++++++++++++++++++ .../spark/sql/GeneratorFunctionSuite.scala | 31 +++++++ 7 files changed, 241 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 51bb6b0abe40..9580f3930bf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -212,6 +212,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[Stack]("stack"), + expression[ReplicateRows]("replicate_rows"), expression[CaseWhen]("when"), // math functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7..a54fd807adf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -62,6 +62,7 @@ object TypeCoercion { new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: + ReplicateRowsCoercion :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -702,6 +703,21 @@ object TypeCoercion { } } + /** + * Coerces first argument in ReplicateRows expression and introduces a cast to Long + * if necessary. + */ + object ReplicateRowsCoercion extends TypeCoercionRule { + private val acceptedTypes = Seq(LongType, IntegerType, ShortType, ByteType) + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case s @ ReplicateRows(children) + if s.childrenResolved && acceptedTypes.contains(s.children.head.dataType) => + val numRowExpr = s.children.head + val castedExpr = ImplicitTypeCasts.implicitCast(numRowExpr, LongType).getOrElse(numRowExpr) + ReplicateRows(Seq(castedExpr) ++ s.children.tail) + } + } + /** * Coerces the types of [[Concat]] children to expected ones. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45..0d701a2e7a73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ + /** * An expression that produces zero or more rows given a single input row. * @@ -222,6 +223,51 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +/** + * Replicate the row based N times. N is specified as the first argument to the function. + * {{{ + * SELECT replicate_rows(2, "val1", "val2") -> + * 2 val1 val2 + * 2 val1 val2 + * }}} + */ +@ExpressionDescription( +usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `expr1`, ..., `exprk` into `n` rows.", +examples = """ + Examples: + > SELECT _FUNC_(2, "val1", "val2"); + 2 val1 val2 + 2 val1 val2 + """) +case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length < 2) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") + } else if (children.head.dataType != LongType) { + TypeCheckResult.TypeCheckFailure("The number of rows must be a positive long value.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def elementSchema: StructType = + StructType(children.zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val numRows = children.head.eval(input).asInstanceOf[Long] + val values = children.map(_.eval(input)).toArray + Range.Long(0, numRows, 1).map { i => + val fields = new Array[Any](children.length) + for (col <- 0 until children.length) { + fields.update(col, values(col)) + } + InternalRow(fields: _*) + } + } +} + /** * Wrapper around another generator to specify outer behavior. This is used to implement functions * such as explode_outer. This expression gets replaced during analysis. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0acd3b490447..4d0d3815f695 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1353,6 +1353,31 @@ class TypeCoercionSuite extends AnalysisTest { SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)) ) } + + test("type coercion for ReplicateRows") { + val rule = TypeCoercion.ReplicateRowsCoercion + // Cast is setup to promote the first expression to Long + // for numeric types. + ruleTest(rule, + ReplicateRows(Seq(1.toShort, Literal("rowdata"))), + ReplicateRows(Seq(Cast(1.toShort, LongType), Literal("rowdata")))) + ruleTest(rule, + ReplicateRows(Seq(1, Literal("rowdata"))), + ReplicateRows(Seq(Cast(1, LongType), Literal("rowdata")))) + ruleTest(rule, + ReplicateRows(Seq(1.toByte, Literal("rowdata"))), + ReplicateRows(Seq(Cast(1.toByte, LongType), Literal("rowdata")))) + + // No cast here since the expected type is Long. + ruleTest(rule, + ReplicateRows(Seq(1L, Literal("rowdata"))), + ReplicateRows(Seq(1L, Literal("rowdata")))) + + // No type coercion when first expression is a non numeric type. + ruleTest(rule, + ReplicateRows(Seq(Literal("invalid"), Literal("rowdata"))), + ReplicateRows(Seq(Literal("invalid"), Literal("rowdata")))) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql new file mode 100644 index 000000000000..dc2435da61d1 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql @@ -0,0 +1,29 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 'row1', 1.1), + (2, 'row2', 2.2), + (0, 'row3', 3.3), + (-1,'row4', 4.4), + (null,'row5', 5.5), + (3, 'row6', null) + AS tab1(c1, c2, c3); + +-- c1, c2 replicated c1 times +SELECT replicate_rows(c1, c2) FROM tab1; + +-- c1, c2, c2 repeated replicated c1 times +SELECT replicate_rows(c1, c2, c2) FROM tab1; + +-- c1, c2, c2, c3 replicated c1 times +SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1; + +-- Used as a derived table in FROM clause. +SELECT c2, c1 +FROM ( + SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1 +); + +-- column expression. +SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1; + +-- Clean-up +DROP VIEW IF EXISTS tab1; diff --git a/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out new file mode 100644 index 000000000000..9b16a46b36ce --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out @@ -0,0 +1,93 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 'row1', 1.1), + (2, 'row2', 2.2), + (0, 'row3', 3.3), + (-1,'row4', 4.4), + (null,'row5', 5.5), + (3, 'row6', null) + AS tab1(c1, c2, c3) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT replicate_rows(c1, c2) FROM tab1 +-- !query 1 schema +struct +-- !query 1 output +1 row1 +2 row2 +2 row2 +3 row6 +3 row6 +3 row6 + + +-- !query 2 +SELECT replicate_rows(c1, c2, c2) FROM tab1 +-- !query 2 schema +struct +-- !query 2 output +1 row1 row1 +2 row2 row2 +2 row2 row2 +3 row6 row6 +3 row6 row6 +3 row6 row6 + + +-- !query 3 +SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1 +-- !query 3 schema +struct +-- !query 3 output +1 row1 row1 row1 1.1 +2 row2 row2 row2 2.2 +2 row2 row2 row2 2.2 +3 row6 row6 row6 NULL +3 row6 row6 row6 NULL +3 row6 row6 row6 NULL + + +-- !query 4 +SELECT c2, c1 +FROM ( + SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1 +) +-- !query 4 schema +struct +-- !query 4 output +row1 1 +row2 2 +row2 2 +row6 3 +row6 3 +row6 3 + + +-- !query 5 +SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1 +-- !query 5 schema +struct +-- !query 5 output +1 row1... row1 +2 row2... row2 +2 row2... row2 +3 row6... row6 +3 row6... row6 +3 row6... row6 + + +-- !query 6 +DROP VIEW IF EXISTS tab1 +-- !query 6 schema +struct<> +-- !query 6 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec..0941c0983584 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -307,6 +307,37 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { sql("select * from values 1, 2 lateral view outer empty_gen() a as b"), Row(1, null) :: Row(2, null) :: Nil) } + + test("ReplicateRows generator") { + val df = spark.range(1) + + // Empty DataFrame suppress the result generation + checkAnswer(spark.emptyDataFrame.selectExpr("replicate_rows(1, 1, 2, 3)"), Nil) + + checkAnswer(df.selectExpr("replicate_rows(1, 2.5)"), Row(1, 2.5) :: Nil) + checkAnswer(df.selectExpr("replicate_rows(1, null)"), Row(1, null) :: Nil) + checkAnswer(df.selectExpr("replicate_rows(3, 'row1')"), + Row(3, "row1") :: Row(3, "row1") :: Row(3, "row1") :: Nil) + checkAnswer(df.selectExpr("replicate_rows(-1, 2.5)"), Nil) + + // The data for the same column should have the same type. + val msg1 = intercept[AnalysisException] { + df.selectExpr("replicate_rows(1)") + }.getMessage + assert(msg1.contains("requires at least 2 arguments")) + + // The data for the same column should have the same type. + val msg2 = intercept[AnalysisException] { + df.selectExpr("replicate_rows('a', 1)") + }.getMessage + assert(msg2.contains("The number of rows must be a positive long value.")) + + val msg3 = intercept[AnalysisException] { + df.selectExpr("replicate_rows(null, 1)") + }.getMessage + assert(msg3.contains("The number of rows must be a positive long value.")) + + } } case class EmptyGenerator() extends Generator { From 748003ab5b8d9741a6dec79860cc6bef1083c14b Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 5 May 2018 17:00:47 -0700 Subject: [PATCH 2/5] Review comments --- .../sql/catalyst/analysis/TypeCoercion.scala | 9 ++- .../sql/catalyst/expressions/generators.scala | 5 +- .../catalyst/analysis/TypeCoercionSuite.scala | 25 -------- .../native/replicateRowsCoercion.sql | 9 +++ .../sql-tests/inputs/udtf_replicate_rows.sql | 9 +++ .../native/replicateRowsCoercion.sql.out | 43 +++++++++++++ .../results/udtf_replicate_rows.sql.out | 63 +++++++++++++------ .../spark/sql/GeneratorFunctionSuite.scala | 31 --------- 8 files changed, 112 insertions(+), 82 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/replicateRowsCoercion.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/replicateRowsCoercion.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a54fd807adf5..2dda02fc38a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -708,12 +708,11 @@ object TypeCoercion { * if necessary. */ object ReplicateRowsCoercion extends TypeCoercionRule { - private val acceptedTypes = Seq(LongType, IntegerType, ShortType, ByteType) + private val acceptedTypes = Seq(IntegerType, ShortType, ByteType) override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case s @ ReplicateRows(children) - if s.childrenResolved && acceptedTypes.contains(s.children.head.dataType) => - val numRowExpr = s.children.head - val castedExpr = ImplicitTypeCasts.implicitCast(numRowExpr, LongType).getOrElse(numRowExpr) + case s @ ReplicateRows(children) if s.childrenResolved && + s.children.head.dataType != LongType && acceptedTypes.contains(s.children.head.dataType) => + val castedExpr = Cast(s.children.head, LongType) ReplicateRows(Seq(castedExpr) ++ s.children.tail) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 0d701a2e7a73..bba31c7aeb6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ - /** * An expression that produces zero or more rows given a single input row. * @@ -224,7 +223,7 @@ case class Stack(children: Seq[Expression]) extends Generator { } /** - * Replicate the row based N times. N is specified as the first argument to the function. + * Replicate the row N times. N is specified as the first argument to the function. * {{{ * SELECT replicate_rows(2, "val1", "val2") -> * 2 val1 val2 @@ -232,7 +231,7 @@ case class Stack(children: Seq[Expression]) extends Generator { * }}} */ @ExpressionDescription( -usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `expr1`, ..., `exprk` into `n` rows.", +usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `n`, `expr1`, ..., `exprk` into `n` rows.", examples = """ Examples: > SELECT _FUNC_(2, "val1", "val2"); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 4d0d3815f695..0acd3b490447 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1353,31 +1353,6 @@ class TypeCoercionSuite extends AnalysisTest { SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)) ) } - - test("type coercion for ReplicateRows") { - val rule = TypeCoercion.ReplicateRowsCoercion - // Cast is setup to promote the first expression to Long - // for numeric types. - ruleTest(rule, - ReplicateRows(Seq(1.toShort, Literal("rowdata"))), - ReplicateRows(Seq(Cast(1.toShort, LongType), Literal("rowdata")))) - ruleTest(rule, - ReplicateRows(Seq(1, Literal("rowdata"))), - ReplicateRows(Seq(Cast(1, LongType), Literal("rowdata")))) - ruleTest(rule, - ReplicateRows(Seq(1.toByte, Literal("rowdata"))), - ReplicateRows(Seq(Cast(1.toByte, LongType), Literal("rowdata")))) - - // No cast here since the expected type is Long. - ruleTest(rule, - ReplicateRows(Seq(1L, Literal("rowdata"))), - ReplicateRows(Seq(1L, Literal("rowdata")))) - - // No type coercion when first expression is a non numeric type. - ruleTest(rule, - ReplicateRows(Seq(Literal("invalid"), Literal("rowdata"))), - ReplicateRows(Seq(Literal("invalid"), Literal("rowdata")))) - } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/replicateRowsCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/replicateRowsCoercion.sql new file mode 100644 index 000000000000..5cde6a9abcb9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/replicateRowsCoercion.sql @@ -0,0 +1,9 @@ +SELECT replicate_rows(CAST(1 AS BYTE), 1); + +SELECT replicate_rows(CAST(1 AS INT), 1); + +SELECT replicate_rows(CAST(1 AS LONG), 1); + +SELECT replicate_rows(CAST(1 AS SHORT), 1); + +SELECT replicate_rows("abcd", 1); diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql index dc2435da61d1..d3a33700d2a6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql @@ -7,6 +7,15 @@ CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES (3, 'row6', null) AS tab1(c1, c2, c3); +-- Requires 2 arguments at minimum. +SELECT replicate_rows(c1) FROM tab1; + +-- First argument should be a numeric type. +SELECT replicate_rows("abcd", c2) FROM tab1; + +-- untyped null first argument +SELECT replicate_rows(null, c2) FROM tab1; + -- c1, c2 replicated c1 times SELECT replicate_rows(c1, c2) FROM tab1; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/replicateRowsCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/replicateRowsCoercion.sql.out new file mode 100644 index 000000000000..024aee350c1c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/replicateRowsCoercion.sql.out @@ -0,0 +1,43 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +SELECT replicate_rows(CAST(1 AS BYTE), 1) +-- !query 0 schema +struct +-- !query 0 output +1 1 + + +-- !query 1 +SELECT replicate_rows(CAST(1 AS INT), 1) +-- !query 1 schema +struct +-- !query 1 output +1 1 + + +-- !query 2 +SELECT replicate_rows(CAST(1 AS LONG), 1) +-- !query 2 schema +struct +-- !query 2 output +1 1 + + +-- !query 3 +SELECT replicate_rows(CAST(1 AS SHORT), 1) +-- !query 3 schema +struct +-- !query 3 output +1 1 + + +-- !query 4 +SELECT replicate_rows("abcd", 1) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows('abcd', 1)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out index 9b16a46b36ce..5c5255cf7458 100644 --- a/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 10 -- !query 0 @@ -18,10 +18,37 @@ struct<> -- !query 1 -SELECT replicate_rows(c1, c2) FROM tab1 +SELECT replicate_rows(c1) FROM tab1 -- !query 1 schema -struct +struct<> -- !query 1 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows(CAST(tab1.`c1` AS BIGINT))' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7 + + +-- !query 2 +SELECT replicate_rows("abcd", c2) FROM tab1 +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows('abcd', tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 + + +-- !query 3 +SELECT replicate_rows(null, c2) FROM tab1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows(NULL, tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 + + +-- !query 4 +SELECT replicate_rows(c1, c2) FROM tab1 +-- !query 4 schema +struct +-- !query 4 output 1 row1 2 row2 2 row2 @@ -30,11 +57,11 @@ struct 3 row6 --- !query 2 +-- !query 5 SELECT replicate_rows(c1, c2, c2) FROM tab1 --- !query 2 schema +-- !query 5 schema struct --- !query 2 output +-- !query 5 output 1 row1 row1 2 row2 row2 2 row2 row2 @@ -43,11 +70,11 @@ struct 3 row6 row6 --- !query 3 +-- !query 6 SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1 --- !query 3 schema +-- !query 6 schema struct --- !query 3 output +-- !query 6 output 1 row1 row1 row1 1.1 2 row2 row2 row2 2.2 2 row2 row2 row2 2.2 @@ -56,14 +83,14 @@ struct 3 row6 row6 row6 NULL --- !query 4 +-- !query 7 SELECT c2, c1 FROM ( SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1 ) --- !query 4 schema +-- !query 7 schema struct --- !query 4 output +-- !query 7 output row1 1 row2 2 row2 2 @@ -72,11 +99,11 @@ row6 3 row6 3 --- !query 5 +-- !query 8 SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1 --- !query 5 schema +-- !query 8 schema struct --- !query 5 output +-- !query 8 output 1 row1... row1 2 row2... row2 2 row2... row2 @@ -85,9 +112,9 @@ struct 3 row6... row6 --- !query 6 +-- !query 9 DROP VIEW IF EXISTS tab1 --- !query 6 schema +-- !query 9 schema struct<> --- !query 6 output +-- !query 9 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 0941c0983584..109fcf90a3ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -307,37 +307,6 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { sql("select * from values 1, 2 lateral view outer empty_gen() a as b"), Row(1, null) :: Row(2, null) :: Nil) } - - test("ReplicateRows generator") { - val df = spark.range(1) - - // Empty DataFrame suppress the result generation - checkAnswer(spark.emptyDataFrame.selectExpr("replicate_rows(1, 1, 2, 3)"), Nil) - - checkAnswer(df.selectExpr("replicate_rows(1, 2.5)"), Row(1, 2.5) :: Nil) - checkAnswer(df.selectExpr("replicate_rows(1, null)"), Row(1, null) :: Nil) - checkAnswer(df.selectExpr("replicate_rows(3, 'row1')"), - Row(3, "row1") :: Row(3, "row1") :: Row(3, "row1") :: Nil) - checkAnswer(df.selectExpr("replicate_rows(-1, 2.5)"), Nil) - - // The data for the same column should have the same type. - val msg1 = intercept[AnalysisException] { - df.selectExpr("replicate_rows(1)") - }.getMessage - assert(msg1.contains("requires at least 2 arguments")) - - // The data for the same column should have the same type. - val msg2 = intercept[AnalysisException] { - df.selectExpr("replicate_rows('a', 1)") - }.getMessage - assert(msg2.contains("The number of rows must be a positive long value.")) - - val msg3 = intercept[AnalysisException] { - df.selectExpr("replicate_rows(null, 1)") - }.getMessage - assert(msg3.contains("The number of rows must be a positive long value.")) - - } } case class EmptyGenerator() extends Generator { From 02ed0582348f12473fdde8779c1d9e59ecfd84b1 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 5 May 2018 20:37:23 -0700 Subject: [PATCH 3/5] Review comments --- .../sql/catalyst/analysis/TypeCoercion.scala | 2 +- .../sql-tests/inputs/udtf_replicate_rows.sql | 3 + .../results/udtf_replicate_rows.sql.out | 57 +++++++++++-------- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2dda02fc38a0..ea7318f90734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -710,7 +710,7 @@ object TypeCoercion { object ReplicateRowsCoercion extends TypeCoercionRule { private val acceptedTypes = Seq(IntegerType, ShortType, ByteType) override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case s @ ReplicateRows(children) if s.childrenResolved && + case s @ ReplicateRows(children) if s.children.nonEmpty && s.childrenResolved && s.children.head.dataType != LongType && acceptedTypes.contains(s.children.head.dataType) => val castedExpr = Cast(s.children.head, LongType) ReplicateRows(Seq(castedExpr) ++ s.children.tail) diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql index d3a33700d2a6..881f2082c190 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql @@ -7,6 +7,9 @@ CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES (3, 'row6', null) AS tab1(c1, c2, c3); +-- Requires 2 arguments at minimum. +SELECT replicate_rows() FROM tab1; + -- Requires 2 arguments at minimum. SELECT replicate_rows(c1) FROM tab1; diff --git a/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out index 5c5255cf7458..9fae15f3741f 100644 --- a/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 11 -- !query 0 @@ -18,37 +18,46 @@ struct<> -- !query 1 -SELECT replicate_rows(c1) FROM tab1 +SELECT replicate_rows() FROM tab1 -- !query 1 schema struct<> -- !query 1 output org.apache.spark.sql.AnalysisException -cannot resolve 'replicaterows(CAST(tab1.`c1` AS BIGINT))' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7 +cannot resolve 'replicaterows()' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7 -- !query 2 -SELECT replicate_rows("abcd", c2) FROM tab1 +SELECT replicate_rows(c1) FROM tab1 -- !query 2 schema struct<> -- !query 2 output org.apache.spark.sql.AnalysisException -cannot resolve 'replicaterows('abcd', tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 +cannot resolve 'replicaterows(CAST(tab1.`c1` AS BIGINT))' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7 -- !query 3 -SELECT replicate_rows(null, c2) FROM tab1 +SELECT replicate_rows("abcd", c2) FROM tab1 -- !query 3 schema struct<> -- !query 3 output org.apache.spark.sql.AnalysisException -cannot resolve 'replicaterows(NULL, tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 +cannot resolve 'replicaterows('abcd', tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 -- !query 4 -SELECT replicate_rows(c1, c2) FROM tab1 +SELECT replicate_rows(null, c2) FROM tab1 -- !query 4 schema -struct +struct<> -- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows(NULL, tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 + + +-- !query 5 +SELECT replicate_rows(c1, c2) FROM tab1 +-- !query 5 schema +struct +-- !query 5 output 1 row1 2 row2 2 row2 @@ -57,11 +66,11 @@ struct 3 row6 --- !query 5 +-- !query 6 SELECT replicate_rows(c1, c2, c2) FROM tab1 --- !query 5 schema +-- !query 6 schema struct --- !query 5 output +-- !query 6 output 1 row1 row1 2 row2 row2 2 row2 row2 @@ -70,11 +79,11 @@ struct 3 row6 row6 --- !query 6 +-- !query 7 SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1 --- !query 6 schema +-- !query 7 schema struct --- !query 6 output +-- !query 7 output 1 row1 row1 row1 1.1 2 row2 row2 row2 2.2 2 row2 row2 row2 2.2 @@ -83,14 +92,14 @@ struct 3 row6 row6 row6 NULL --- !query 7 +-- !query 8 SELECT c2, c1 FROM ( SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1 ) --- !query 7 schema +-- !query 8 schema struct --- !query 7 output +-- !query 8 output row1 1 row2 2 row2 2 @@ -99,11 +108,11 @@ row6 3 row6 3 --- !query 8 +-- !query 9 SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1 --- !query 8 schema +-- !query 9 schema struct --- !query 8 output +-- !query 9 output 1 row1... row1 2 row2... row2 2 row2... row2 @@ -112,9 +121,9 @@ struct 3 row6... row6 --- !query 9 +-- !query 10 DROP VIEW IF EXISTS tab1 --- !query 9 schema +-- !query 10 schema struct<> --- !query 9 output +-- !query 10 output From 17610689595f02a30730c0fc1a070c3652eabf7e Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 5 May 2018 23:55:28 -0700 Subject: [PATCH 4/5] more comments --- .../org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index ea7318f90734..5a50169918a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -711,7 +711,7 @@ object TypeCoercion { private val acceptedTypes = Seq(IntegerType, ShortType, ByteType) override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s @ ReplicateRows(children) if s.children.nonEmpty && s.childrenResolved && - s.children.head.dataType != LongType && acceptedTypes.contains(s.children.head.dataType) => + acceptedTypes.contains(s.children.head.dataType) => val castedExpr = Cast(s.children.head, LongType) ReplicateRows(Seq(castedExpr) ++ s.children.tail) } From 4ab3af0c1abfd0ac078c968dbe589baaaaf96091 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sun, 6 May 2018 17:52:51 -0700 Subject: [PATCH 5/5] fix --- .../spark/sql/catalyst/expressions/generators.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index bba31c7aeb6c..7e947b7e0a50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -239,8 +239,10 @@ examples = """ 2 val1 val2 """) case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { + private lazy val numColumns = children.length + override def checkInputDataTypes(): TypeCheckResult = { - if (children.length < 2) { + if (numColumns < 2) { TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") } else if (children.head.dataType != LongType) { TypeCheckResult.TypeCheckFailure("The number of rows must be a positive long value.") @@ -256,11 +258,12 @@ case class ReplicateRows(children: Seq[Expression]) extends Generator with Codeg override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val numRows = children.head.eval(input).asInstanceOf[Long] - val values = children.map(_.eval(input)).toArray + val values = children.tail.map(_.eval(input)).toArray Range.Long(0, numRows, 1).map { i => - val fields = new Array[Any](children.length) - for (col <- 0 until children.length) { - fields.update(col, values(col)) + val fields = new Array[Any](numColumns) + fields.update(0, numRows) + for (col <- 1 until numColumns) { + fields.update(col, values(col - 1)) } InternalRow(fields: _*) }