diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 51bb6b0abe40..9580f3930bf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -212,6 +212,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[Stack]("stack"), + expression[ReplicateRows]("replicate_rows"), expression[CaseWhen]("when"), // math functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index b2817b0538a7..5a50169918a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -62,6 +62,7 @@ object TypeCoercion { new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: + ReplicateRowsCoercion :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -702,6 +703,20 @@ object TypeCoercion { } } + /** + * Coerces first argument in ReplicateRows expression and introduces a cast to Long + * if necessary. + */ + object ReplicateRowsCoercion extends TypeCoercionRule { + private val acceptedTypes = Seq(IntegerType, ShortType, ByteType) + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case s @ ReplicateRows(children) if s.children.nonEmpty && s.childrenResolved && + acceptedTypes.contains(s.children.head.dataType) => + val castedExpr = Cast(s.children.head, LongType) + ReplicateRows(Seq(castedExpr) ++ s.children.tail) + } + } + /** * Coerces the types of [[Concat]] children to expected ones. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45..7e947b7e0a50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -222,6 +222,54 @@ case class Stack(children: Seq[Expression]) extends Generator { } } +/** + * Replicate the row N times. N is specified as the first argument to the function. + * {{{ + * SELECT replicate_rows(2, "val1", "val2") -> + * 2 val1 val2 + * 2 val1 val2 + * }}} + */ +@ExpressionDescription( +usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `n`, `expr1`, ..., `exprk` into `n` rows.", +examples = """ + Examples: + > SELECT _FUNC_(2, "val1", "val2"); + 2 val1 val2 + 2 val1 val2 + """) +case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback { + private lazy val numColumns = children.length + + override def checkInputDataTypes(): TypeCheckResult = { + if (numColumns < 2) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") + } else if (children.head.dataType != LongType) { + TypeCheckResult.TypeCheckFailure("The number of rows must be a positive long value.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def elementSchema: StructType = + StructType(children.zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val numRows = children.head.eval(input).asInstanceOf[Long] + val values = children.tail.map(_.eval(input)).toArray + Range.Long(0, numRows, 1).map { i => + val fields = new Array[Any](numColumns) + fields.update(0, numRows) + for (col <- 1 until numColumns) { + fields.update(col, values(col - 1)) + } + InternalRow(fields: _*) + } + } +} + /** * Wrapper around another generator to specify outer behavior. This is used to implement functions * such as explode_outer. This expression gets replaced during analysis. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/replicateRowsCoercion.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/replicateRowsCoercion.sql new file mode 100644 index 000000000000..5cde6a9abcb9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/replicateRowsCoercion.sql @@ -0,0 +1,9 @@ +SELECT replicate_rows(CAST(1 AS BYTE), 1); + +SELECT replicate_rows(CAST(1 AS INT), 1); + +SELECT replicate_rows(CAST(1 AS LONG), 1); + +SELECT replicate_rows(CAST(1 AS SHORT), 1); + +SELECT replicate_rows("abcd", 1); diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql new file mode 100644 index 000000000000..881f2082c190 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf_replicate_rows.sql @@ -0,0 +1,41 @@ +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 'row1', 1.1), + (2, 'row2', 2.2), + (0, 'row3', 3.3), + (-1,'row4', 4.4), + (null,'row5', 5.5), + (3, 'row6', null) + AS tab1(c1, c2, c3); + +-- Requires 2 arguments at minimum. +SELECT replicate_rows() FROM tab1; + +-- Requires 2 arguments at minimum. +SELECT replicate_rows(c1) FROM tab1; + +-- First argument should be a numeric type. +SELECT replicate_rows("abcd", c2) FROM tab1; + +-- untyped null first argument +SELECT replicate_rows(null, c2) FROM tab1; + +-- c1, c2 replicated c1 times +SELECT replicate_rows(c1, c2) FROM tab1; + +-- c1, c2, c2 repeated replicated c1 times +SELECT replicate_rows(c1, c2, c2) FROM tab1; + +-- c1, c2, c2, c3 replicated c1 times +SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1; + +-- Used as a derived table in FROM clause. +SELECT c2, c1 +FROM ( + SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1 +); + +-- column expression. +SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1; + +-- Clean-up +DROP VIEW IF EXISTS tab1; diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/replicateRowsCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/replicateRowsCoercion.sql.out new file mode 100644 index 000000000000..024aee350c1c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/replicateRowsCoercion.sql.out @@ -0,0 +1,43 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +SELECT replicate_rows(CAST(1 AS BYTE), 1) +-- !query 0 schema +struct +-- !query 0 output +1 1 + + +-- !query 1 +SELECT replicate_rows(CAST(1 AS INT), 1) +-- !query 1 schema +struct +-- !query 1 output +1 1 + + +-- !query 2 +SELECT replicate_rows(CAST(1 AS LONG), 1) +-- !query 2 schema +struct +-- !query 2 output +1 1 + + +-- !query 3 +SELECT replicate_rows(CAST(1 AS SHORT), 1) +-- !query 3 schema +struct +-- !query 3 output +1 1 + + +-- !query 4 +SELECT replicate_rows("abcd", 1) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows('abcd', 1)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out new file mode 100644 index 000000000000..9fae15f3741f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udtf_replicate_rows.sql.out @@ -0,0 +1,129 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES + (1, 'row1', 1.1), + (2, 'row2', 2.2), + (0, 'row3', 3.3), + (-1,'row4', 4.4), + (null,'row5', 5.5), + (3, 'row6', null) + AS tab1(c1, c2, c3) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT replicate_rows() FROM tab1 +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows()' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7 + + +-- !query 2 +SELECT replicate_rows(c1) FROM tab1 +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows(CAST(tab1.`c1` AS BIGINT))' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7 + + +-- !query 3 +SELECT replicate_rows("abcd", c2) FROM tab1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows('abcd', tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 + + +-- !query 4 +SELECT replicate_rows(null, c2) FROM tab1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +cannot resolve 'replicaterows(NULL, tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7 + + +-- !query 5 +SELECT replicate_rows(c1, c2) FROM tab1 +-- !query 5 schema +struct +-- !query 5 output +1 row1 +2 row2 +2 row2 +3 row6 +3 row6 +3 row6 + + +-- !query 6 +SELECT replicate_rows(c1, c2, c2) FROM tab1 +-- !query 6 schema +struct +-- !query 6 output +1 row1 row1 +2 row2 row2 +2 row2 row2 +3 row6 row6 +3 row6 row6 +3 row6 row6 + + +-- !query 7 +SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1 +-- !query 7 schema +struct +-- !query 7 output +1 row1 row1 row1 1.1 +2 row2 row2 row2 2.2 +2 row2 row2 row2 2.2 +3 row6 row6 row6 NULL +3 row6 row6 row6 NULL +3 row6 row6 row6 NULL + + +-- !query 8 +SELECT c2, c1 +FROM ( + SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1 +) +-- !query 8 schema +struct +-- !query 8 output +row1 1 +row2 2 +row2 2 +row6 3 +row6 3 +row6 3 + + +-- !query 9 +SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1 +-- !query 9 schema +struct +-- !query 9 output +1 row1... row1 +2 row2... row2 +2 row2... row2 +3 row6... row6 +3 row6... row6 +3 row6... row6 + + +-- !query 10 +DROP VIEW IF EXISTS tab1 +-- !query 10 schema +struct<> +-- !query 10 output +