From 9816402ffddaf9d7c28a5d734ef4411728089dd2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 13 Nov 2020 12:27:38 -0800 Subject: [PATCH] Prevent flaky test. --- .../expressions/SubexpressionEliminationSuite.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 4725a40781c6..0147c6c6a826 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -162,8 +162,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite { assert(equivalence1.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) // one-time expressions: only ifExpr and its predicate expression assert(equivalence1.getAllEquivalentExprs.count(_.size == 1) == 2) - assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) - assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).last == Seq(condition)) + assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1))) + assert(equivalence1.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(condition))) // Repeated `add` is only in one branch, so we don't count it. val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) @@ -179,13 +179,14 @@ class SubexpressionEliminationSuite extends SparkFunSuite { // `add`: 2, `condition`: 2 assert(equivalence3.getAllEquivalentExprs.count(_.size == 2) == 2) - assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).head == Seq(add, add)) - assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).last == Seq(condition, condition)) + assert(equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(add, add))) + assert( + equivalence3.getAllEquivalentExprs.filter(_.size == 2).contains(Seq(condition, condition))) // `ifExpr1`, `ifExpr3` assert(equivalence3.getAllEquivalentExprs.count(_.size == 1) == 2) - assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).head == Seq(ifExpr1)) - assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).last == Seq(ifExpr3)) + assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr1))) + assert(equivalence3.getAllEquivalentExprs.filter(_.size == 1).contains(Seq(ifExpr3))) } test("Children of conditional expressions: CaseWhen") {