Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical {

delimiters += (
"@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
",", ";", "%", "{", "}", ":", "[", "]", "."
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~"
)

override lazy val token: Parser[Token] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class SqlParser extends AbstractSparkSQLParser {
(LIMIT ~> expression).? ^^ {
case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
val base = r.getOrElse(NoRelation)
val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
val withFilter = f.map(Filter(_, base)).getOrElse(base)
val withProjection = g
.map(Aggregate(_, assignAliases(p), withFilter))
.getOrElse(Project(assignAliases(p), withFilter))
Expand Down Expand Up @@ -260,6 +260,9 @@ class SqlParser extends AbstractSparkSQLParser {
( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) }
| "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) }
| "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) }
| "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) }
| "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) }
| "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) }
)

protected lazy val function: Parser[Expression] =
Expand Down Expand Up @@ -303,33 +306,74 @@ class SqlParser extends AbstractSparkSQLParser {
CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) }

protected lazy val literal: Parser[Literal] =
( numericLit ^^ {
case i if i.toLong > Int.MaxValue => Literal(i.toLong)
case i => Literal(i.toInt)
}
| NULL ^^^ Literal(null, NullType)
| floatLit ^^ {case f => Literal(f.toDouble) }
( numericLiteral
| booleanLiteral
| stringLit ^^ {case s => Literal(s, StringType) }
| NULL ^^^ Literal(null, NullType)
)

protected lazy val booleanLiteral: Parser[Literal] =
( TRUE ^^^ Literal(true, BooleanType)
| FALSE ^^^ Literal(false, BooleanType)
)

protected lazy val numericLiteral: Parser[Literal] =
signedNumericLiteral | unsignedNumericLiteral

protected lazy val sign: Parser[String] =
"+" | "-"

protected lazy val signedNumericLiteral: Parser[Literal] =
( sign ~ numericLit ^^ { case s ~ l => Literal(toNarrowestIntegerType(s + l)) }
| sign ~ floatLit ^^ { case s ~ f => Literal((s + f).toDouble) }
)

protected lazy val unsignedNumericLiteral: Parser[Literal] =
( numericLit ^^ { n => Literal(toNarrowestIntegerType(n)) }
| floatLit ^^ { f => Literal(f.toDouble) }
)

private val longMax = BigDecimal(s"${Long.MaxValue}")
private val longMin = BigDecimal(s"${Long.MinValue}")
private val intMax = BigDecimal(s"${Int.MaxValue}")
private val intMin = BigDecimal(s"${Int.MinValue}")

private def toNarrowestIntegerType(value: String) = {
val bigIntValue = BigDecimal(value)

bigIntValue match {
case v if v < longMin || v > longMax => v
case v if v < intMin || v > intMax => v.toLong
case v => v.toInt
}
}

protected lazy val floatLit: Parser[String] =
elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)
( "." ~> unsignedNumericLiteral ^^ { u => "0." + u }
| elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)
)

protected lazy val baseExpression: Parser[Expression] =
( "*" ^^^ Star(None)
| primary
)

protected lazy val baseExpression: PackratParser[Expression] =
( expression ~ ("[" ~> expression <~ "]") ^^
protected lazy val signedPrimary: Parser[Expression] =
sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e}

protected lazy val primary: PackratParser[Expression] =
( literal
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => GetItem(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => GetField(base, fieldName) }
| TRUE ^^^ Literal(true, BooleanType)
| FALSE ^^^ Literal(false, BooleanType)
| cast
| "(" ~> expression <~ ")"
| function
| "-" ~> literal ^^ UnaryMinus
| dotExpressionHeader
| ident ^^ UnresolvedAttribute
| "*" ^^^ Star(None)
| literal
| signedPrimary
| "~" ~> expression ^^ BitwiseNot
)

protected lazy val dotExpressionHeader: Parser[Expression] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ package object dsl {

def unary_- = UnaryMinus(expr)
def unary_! = Not(expr)
def unary_~ = BitwiseNot(expr)

def + (other: Expression) = Add(expr, other)
def - (other: Expression) = Subtract(expr, other)
def * (other: Expression) = Multiply(expr, other)
def / (other: Expression) = Divide(expr, other)
def % (other: Expression) = Remainder(expr, other)
def & (other: Expression) = BitwiseAnd(expr, other)
def | (other: Expression) = BitwiseOr(expr, other)
def ^ (other: Expression) = BitwiseXor(expr, other)

def && (other: Expression) = And(expr, other)
def || (other: Expression) = Or(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
left.dataType
}

override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
if(evalE1 == null) {
null
} else {
val evalE2 = right.eval(input)
if (evalE2 == null) {
null
} else {
evalInternal(evalE1, evalE2)
}
}
}

