diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 854b5b461bdc8..a0580f1abd2ec 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -308,7 +308,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { baseExpression * ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } | - "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } + "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } | + "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAND(e1,e2) } ) protected lazy val function: Parser[Expression] = @@ -396,7 +397,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", "." + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&" ) override lazy val token: Parser[Token] = ( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1eb260efa6387..986ce1f5178a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} +import org.apache.spark.sql.catalyst.types.{BitwiseOpsType, DataType, FractionalType, IntegralType, NumericType, NativeType} abstract class Expression extends TreeNode[Expression] { self: Product => @@ -188,6 +188,39 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Evaluation helper function for 2 children expressions which do bitwise operations. Those expressions are + * supposed to be in the same data type, and also the return type. + * Either one of the expressions result is null, the evaluation result should be null. + */ + @inline + protected final def b2( + i: Row, + e1: Expression, + e2: Expression, + f: ((BitwiseOpsType[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.eval(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.eval(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case i: IntegralType => + f.asInstanceOf[(BitwiseOpsType[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( + i.asInstanceOf[BitwiseOpsType[i.JvmType]], evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case other => sys.error(s"Type $other does not support bitwise operations") + } + } + } + } + /** * Evaluation helper function for 2 Comparable children expressions. Those expressions are * supposed to be in the same data type, and the return type should be Integer: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe825fdcdae37..160cd70083c77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -100,6 +100,14 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } +case class BitwiseAND(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "&" + + override def eval(input: Row): Any = dataType match { + case _: IntegralType => b2(input, left , right, _.bitwiseAND(_, _)) + } +} + case class MaxOf(left: Expression, right: Expression) extends Expression { type EvaluatedType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index ac043d4dd8eb9..ec11047e3f358 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -216,40 +216,49 @@ abstract class IntegralType extends NumericType { private[sql] val integral: Integral[JvmType] } -case object LongType extends IntegralType { +/** Matcher for any expressions that evaluate which do bitwise operations */ +trait BitwiseOpsType[T] { + def bitwiseAND(x: T, y: T): T +} + +case object LongType extends IntegralType with BitwiseOpsType[Long] { private[sql] type JvmType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[JvmType]] def simpleString: String = "long" + def bitwiseAND(x: JvmType, y: JvmType): JvmType = x & y } -case object IntegerType extends IntegralType { +case object IntegerType extends IntegralType with BitwiseOpsType[Int] { private[sql] type JvmType = Int @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[JvmType]] def simpleString: String = "integer" + def bitwiseAND(x: JvmType, y: JvmType): JvmType = x & y } -case object ShortType extends IntegralType { +case object ShortType extends IntegralType with BitwiseOpsType[Short] { private[sql] type JvmType = Short @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[JvmType]] def simpleString: String = "short" + def bitwiseAND(x: JvmType, y: JvmType): JvmType = (x & y).toShort } -case object ByteType extends IntegralType { +case object ByteType extends IntegralType with BitwiseOpsType[Byte] { private[sql] type JvmType = Byte @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[JvmType]] def simpleString: String = "byte" + def bitwiseAND(x: JvmType, y: JvmType): JvmType = (x & y).toByte } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 6fb6cb8db0c8f..549043fb6727b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -680,9 +680,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } - + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( - sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)} + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1) + } + + test("SPARK-3814 Support Bitwise & operator") { + checkAnswer(sql("SELECT key&1 FROM testData WHERE key = 1 "), 1) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 32c9175f181bb..29bcac19cddfc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -935,6 +935,7 @@ private[hive] object HiveQl { case Token(DIV(), left :: right:: Nil) => Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + case Token("&", left :: right:: Nil) => BitwiseAND(nodeToExpr(left), nodeToExpr(right)) case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg)) /* Comparisons */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3647bb1c4ce7d..cebfc5278b8a9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -68,5 +68,11 @@ class SQLQuerySuite extends QueryTest { checkAnswer( sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"), sql("SELECT `key` FROM src").collect().toSeq) - } + } + + test("SPARK-3814 Support Bitwise & operator") { + checkAnswer( + sql("SELECT case when 1&1=1 then 1 else 0 end FROM src"), + sql("SELECT 1 FROM src").collect().toSeq) + } }