Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
Expand All @@ -33,13 +33,24 @@ case class CreateArray(children: Seq[Expression]) extends Expression {

override def foldable: Boolean = children.forall(_.foldable)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
override def checkInputDataTypes(): TypeCheckResult = {
if (children.map(_.dataType).forall(_.isInstanceOf[DecimalType])) {
TypeCheckResult.TypeCheckSuccess
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cannot just make the check pass. We need to need to actually cast those element to the same prevision and scale.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if we access a single element, its data type actually may not be the one shown as the array's datatype.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for review, @yhuai .
I see. I'll check that more.

Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Jul 26, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @yhuai . I checked the following.

scala> sql("select a[0], a[1] from (select array(0.001, 0.02) a) T")
res4: org.apache.spark.sql.DataFrame = [a[0]: decimal(3,3), a[1]: decimal(3,3)]

scala> sql("select a[0], a[1] from (select array(0.001, 0.02) a) T").show()
+-----+-----+
| a[0]| a[1]|
+-----+-----+
|0.001|0.020|
+-----+-----+

scala> sql("select a[0], a[1] from (select array(0.001, 0.02) a) T").explain(true)
== Parsed Logical Plan ==
'Project [unresolvedalias('a[0], None), unresolvedalias('a[1], None)]
+- 'SubqueryAlias T
   +- 'Project ['array(0.001, 0.02) AS a#54]
      +- OneRowRelation$

== Analyzed Logical Plan ==
a[0]: decimal(3,3), a[1]: decimal(3,3)
Project [a#54[0] AS a[0]#61, a#54[1] AS a[1]#62]
+- SubqueryAlias T
   +- Project [array(0.001, 0.02) AS a#54]
      +- OneRowRelation$

Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Jul 26, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> sql("create table d1(a DECIMAL(3,2))")
scala> sql("create table d2(a DECIMAL(2,1))")
scala> sql("insert into d1 values(1.0)")
scala> sql("insert into d2 values(1.0)")
scala> sql("select * from d1, d2").show()
+----+---+
|   a|  a|
+----+---+
|1.00|1.0|
+----+---+

scala> sql("select array(d1.a,d2.a),array(d2.a,d1.a),* from d1, d2")
res5: org.apache.spark.sql.DataFrame = [array(a, a): array<decimal(3,2)>, array(a, a): array<decimal(3,2)> ... 2 more fields]

scala> sql("select array(d1.a,d2.a),array(d2.a,d1.a),* from d1, d2").show()
+------------+------------+----+---+
| array(a, a)| array(a, a)|   a|  a|
+------------+------------+----+---+
|[1.00, 1.00]|[1.00, 1.00]|1.00|1.0|
+------------+------------+----+---+

scala> sql("select array(d1.a,d2.a)[0],array(d2.a,d1.a)[0],* from d1, d2").show()
+--------------+--------------+----+---+
|array(a, a)[0]|array(a, a)[0]|   a|  a|
+--------------+--------------+----+---+
|          1.00|          1.00|1.00|1.0|
+--------------+--------------+----+---+

scala> sql("select array(d1.a,d2.a)[1],array(d2.a,d1.a)[1],* from d1, d2").show()
+--------------+--------------+----+---+
|array(a, a)[1]|array(a, a)[1]|   a|  a|
+--------------+--------------+----+---+
|          1.00|          1.00|1.00|1.0|
+--------------+--------------+----+---+

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Finally, the following is the codegen result. Please see the line 29.

scala> sql("explain codegen select array(0.001, 0.02)[1]").collect().foreach(println)
[Found 1 WholeStageCodegen subtrees.
== Subtree 1 / 1 ==
*Project [0.02 AS array(0.001, 0.02)[1]#75]
+- Scan OneRowRelation[]

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator inputadapter_input;
/* 008 */   private UnsafeRow project_result;
/* 009 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
/* 010 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter;
/* 011 */
/* 012 */   public GeneratedIterator(Object[] references) {
/* 013 */     this.references = references;
/* 014 */   }
/* 015 */
/* 016 */   public void init(int index, scala.collection.Iterator inputs[]) {
/* 017 */     partitionIndex = index;
/* 018 */     inputadapter_input = inputs[0];
/* 019 */     project_result = new UnsafeRow(1);
/* 020 */     this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 0);
/* 021 */     this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 1);
/* 022 */   }
/* 023 */
/* 024 */   protected void processNext() throws java.io.IOException {
/* 025 */     while (inputadapter_input.hasNext()) {
/* 026 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 027 */       Object project_obj = ((Expression) references[0]).eval(null);
/* 028 */       Decimal project_value = (Decimal) project_obj;
/* 029 */       project_rowWriter.write(0, project_value, 3, 3);
/* 030 */       append(project_result);
/* 031 */       if (shouldStop()) return;
/* 032 */     }
/* 033 */   }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short, those are recognized correctly in the Analyzed Logical Plan. As a result, the codegen correctly writes it with the unified precision and scale.

== Analyzed Logical Plan ==
a[0]: decimal(3,3), a[1]: decimal(3,3)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything to check more?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @yhuai .
Could you give me some advice?

} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
}
}

override def dataType: DataType = {
ArrayType(
children.headOption.map(_.dataType).getOrElse(NullType),
containsNull = children.exists(_.nullable))
var elementType: DataType = children.headOption.map(_.dataType).getOrElse(NullType)
if (elementType.isInstanceOf[DecimalType]) {
children.foreach { child =>
if (elementType.asInstanceOf[DecimalType].isTighterThan(child.dataType)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this suffers from the same issue as the map pr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @rxin .
Yep. I've read you comment about the lose.
I'll check that and revise.

elementType = child.dataType
}
}
}
ArrayType(elementType, containsNull = children.exists(_.nullable))
}

override def nullable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -134,6 +134,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
}

test("SPARK-16714: CreateArray with Decimals") {
val array1 = CreateArray(Seq(Literal(Decimal(0.001)), Literal(Decimal(0.02))))
val array2 = CreateArray(Seq(Literal(Decimal(0.02)), Literal(Decimal(0.001))))

assert(array1.checkInputDataTypes() == TypeCheckResult.TypeCheckSuccess)
assert(array2.checkInputDataTypes() == TypeCheckResult.TypeCheckSuccess)
assert(array1.dataType == array2.dataType)

checkEvaluation(array1, Seq(Decimal(0.001), Decimal(0.02)))
checkEvaluation(array2, Seq(Decimal(0.02), Decimal(0.001)))
}

test("CreateMap") {
def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = {
keys.zip(values).flatMap { case (k, v) => Seq(k, v) }
Expand Down