From c7db03425bd70d354dcab67c7578c4d6d9b905c7 Mon Sep 17 00:00:00 2001 From: philo Date: Fri, 22 Jul 2022 10:27:36 +0800 Subject: [PATCH] Cast short type to int32 --- .../intel/oap/expression/CodeGeneration.scala | 16 +++++++++++++++- .../intel/oap/expression/ColumnarLiterals.scala | 7 +++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/CodeGeneration.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/CodeGeneration.scala index d68a40e3c..36f42987b 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/CodeGeneration.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/CodeGeneration.scala @@ -32,7 +32,19 @@ object CodeGeneration { def getResultType(left: ArrowType, right: ArrowType): ArrowType = { //TODO(): remove this API - left + // Use left type except that left is int16. If both right & left are int16, + // int32 will be used. + left match { + case intLeft: ArrowType.Int if (intLeft.getBitWidth == 16) => + right match { + case intRight: ArrowType.Int if (intRight.getBitWidth == 16) => + new ArrowType.Int(32, true) + case _ => + right + } + case _ => + left + } } def getResultType(dataType: DataType): ArrowType = { @@ -81,6 +93,8 @@ object CodeGeneration { dataType match { case t: ArrowType.FloatingPoint => s"castFLOAT${4 * dataType.asInstanceOf[ArrowType.FloatingPoint].getPrecision().getFlatbufID()}" + case i: ArrowType.Int if i.getBitWidth == 32 => + "castINT" case _ => throw new UnsupportedOperationException(s"getCastFuncName(${dataType}) is not supported.") } diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarLiterals.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarLiterals.scala index f19b49c94..5e5070823 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarLiterals.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarLiterals.scala @@ -75,6 +75,8 @@ class ColumnarLiteral(lit: Literal) throw new UnsupportedOperationException( s"can't support CalendarIntervalType with microseconds yet") } + case ShortType => + new ArrowType.Int(32, true) case _ => CodeGeneration.getResultType(dataType) } @@ -97,12 +99,13 @@ class ColumnarLiteral(lit: Literal) case _ => (TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType) } - case t: ShortType => + case _: ShortType => value match { case null => (TreeBuilder.makeNull(resultType), resultType) case _ => - (TreeBuilder.makeLiteral(value.asInstanceOf[Integer]), resultType) + (TreeBuilder.makeLiteral(new Integer( + value.asInstanceOf[java.lang.Short].toInt)), resultType) } case t: LongType => value match {