diff --git a/build.sbt b/build.sbt index 915985f31e..c45efa488a 100644 --- a/build.sbt +++ b/build.sbt @@ -26,9 +26,9 @@ val coreDependencies = Seq( "org.apache.spark" %% "spark-mllib" % sparkVersion % "compile", "org.apache.spark" %% "spark-avro" % sparkVersion % "provided", "org.apache.spark" %% "spark-tags" % sparkVersion % "test", - "org.scalatest" %% "scalatest" % "3.0.5" % "test") + "org.scalatest" %% "scalatest" % "3.2.14" % "test") val extraDependencies = Seq( - "org.scalactic" %% "scalactic" % "3.0.5", + "org.scalactic" %% "scalactic" % "3.2.14", "io.spray" %% "spray-json" % "1.3.5", "com.jcraft" % "jsch" % "0.1.54", "org.apache.httpcomponents.client5" % "httpclient5" % "5.1.3", diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TranslatorSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TranslatorSuite.scala index ee1abeb02c..59f78a2085 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TranslatorSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/cognitive/split1/TranslatorSuite.scala @@ -9,7 +9,12 @@ import com.microsoft.azure.synapse.ml.core.test.base.{Flaky, TestBase} import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{col, flatten} +import org.apache.spark.sql.functions.{col, flatten, udf} +import org.scalactic.Equality +import scala.collection.Traversable + +import java.sql.Struct +import scala.collection.mutable trait TranslatorKey { lazy val translatorKey: String = sys.env.getOrElse("TRANSLATOR_KEY", Secrets.TranslatorKey) @@ -188,8 +193,9 @@ class TransliterateSuite extends TransformerFuzzing[Transliterate] .withColumn("text", col("result.text")) .withColumn("script", col("result.script")) .select("text", "script").collect() - assert(results.head.getSeq(0).mkString("\n") === "Kon'nichiwa\nsayonara") - assert(results.head.getSeq(1).mkString("\n") === "Latn\nLatn") + + assert(TransliterateSuite.stripInvalid(results.head.getSeq(0).mkString("\n")) === "Kon'nichiwa\nsayonara") + assert(TransliterateSuite.stripInvalid(results.head.getSeq(1).mkString("\n")) === "Latn\nLatn") } test("Throw errors if required fields not set") { @@ -206,12 +212,30 @@ class TransliterateSuite extends TransformerFuzzing[Transliterate] assert(caught.getMessage.contains("toScript")) } + val stripUdf = udf { + (o: Seq[(String, String)]) => { + o.map(t => (TransliterateSuite.stripInvalid(t._1), t._2)) + } + } + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { + val column = "result" + super.assertDFEq( + df1.withColumn(column, stripUdf(col(column))), + df2.withColumn(column, stripUdf(col(column))))(eq) + } + override def testObjects(): Seq[TestObject[Transliterate]] = Seq(new TestObject(transliterate, transDf)) override def reader: MLReadable[_] = Transliterate } +object TransliterateSuite { + private def stripInvalid(str: String): String = { + "[^\n'A-Za-z]".r.replaceAllIn(str, "") + } +} + class DetectSuite extends TransformerFuzzing[Detect] with TranslatorKey with Flaky with TranslatorUtils { diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/utils/SlicerFunctionsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/utils/SlicerFunctionsSuite.scala index 21bf935f2d..31920ea9c6 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/utils/SlicerFunctionsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/core/utils/utils/SlicerFunctionsSuite.scala @@ -9,7 +9,7 @@ import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.{FloatType, IntegerType} -import org.scalatest.Matchers.{a, thrownBy} +import org.scalatest.matchers.should.Matchers.{a, thrownBy} class SlicerFunctionsSuite extends TestBase { test("SlicerFunctions UDFs can handle different types of inputs") { diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/test/base/TestBase.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/test/base/TestBase.scala index 3d8c4d0c9c..2d57508088 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/core/test/base/TestBase.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/core/test/base/TestBase.scala @@ -14,6 +14,7 @@ import org.apache.spark.streaming.{StreamingContext, Seconds => SparkSeconds} import org.scalactic.Equality import org.scalactic.source.Position import org.scalatest._ +import org.scalatest.funsuite.AnyFunSuite import org.scalatest.concurrent.TimeLimits import org.scalatest.time.{Seconds, Span} @@ -142,7 +143,7 @@ object TestBase extends SparkSessionManagement { } -abstract class TestBase extends FunSuite with BeforeAndAfterEachTestData with BeforeAndAfterAll { +abstract class TestBase extends AnyFunSuite with BeforeAndAfterEachTestData with BeforeAndAfterAll { lazy val sparkProvider: SparkSessionManagement = TestBase diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/KernelSHAPSamplerSupportSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/KernelSHAPSamplerSupportSuite.scala index 0d8fabe2fb..2ee4850627 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/KernelSHAPSamplerSupportSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/KernelSHAPSamplerSupportSuite.scala @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.explainers.split1 import breeze.linalg.sum import com.microsoft.azure.synapse.ml.core.test.base.TestBase import com.microsoft.azure.synapse.ml.explainers.KernelSHAPSamplerSupport -import org.scalatest.Matchers._ +import org.scalatest.matchers.should.Matchers._ class KernelSHAPSamplerSupportSuite extends TestBase { diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/SamplerSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/SamplerSuite.scala index 77eaf99be9..5b245f6cb9 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/SamplerSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/SamplerSuite.scala @@ -16,7 +16,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types._ import org.scalactic.{Equality, TolerantNumerics} -import org.scalatest.Matchers._ +import org.scalatest.matchers.should.Matchers._ import java.nio.file.{Files, Paths} import javax.imageio.ImageIO