def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any =
sys.error(s"BinaryExpressions must either override eval or evalInternal")
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
Expand Down Expand Up @@ -100,6 +117,78 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}

/**
* A function that calculates bitwise and(&) of two numbers.
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "&"

override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
case ByteType => (evalE1.asInstanceOf[Byte] & evalE2.asInstanceOf[Byte]).toByte
case ShortType => (evalE1.asInstanceOf[Short] & evalE2.asInstanceOf[Short]).toShort
case IntegerType => evalE1.asInstanceOf[Int] & evalE2.asInstanceOf[Int]
case LongType => evalE1.asInstanceOf[Long] & evalE2.asInstanceOf[Long]
case other => sys.error(s"Unsupported bitwise & operation on ${other}")
}
}

/**
* A function that calculates bitwise or(|) of two numbers.
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "&"

override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
case ByteType => (evalE1.asInstanceOf[Byte] | evalE2.asInstanceOf[Byte]).toByte
case ShortType => (evalE1.asInstanceOf[Short] | evalE2.asInstanceOf[Short]).toShort
case IntegerType => evalE1.asInstanceOf[Int] | evalE2.asInstanceOf[Int]
case LongType => evalE1.asInstanceOf[Long] | evalE2.asInstanceOf[Long]
case other => sys.error(s"Unsupported bitwise | operation on ${other}")
}
}

/**
* A function that calculates bitwise xor(^) of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "^"

override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
case ByteType => (evalE1.asInstanceOf[Byte] ^ evalE2.asInstanceOf[Byte]).toByte
case ShortType => (evalE1.asInstanceOf[Short] ^ evalE2.asInstanceOf[Short]).toShort
case IntegerType => evalE1.asInstanceOf[Int] ^ evalE2.asInstanceOf[Int]
case LongType => evalE1.asInstanceOf[Long] ^ evalE2.asInstanceOf[Long]
case other => sys.error(s"Unsupported bitwise ^ operation on ${other}")
}
}

/**
* A function that calculates bitwise not(~) of a number.
*/
case class BitwiseNot(child: Expression) extends UnaryExpression {
type EvaluatedType = Any

def dataType = child.dataType
override def foldable = child.foldable
def nullable = child.nullable
override def toString = s"-$child"

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
dataType match {
case ByteType => (~(evalE.asInstanceOf[Byte])).toByte
case ShortType => (~(evalE.asInstanceOf[Short])).toShort
case IntegerType => ~(evalE.asInstanceOf[Int])
case LongType => ~(evalE.asInstanceOf[Long])
case other => sys.error(s"Unsupported bitwise ~ operation on ${other}")
}
}
}
}

case class MaxOf(left: Expression, right: Expression) extends Expression {
type EvaluatedType = Any

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,4 +674,36 @@ class ExpressionEvaluationSuite extends FunSuite {

checkEvaluation(Sqrt(Literal(null, DoubleType)), null, new GenericRow(Array[Any](null)))
}

test("Bitwise operations") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(BitwiseAnd(c1, c4), null, row)
checkEvaluation(BitwiseAnd(c1, c2), 0, row)
checkEvaluation(BitwiseAnd(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(BitwiseAnd(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)

checkEvaluation(BitwiseOr(c1, c4), null, row)
checkEvaluation(BitwiseOr(c1, c2), 3, row)
checkEvaluation(BitwiseOr(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(BitwiseOr(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)

checkEvaluation(BitwiseXor(c1, c4), null, row)
checkEvaluation(BitwiseXor(c1, c2), 3, row)
checkEvaluation(BitwiseXor(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(BitwiseXor(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)

checkEvaluation(BitwiseNot(c4), null, row)
checkEvaluation(BitwiseNot(c1), -2, row)
checkEvaluation(BitwiseNot(Literal(null, IntegerType)), null, row)

checkEvaluation(c1 & c2, 0, row)
checkEvaluation(c1 | c2, 3, row)
checkEvaluation(c1 ^ c2, 3, row)
checkEvaluation(~c1, -2, row)
}
}
Loading