diff --git a/core/src/main/scala/chisel3/Data.scala b/core/src/main/scala/chisel3/Data.scala index b83fb53e61f..7975cbdfe6a 100644 --- a/core/src/main/scala/chisel3/Data.scala +++ b/core/src/main/scala/chisel3/Data.scala @@ -5,7 +5,7 @@ package chisel3 import chisel3.experimental.dataview.reify import scala.language.experimental.macros -import chisel3.experimental.{Analog, BaseModule, DataMirror, FixedPoint, Interval} +import chisel3.experimental.{Analog, BaseModule, DataMirror, EnumType, FixedPoint, Interval} import chisel3.internal.Builder.pushCommand import chisel3.internal._ import chisel3.internal.firrtl._ @@ -895,6 +895,84 @@ abstract class Data extends HasId with NamedComponent with SourceInfoDoc { def toPrintable: Printable } +object Data { + + /** + * Provides generic, recursive equality for [[Bundle]] and [[Vec]] hardware. This avoids the + * need to use workarounds such as `bundle1.asUInt === bundle2.asUInt` by allowing users + * to instead write `bundle1 === bundle2`. + * + * Static type safety of this comparison is guaranteed at compile time as the extension + * method requires the same parameterized type for both the left-hand and right-hand + * sides. It is, however, possible to get around this type safety using `Bundle` subtypes + * that can differ during runtime (e.g. through a generator). These cases are + * subsequently raised as elaboration errors. + * + * @param lhs The [[Data]] hardware on the left-hand side of the equality + */ + implicit class DataEquality[T <: Data](lhs: T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) { + + /** Dynamic recursive equality operator for generic [[Data]] + * + * @param rhs a hardware [[Data]] to compare `lhs` to + * @return a hardware [[Bool]] asserted if `lhs` is equal to `rhs` + * @throws ChiselException when `lhs` and `rhs` are different types during elaboration time + */ + def ===(rhs: T): Bool = { + (lhs, rhs) match { + case (thiz: UInt, that: UInt) => thiz === that + case (thiz: SInt, that: SInt) => thiz === that + case (thiz: AsyncReset, that: AsyncReset) => thiz.asBool === that.asBool + case (thiz: Reset, that: Reset) => thiz === that + case (thiz: Interval, that: Interval) => thiz === that + case (thiz: FixedPoint, that: FixedPoint) => thiz === that + case (thiz: EnumType, that: EnumType) => thiz === that + case (thiz: Clock, that: Clock) => thiz.asUInt === that.asUInt + case (thiz: Vec[_], that: Vec[_]) => + if (thiz.length != that.length) { + throwException(s"Cannot compare Vecs $thiz and $that: Vec sizes differ") + } else { + thiz.getElements + .zip(that.getElements) + .map { case (thisData, thatData) => thisData === thatData } + .reduce(_ && _) + } + case (thiz: Record, that: Record) => + if (thiz.elements.size != that.elements.size) { + throwException(s"Cannot compare Bundles $thiz and $that: Bundle types differ") + } else { + thiz.elements.map { + case (thisName, thisData) => + if (!that.elements.contains(thisName)) + throwException( + s"Cannot compare Bundles $thiz and $that: field $thisName (from $thiz) was not found in $that" + ) + + val thatData = that.elements(thisName) + + try { + thisData === thatData + } catch { + case e: ChiselException => + throwException( + s"Cannot compare field $thisName in Bundles $thiz and $that: ${e.getMessage.split(": ").last}" + ) + } + } + .reduce(_ && _) + } + // This should be matching to (DontCare, DontCare) but the compiler wasn't happy with that + case (_: DontCare.type, _: DontCare.type) => true.B + + case (thiz: Analog, that: Analog) => + throwException(s"Cannot compare Analog values $thiz and $that: Equality isn't defined for Analog values") + // Runtime types are different + case (thiz, that) => throwException(s"Cannot compare $thiz and $that: Runtime types differ") + } + } + } +} + trait WireFactory { /** Construct a [[Wire]] from a type template diff --git a/src/test/scala/chiselTests/DataEqualitySpec.scala b/src/test/scala/chiselTests/DataEqualitySpec.scala new file mode 100644 index 00000000000..4ac3292dcf4 --- /dev/null +++ b/src/test/scala/chiselTests/DataEqualitySpec.scala @@ -0,0 +1,257 @@ +package chiselTests + +import chisel3._ +import chisel3.experimental.VecLiterals._ +import chisel3.experimental.BundleLiterals._ +import chisel3.experimental.{Analog, ChiselEnum, ChiselRange, FixedPoint, Interval} +import chisel3.stage.ChiselStage +import chisel3.testers.BasicTester +import chisel3.util.Valid + +class EqualityModule(lhsGen: => Data, rhsGen: => Data) extends Module { + val out = IO(Output(Bool())) + + val lhs = lhsGen + val rhs = rhsGen + + out := lhs === rhs +} + +class EqualityTester(lhsGen: => Data, rhsGen: => Data) extends BasicTester { + val module = Module(new EqualityModule(lhsGen, rhsGen)) + + assert(module.out) + + stop() +} + +class AnalogBundle extends Bundle { + val analog = Analog(32.W) +} + +class AnalogExceptionModule extends Module { + class AnalogExceptionModuleIO extends Bundle { + val bundle1 = new AnalogBundle + val bundle2 = new AnalogBundle + } + + val io = IO(new AnalogExceptionModuleIO) +} + +class AnalogExceptionTester extends BasicTester { + val module = Module(new AnalogExceptionModule) + + module.io.bundle1 <> DontCare + module.io.bundle2 <> DontCare + + assert(module.io.bundle1 === module.io.bundle2) + + stop() +} + +class DataEqualitySpec extends ChiselFlatSpec with Utils { + object MyEnum extends ChiselEnum { + val sA, sB = Value + } + object MyEnumB extends ChiselEnum { + val sA, sB = Value + } + class MyBundle extends Bundle { + val a = UInt(8.W) + val b = Bool() + val c = MyEnum() + } + class LongBundle extends Bundle { + val a = UInt(48.W) + val b = SInt(32.W) + val c = FixedPoint(16.W, 4.BP) + } + class RuntimeSensitiveBundle(gen: => Bundle) extends Bundle { + val a = UInt(8.W) + val b: Bundle = gen + } + + behavior.of("UInt === UInt") + it should "pass with equal values" in { + assertTesterPasses { + new EqualityTester(0.U, 0.U) + } + } + it should "fail with differing values" in { + assertTesterFails { + new EqualityTester(0.U, 1.U) + } + } + + behavior.of("SInt === SInt") + it should "pass with equal values" in { + assertTesterPasses { + new EqualityTester(0.S, 0.S) + } + } + it should "fail with differing values" in { + assertTesterFails { + new EqualityTester(0.S, 1.S) + } + } + + behavior.of("Reset === Reset") + it should "pass with equal values" in { + assertTesterPasses { + new EqualityTester(true.B, true.B) + } + } + it should "fail with differing values" in { + assertTesterFails { + new EqualityTester(true.B, false.B) + } + } + + behavior.of("AsyncReset === AsyncReset") + it should "pass with equal values" in { + assertTesterPasses { + new EqualityTester(true.B.asAsyncReset, true.B.asAsyncReset) + } + } + it should "fail with differing values" in { + assertTesterFails { + new EqualityTester(true.B.asAsyncReset, false.B.asAsyncReset) + } + } + + behavior.of("Interval === Interval") + it should "pass with equal values" in { + assertTesterPasses { + new EqualityTester(2.I, 2.I) + } + } + it should "fail with differing values" in { + assertTesterFails { + new EqualityTester(2.I, 3.I) + } + } + + behavior.of("FixedPoint === FixedPoint") + it should "pass with equal values" in { + assertTesterPasses { + new EqualityTester(4.5.F(16.W, 4.BP), 4.5.F(16.W, 4.BP)) + } + } + it should "fail with differing values" in { + assertTesterFails { + new EqualityTester(4.5.F(16.W, 4.BP), 4.6.F(16.W, 4.BP)) + } + } + + behavior.of("ChiselEnum === ChiselEnum") + it should "pass with equal values" in { + assertTesterPasses { + new EqualityTester(MyEnum.sA, MyEnum.sA) + } + } + it should "fail with differing values" in { + assertTesterFails { + new EqualityTester(MyEnum.sA, MyEnum.sB) + } + } + + behavior.of("Vec === Vec") + it should "pass with equal sizes, equal values" in { + assertTesterPasses { + new EqualityTester( + Vec(3, UInt(8.W)).Lit(0 -> 1.U, 1 -> 2.U, 2 -> 3.U), + Vec(3, UInt(8.W)).Lit(0 -> 1.U, 1 -> 2.U, 2 -> 3.U) + ) + } + } + it should "fail with equal sizes, differing values" in { + assertTesterFails { + new EqualityTester( + Vec(3, UInt(8.W)).Lit(0 -> 1.U, 1 -> 2.U, 2 -> 3.U), + Vec(3, UInt(8.W)).Lit(0 -> 0.U, 1 -> 1.U, 2 -> 2.U) + ) + } + } + it should "throw a ChiselException with differing sizes" in { + (the[ChiselException] thrownBy extractCause[ChiselException] { + assertTesterFails { + new EqualityTester( + Vec(3, UInt(8.W)).Lit(0 -> 1.U, 1 -> 2.U, 2 -> 3.U), + Vec(4, UInt(8.W)).Lit(0 -> 1.U, 1 -> 2.U, 2 -> 3.U, 3 -> 4.U) + ) + } + }).getMessage should include("Vec sizes differ") + } + + behavior.of("Bundle === Bundle") + it should "pass with equal type, equal values" in { + assertTesterPasses { + new EqualityTester( + (new MyBundle).Lit(_.a -> 42.U, _.b -> false.B, _.c -> MyEnum.sB), + (new MyBundle).Lit(_.a -> 42.U, _.b -> false.B, _.c -> MyEnum.sB) + ) + } + } + it should "fail with equal type, differing values" in { + assertTesterFails { + new EqualityTester( + (new MyBundle).Lit(_.a -> 42.U, _.b -> false.B, _.c -> MyEnum.sB), + (new MyBundle).Lit(_.a -> 42.U, _.b -> false.B, _.c -> MyEnum.sA) + ) + } + } + it should "throw a ChiselException with differing runtime types" in { + (the[ChiselException] thrownBy extractCause[ChiselException] { + assertTesterFails { + new EqualityTester( + (new RuntimeSensitiveBundle(new MyBundle)).Lit( + _.a -> 1.U, + _.b -> (new MyBundle).Lit( + _.a -> 42.U, + _.b -> false.B, + _.c -> MyEnum.sB + ) + ), + (new RuntimeSensitiveBundle(new LongBundle)).Lit( + _.a -> 1.U, + _.b -> (new LongBundle).Lit( + _.a -> 42.U, + _.b -> 0.S, + _.c -> 4.5.F(16.W, 4.BP) + ) + ) + ) + } + }).getMessage should include("Runtime types differ") + } + + behavior.of("DontCare === DontCare") + it should "pass with two invalids" in { + assertTesterPasses { + new EqualityTester(Valid(UInt(8.W)).Lit(_.bits -> 123.U), Valid(UInt(8.W)).Lit(_.bits -> 123.U)) + } + } + it should "exhibit the same behavior as comparing two invalidated wires" in { + // Also check that two invalidated wires are equal + assertTesterPasses { + new EqualityTester(WireInit(UInt(8.W), DontCare), WireInit(UInt(8.W), DontCare)) + } + + // Compare the verilog generated from both test cases and verify that they both are equal to true + val verilog1 = ChiselStage.emitVerilog( + new EqualityModule(Valid(UInt(8.W)).Lit(_.bits -> 123.U), Valid(UInt(8.W)).Lit(_.bits -> 123.U)) + ) + val verilog2 = + ChiselStage.emitVerilog(new EqualityModule(WireInit(UInt(8.W), DontCare), WireInit(UInt(8.W), DontCare))) + + verilog1 should include("assign out = 1'h1;") + verilog2 should include("assign out = 1'h1;") + } + + behavior.of("Analog === Analog") + it should "throw a ChiselException" in { + (the[ChiselException] thrownBy extractCause[ChiselException] { + assertTesterFails { new AnalogExceptionTester } + }).getMessage should include("Equality isn't defined for Analog values") + } +}