Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revive "Strict Equality" for assertEquals() #521

Merged
merged 14 commits into from
Apr 16, 2022
Merged
55 changes: 39 additions & 16 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ lazy val mimaEnable: List[Def.Setting[_]] = List(
"munit.internal.junitinterface.JUnitComputer.this"
),
// Known breaking changes for MUnit v1
ProblemFilters.exclude[DirectMissingMethodProblem](
"munit.Assertions.assertNotEquals"
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"munit.Assertions.assertEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.Assertions.assertNotEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.Assertions.assertEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.FunSuite.assertNotEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.FunSuite.assertEquals"
),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"munit.FunSuite.munitTestTransform"
),
Expand Down Expand Up @@ -194,22 +212,8 @@ lazy val junit = project
lazy val munit = crossProject(JSPlatform, JVMPlatform, NativePlatform)
.settings(
sharedSettings,
Compile / unmanagedSourceDirectories ++= {
val root = (ThisBuild / baseDirectory).value / "munit"
val base = root / "shared" / "src" / "main"
val result = mutable.ListBuffer.empty[File]
val partialVersion = CrossVersion.partialVersion(scalaVersion.value)
if (isPreScala213(partialVersion)) {
result += base / "scala-pre-2.13"
}
if (isNotScala211(partialVersion)) {
result += base / "scala-post-2.11"
}
if (isScala2(partialVersion)) {
result += base / "scala-2"
}
result.toList
},
Compile / unmanagedSourceDirectories ++=
crossBuildingDirectories("munit", "main").value,
libraryDependencies ++= List(
"org.scala-lang" % "scala-reflect" % {
if (isScala3Setting.value) scala213
Expand Down Expand Up @@ -308,6 +312,8 @@ lazy val tests = crossProject(JSPlatform, JVMPlatform, NativePlatform)
((ThisBuild / baseDirectory).value / "tests" / "shared" / "src" / "main").getAbsolutePath.toString,
scalaVersion
),
Test / unmanagedSourceDirectories ++=
crossBuildingDirectories("tests", "test").value,
publish / skip := true
)
.nativeConfigure(sharedNativeConfigure)
Expand Down Expand Up @@ -348,3 +354,20 @@ lazy val docs = project
Global / excludeLintKeys ++= Set(
mimaPreviousArtifacts
)
def crossBuildingDirectories(name: String, config: String) =
Def.setting[Seq[File]] {
val root = (ThisBuild / baseDirectory).value / name
val base = root / "shared" / "src" / config
val result = mutable.ListBuffer.empty[File]
val partialVersion = CrossVersion.partialVersion(scalaVersion.value)
if (isPreScala213(partialVersion)) {
result += base / "scala-pre-2.13"
}
if (isNotScala211(partialVersion)) {
result += base / "scala-post-2.11"
}
if (isScala2(partialVersion)) {
result += base / "scala-2"
}
result.toList
}
24 changes: 15 additions & 9 deletions docs/assertions.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,29 @@ assertEquals(
Comparing two values of different types is a compile error.

```scala mdoc:fail
assertEquals(1, "")
assertEquals(Option("message"), "message")
```

The "expected" value (second argument) must be a subtype of the "obtained" value
(first argument).
It's a compile error even if the comparison is true at runtime.

```scala mdoc
assertEquals(Option(1), Some(1))
```scala mdoc:fail
assertEquals(List(1), Vector(1))
```

It's a compile error if you swap the order of the arguments.

```scala mdoc:fail
assertEquals(Some(1), Option(1))
assertEquals('a', 'a'.toInt)
```

It's OK to compare two types as long as one argument is a subtype of the other
type.

```scala mdoc
assertEquals(Option(1), Some(1)) // OK
assertEquals(Some(1), Option(1)) // OK
```

Use `assertEquals[Any, Any]` if you really want to compare two different types.
Use `assertEquals[Any, Any]` if you think it's OK to compare the two types at
runtime.

```scala mdoc
val right1: Either[String , Int] = Right(42)
Expand Down
79 changes: 15 additions & 64 deletions munit/shared/src/main/scala/munit/Assertions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,6 @@ trait Assertions extends MacroCompat.CompileErrorMacro {

def munitAnsiColors: Boolean = true

private def munitComparisonHandler(
actualObtained: Any,
actualExpected: Any
): ComparisonFailExceptionHandler =
new ComparisonFailExceptionHandler {
override def handle(
message: String,
unusedObtained: String,
unusedExpected: String,
loc: Location
): Nothing = failComparison(message, actualObtained, actualExpected)(loc)
}

private def munitFilterAnsi(message: String): String =
if (munitAnsiColors) message
else AnsiColors.filterAnsi(message)
Expand Down Expand Up @@ -67,20 +54,25 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
Diffs.assertNoDiff(
obtained,
expected,
munitComparisonHandler(obtained, expected),
ComparisonFailExceptionHandler.fromAssertions(this, Clues.empty),
munitPrint(clue),
printObtainedAsStripMargin = true
)
}
}

/**
* Asserts that two elements are not equal according to the `Compare[A, B]` type-class.
*
* By default, uses `==` to compare values.
*/
def assertNotEquals[A, B](
obtained: A,
expected: B,
clue: => Any = "values are the same"
)(implicit loc: Location, ev: A =:= B): Unit = {
)(implicit loc: Location, compare: Compare[A, B]): Unit = {
StackTraces.dropInside {
if (obtained == expected) {
if (compare.isEqual(obtained, expected)) {
failComparison(
s"${munitPrint(clue)} expected same: $expected was not: $obtained",
obtained,
Expand All @@ -91,32 +83,17 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
}

/**
* Asserts that two elements are equal using `==` equality.
*
* The "expected" value (second argument) must have the same type or be a
* subtype of the "obtained" value (first argument). For example:
* {{{
* assertEquals(Option(1), Some(1)) // OK
* assertEquals(Some(1), Option(1)) // Error: Option[Int] is not a subtype of Some[Int]
* }}}
* Asserts that two elements are equal according to the `Compare[A, B]` type-class.
*
* Use `assertEquals[Any, Any](a, b)` as an escape hatch to compare two
* values of different types. For example:
* {{{
* val a: Either[List[String], Int] = Right(42)
* val b: Either[String, Int] = Right(42)
* assertEquals[Any, Any](a, b) // OK
* assertEquals(a, b) // Error: Either[String, Int] is not a subtype of Either[List[String], Int]
* }}}
* By default, uses `==` to compare values.
*/
def assertEquals[A, B](
obtained: A,
expected: B,
clue: => Any = "values are not the same"
)(implicit loc: Location, ev: B <:< A): Unit = {
)(implicit loc: Location, compare: Compare[A, B]): Unit = {
StackTraces.dropInside {
if (obtained != expected) {

if (!compare.isEqual(obtained, expected)) {
(obtained, expected) match {
case (a: Array[_], b: Array[_]) if a.sameElements(b) =>
// Special-case error message when comparing arrays. See
Expand All @@ -137,34 +114,7 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
)
case _ =>
}

Diffs.assertNoDiff(
munitPrint(obtained),
munitPrint(expected),
munitComparisonHandler(obtained, expected),
munitPrint(clue),
printObtainedAsStripMargin = false
)
// try with `.toString` in case `munitPrint()` produces identical formatting for both values.
Diffs.assertNoDiff(
obtained.toString(),
expected.toString(),
munitComparisonHandler(obtained, expected),
munitPrint(clue),
printObtainedAsStripMargin = false
)
if (obtained.toString() == expected.toString())
failComparison(
s"values are not equal even if they have the same `toString()`: $obtained",
obtained,
expected
)
else
failComparison(
s"values are not equal, even if their text representation only differs in leading/trailing whitespace and ANSI escape characters: $obtained",
obtained,
expected
)
compare.failEqualsComparison(obtained, expected, clue, loc, this)
}
}
}
Expand Down Expand Up @@ -320,7 +270,8 @@ trait Assertions extends MacroCompat.CompileErrorMacro {
munitFilterAnsi(munitLines.formatLine(loc, message, clues)),
obtained,
expected,
loc
loc,
isStackTracesEnabled = false
)
}

Expand Down
4 changes: 3 additions & 1 deletion munit/shared/src/main/scala/munit/Clue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ class Clue[+T](
override def toString(): String = s"Clue($source, $value)"
}
object Clue extends MacroCompat.ClueMacro {
def empty[T](value: T): Clue[T] = new Clue("", value, "")
@deprecated("use fromValue instead", "1.0.0")
def empty[T](value: T): Clue[T] = fromValue(value)
def fromValue[T](value: T): Clue[T] = new Clue("", value, "")
}
4 changes: 4 additions & 0 deletions munit/shared/src/main/scala/munit/Clues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ import munit.internal.console.Printers
class Clues(val values: List[Clue[_]]) {
override def toString(): String = Printers.print(this)
}
object Clues {
def empty: Clues = new Clues(List())
def fromValue[T](value: T): Clues = new Clues(List(Clue.fromValue(value)))
}
122 changes: 122 additions & 0 deletions munit/shared/src/main/scala/munit/Compare.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package munit

import munit.internal.difflib.Diffs
import munit.internal.difflib.ComparisonFailExceptionHandler
import scala.annotation.implicitNotFound

/**
* A type-class that is used to compare values in MUnit assertions.
*
* By default, uses == and allows comparison between any two types as long
* they have a supertype/subtype relationship. For example:
*
* - Compare[T, T] OK
* - Compare[Some[Int], Option[Int]] OK, subtype
* - Compare[Option[Int], Some[Int]] OK, supertype
* - Compare[List[Int], collection.Seq[Int]] OK, subtype
* - Compare[List[Int], Vector[Int]] Error, requires upcast to `Seq[Int]`
*/
@implicitNotFound(
// NOTE: Dotty ignores this message if the string is formatted as a multiline string """..."""
"Can't compare these two types:\n First type: ${A}\n Second type: ${B}\nPossible ways to fix this error:\n Alternative 1: provide an implicit instance for Compare[${A}, ${B}]\n Alternative 2: upcast either type into `Any` or a shared supertype"
)
trait Compare[A, B] {

/**
* Returns true if the values are equal according to the rules of this `Compare[A, B]` instance.
*
* The default implementation of this method uses `==`.
*/
def isEqual(obtained: A, expected: B): Boolean

/**
* Throws an exception to fail this assertion when two values are not equal.
*
* Override this method to customize the error message. For example, it may
* be helpful to generate an image/HTML file if you're comparing visual
* values. Anything is possible, use your imagination!
*
* @return should ideally throw a org.junit.ComparisonFailException in order
* to support the IntelliJ diff viewer.
*/
def failEqualsComparison(
obtained: A,
expected: B,
title: Any,
loc: Location,
assertions: Assertions
): Nothing = {
val diffHandler = new ComparisonFailExceptionHandler {
override def handle(
message: String,
_obtained: String,
_expected: String,
loc: Location
): Nothing =
assertions.failComparison(
message,
obtained,
expected
)(loc)
}
// Attempt 1: custom pretty-printer that produces multiline output, which is
// optimized for line-by-line diffing.
Diffs.assertNoDiff(
assertions.munitPrint(obtained),
assertions.munitPrint(expected),
diffHandler,
title = assertions.munitPrint(title),
printObtainedAsStripMargin = false
)(loc)

// Attempt 2: try with `.toString` in case `munitPrint()` produces identical
// formatting for both values.
Diffs.assertNoDiff(
obtained.toString(),
expected.toString(),
diffHandler,
title = assertions.munitPrint(title),
printObtainedAsStripMargin = false
)(loc)

// Attempt 3: string comparison is not working, unconditionally fail the test.
if (obtained.toString() == expected.toString())
assertions.failComparison(
s"values are not equal even if they have the same `toString()`: $obtained",
obtained,
expected
)(loc)
else
assertions.failComparison(
s"values are not equal, even if their text representation only differs in leading/trailing whitespace and ANSI escape characters: $obtained",
obtained,
expected
)(loc)
}

}

object Compare extends ComparePriority1 {
private val anyEquality: Compare[Any, Any] = _ == _
def defaultCompare[A, B]: Compare[A, B] =
anyEquality.asInstanceOf[Compare[A, B]]
}

/** Allows comparison between A and B when A is a subtype of B */
trait ComparePriority1 extends ComparePriority2 {
implicit def compareSubtypeWithSupertype[A, B](implicit
ev: A <:< B
): Compare[A, B] = Compare.defaultCompare
}

/**
* Allows comparison between A and B when B is a subtype of A.
*
* This implicit is defined separately from ComparePriority1 in order to avoid
* diverging implicit search when comparing equal types.
*/
trait ComparePriority2 {
implicit def compareSupertypeWithSubtype[A, B](implicit
ev: A <:< B
): Compare[B, A] = Compare.defaultCompare
}
Loading