Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
arg
}
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
val newChild1 = f(arg1.asInstanceOf[BaseType])
val newChild2 = f(arg2.asInstanceOf[BaseType])
val newChild1 = if (containsChild(arg1)) {
f(arg1.asInstanceOf[BaseType])
} else {
arg1.asInstanceOf[BaseType]
}

val newChild2 = if (containsChild(arg2)) {
f(arg2.asInstanceOf[BaseType])
} else {
arg2.asInstanceOf[BaseType]
}

if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
changed = true
(newChild1, newChild2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]])
override def output: Seq[Attribute] = Nil
}

case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable {
case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable {
override def children: Seq[Expression] = map.values.toSeq
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}

case class SeqTupleExpression(sons: Seq[(Expression, Expression)],
nonSons: Seq[(Expression, Expression)]) extends Unevaluable {
override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2))
override def nullable: Boolean = true
override def dataType: NullType = NullType
override lazy val resolved = true
}

case class JsonTestTreeNode(arg: Any) extends LeafNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
}
Expand Down Expand Up @@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite {
assert(actual === Dummy(None))
}

test("mapChildren should only works on children") {
val children = Seq((Literal(1), Literal(2)))
val nonChildren = Seq((Literal(3), Literal(4)))
val before = SeqTupleExpression(children, nonChildren)
val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) }
val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren)

val actual = before mapChildren toZero
assert(actual === expect)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to use .equals? Although it will call Object's .equals which is ==.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand it wrongly. === compared to == could provide more information about the error. You can see the follow examples:

scala> assert(1 == 2)
java.lang.AssertionError: assertion failed
  at scala.Predef$.assert(Predef.scala:156)
  ... 32 elided

scala> assert(1 === 2)
<console>:12: error: value === is not a member of Int
       assert(1 === 2)

And also you can check it in https://stackoverflow.com/questions/10489548/what-is-the-triple-equals-operator-in-scala-koans

}

test("preserves origin") {
CurrentOrigin.setPosition(1, 1)
val add = Add(Literal(1), Literal(1))
Expand Down