diff --git a/src/main/scala/chisel3/util/BitPat.scala b/src/main/scala/chisel3/util/BitPat.scala index 2cf8470605c..232b6da41a9 100644 --- a/src/main/scala/chisel3/util/BitPat.scala +++ b/src/main/scala/chisel3/util/BitPat.scala @@ -120,8 +120,7 @@ object BitPat { * }}} */ sealed class BitPat(val value: BigInt, val mask: BigInt, w: Int) extends BitSet with SourceInfoDoc { - val onSet = value & mask - val offSet = ~value & mask + override val terms = Set(this) override lazy val width = w def getWidth: Int = w @@ -158,14 +157,63 @@ sealed class BitPat(val value: BigInt, val mask: BigInt, w: Int) extends BitSet new BitPat((value << that.getWidth) + that.value, (mask << that.getWidth) + that.mask, this.width + that.getWidth) } - /** Generate raw string of a BitPat. */ - def rawString: String = Seq.tabulate(width) { i => + def overlap(rhs: BitPat): Boolean = ((mask & rhs.mask) & (value ^ rhs.value)) == 0 + + def cover(that: BitPat): Boolean = (mask & (~that.mask | (value ^ that.value))) == 0 + + def intersect(that: BitPat): BitSet = { + if (!overlap(that)) { + BitSet.emptyBitSet + } else { + new BitPat(this.value | that.value, this.mask | that.mask, this.width.max(that.width)) + } + } + + def subtract(that: BitPat): BitSet = { + require(width == that.width) + def enumerateBits(mask: BigInt): Seq[BigInt] = { + if (mask == 0) { + Nil + } else { + // bits comes after the first '1' in a number are inverted in its two's complement. + // therefore bit is always the first '1' in x (counting from least significant bit). + val bit = mask & (-mask) + bit +: enumerateBits(mask & ~bit) + } + } + + val intersection = intersect(that) + val omask = this.mask + if (intersection.isEmpty) { + this + } else { + new BitSet { + val terms = + intersection.terms.flatMap { remove => + enumerateBits(~omask & remove.mask).map { bit => + // Only care about higher than current bit in remove + val nmask = (omask | ~(bit - 1)) & remove.mask + val nvalue = (remove.value ^ bit) & nmask + val nwidth = remove.width + new BitPat(nvalue, nmask, nwidth) + } + } + } + } + } + + override def isEmpty: Boolean = false + + /** Generate raw string of a BitSat. */ + def rawString: String = Seq + .tabulate(width) { i => (value.testBit(width - i - 1), mask.testBit(width - i - 1)) match { - case (true, true) => "1" - case (false, true) => "0" - case (_, false) => "?" + case (true, true) => "1" + case (false, true) => "0" + case (_, false) => "?" + } } - }.mkString + .mkString override def toString = s"BitPat($rawString)" } diff --git a/src/main/scala/chisel3/util/BitSet.scala b/src/main/scala/chisel3/util/BitSet.scala index 26ba20322c5..3fc0ef1bf26 100644 --- a/src/main/scala/chisel3/util/BitSet.scala +++ b/src/main/scala/chisel3/util/BitSet.scala @@ -2,112 +2,108 @@ package chisel3.util -trait BitSetFamily { bsf => - val terms: Set[BitSet] - - // set width = 0 if terms is empty. +trait BitSet { bsf => + val terms: Set[BitPat] + + /** + * Get specified width of said BitSetFamily + * + * All BitPat contained in one family should have the same width + */ lazy val width: Int = { - assert(terms.map(_.width).size < 1) + assert(terms.map(_.width).size <= 1) + // set width = 0 if terms is empty. terms.headOption.map(_.width).getOrElse(0) } - override def toString: String = terms.toSeq.sortBy((t: BitSet) => (t.onSet, t.offSet)).mkString("\n") + override def toString: String = terms.toSeq.sortBy((t: BitPat) => (t.mask, t.value)).mkString("\n") + /** + * @return whether the BitSetFamily is empty (i.e. no value matches) + */ def isEmpty: Boolean = terms.forall(_.isEmpty) - def cover(that: BitSetFamily): Boolean = - if (terms.isEmpty) - that.terms.isEmpty - else - terms.flatMap(a => that.terms.map(b => (a, b))).foldLeft(true) { - case (left, (a, b)) => a.cover(b) & left - } - - def intersect(that: BitSetFamily): BitSetFamily = new BitSetFamily { - val terms = bsf.terms.flatMap(a => that.terms.map(b => a.intersect(b))).filterNot(_.isEmpty) + /** + * @return whether this BitSetFamily overlap with that BitSetFamily, i.e. !(intersect.isEmpty) + */ + def overlap(that: BitSet): Boolean = { + !bsf.terms.flatMap(a => that.terms.map(b => (a, b))).forall { case (a, b) => !a.overlap(b) } } - def subtract(that: BitSetFamily): BitSetFamily = new BitSetFamily { - val terms = bsf.terms.flatMap(a => that.terms.map(b => a.subtract(b))).filterNot(_.isEmpty) - } - - def union(that: BitSetFamily): BitSetFamily = new BitSetFamily { + /** + * @param that BitSetFamily to be covered + * @return whether this BitSetFamily covers that (i.e. forall b matches that, b also matches this) + */ + def cover(that: BitSet): Boolean = + that.subtract(this).isEmpty + + /** + * @return a BitSetFamily that only match a value when both operand match + */ + def intersect(that: BitSet): BitSet = + bsf.terms + .flatMap(a => that.terms.map(b => a.intersect(b))) + .filterNot(_.isEmpty) + .fold(BitSet.emptyBitSet)(_.union(_)) + + /** + * Subtract that from this BitSetFamily + * @param that subtrahend + */ + def subtract(that: BitSet): BitSet = + bsf.terms.map { a => + that.terms.map(b => a.subtract(b)).fold(a)(_.intersect(_)) + }.filterNot(_.isEmpty).fold(BitSet.emptyBitSet)(_.union(_)) + + /** + * Union of two BitSetFamily + */ + def union(that: BitSet): BitSet = new BitSet { val terms = bsf.terms ++ that.terms } + /** + * Test whether two BitSetFamily matches the same set of value + * + * Caution: This method can be very expensive compared to ordinary == operator between two Objects + */ override def equals(obj: Any): Boolean = { obj match { - case that: BitSetFamily => if (width == that.width) this.cover(that) && that.cover(this) else false + case that: BitSet => if (width == that.width) this.cover(that) && that.cover(this) else false case _ => false } } } -trait BitSet extends BitSetFamily { b => - override val terms = Set(this) - val onSet: BigInt - val offSet: BigInt - val width: Int - - assert(width > 0) - - def cover(that: BitSet) = ((onSet & that.onSet) == that.onSet) & ((offSet & that.offSet) == that.offSet) - - def intersect(that: BitSet): BitSet = { - require(width == that.width) - new BitSet { - val onSet: BigInt = b.onSet & that.onSet - val offSet: BigInt = b.offSet & that.offSet - override lazy val width: Int = b.width - } +object BitSet { + val emptyBitSet: BitSet = new BitSet { + override val terms = Set() } - def subtract(that: BitSet): BitSet = { - require(width == that.width) + def apply(str: String): BitSet = { new BitSet { - val onSet: BigInt = b.onSet & ~that.onSet - val offSet: BigInt = b.offSet & ~that.offSet - override lazy val width: Int = b.width - } - } - - override def isEmpty: Boolean = (onSet | offSet) == 0 - override def toString: String = Seq - .tabulate(width) { i => - (onSet.testBit(i), offSet.testBit(i)) match { - case (true, true) => "-" - case (true, false) => "1" - case (false, true) => "0" - case (false, false) => "~" - } - } - .mkString -} - -object BitSet { - def apply(str: String): BitSetFamily = { - new BitSetFamily { val terms = str .split('\n') - .map(str => - new BitSet { - val onSet: BigInt = str.zipWithIndex.map { - case ('-' | '1', i) => BigInt(1) << i - case ('~' | '0', _) => BigInt(0) - case _ => throw new Exception("bitset parse error") - }.sum - val offSet: BigInt = str.zipWithIndex.map { - case ('-' | '0', i) => BigInt(1) << i - case ('~' | '1', _) => BigInt(0) - case _ => throw new Exception("bitset parse error") - }.sum - override lazy val width: Int = str.length - } - ).toSet + .map(str => BitPat(str)) + .toSet + assert(terms.map(_.width).size <= 1) } } - def decode(input: chisel3.UInt, bitSets: Seq[BitSetFamily], errorBit: Boolean = false) = + /** + * Generate a decoder circuit that matches the input to each bitSet. + * + * The resulting circuit functions like the following but is optimized with a logic minifier + * when(input === bitSets(0)) { output := b000001 } + * .elsewhen (input === bitSets(1)) { output := b000010 } + * .... + * .otherwise { if (errorBit) output := b100000 else output := DontCare } + * + * @param input input to the decoder circuit, width should be equal to bitSets.width + * @param bitSets set of ports to be matched, all width should be the equal + * @param errorBit whether generate an additional decode error bit + */ + def decode(input: chisel3.UInt, bitSets: Seq[BitSet], errorBit: Boolean = false): chisel3.UInt = chisel3.util.experimental.decode.decoder( input, chisel3.util.experimental.decode.TruthTable( diff --git a/src/test/scala/chiselTests/util/BitPatSpec.scala b/src/test/scala/chiselTests/util/BitPatSpec.scala index 0c83493fe35..78d1ad75a0e 100644 --- a/src/test/scala/chiselTests/util/BitPatSpec.scala +++ b/src/test/scala/chiselTests/util/BitPatSpec.scala @@ -24,7 +24,7 @@ class BitPatSpec extends AnyFlatSpec with Matchers { intercept[IllegalArgumentException]{BitPat("b")} } - it should "contact BitPat via ##" in { + it should "concat BitPat via ##" in { (BitPat.Y(4) ## BitPat.dontCare(3) ## BitPat.N(2)).toString should be (s"BitPat(1111???00)") } diff --git a/src/test/scala/chiselTests/util/BitSetSpec.scala b/src/test/scala/chiselTests/util/BitSetSpec.scala new file mode 100644 index 00000000000..730182cf013 --- /dev/null +++ b/src/test/scala/chiselTests/util/BitSetSpec.scala @@ -0,0 +1,94 @@ +package chiselTests.util + +import chisel3.util.{BitPat, BitSet} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class BitSetSpec extends AnyFlatSpec with Matchers { + behavior of classOf[BitSet].toString + + it should "reject unequal width when constructing a BitSet" in { + intercept[AssertionError] { + BitSet( + """b0010 + |b00010 + |""".stripMargin) + } + } + + it should "return empty subtraction result correctly" in { + val aBitPat = BitPat("b10?") + val bBitPat = BitPat("b1??") + + aBitPat.subtract(bBitPat).isEmpty should be (true) + } + + it should "return nonempty subtraction result correctly" in { + val aBitPat = BitPat("b10?") + val bBitPat = BitPat("b1??") + val cBitPat = BitPat("b11?") + val dBitPat = BitPat("b100") + + val diffBitPat = bBitPat.subtract(aBitPat) + bBitPat.cover(diffBitPat) should be (true) + diffBitPat.equals(cBitPat) should be (true) + + val largerdiffBitPat = bBitPat.subtract(dBitPat) + aBitPat.cover(dBitPat) should be (true) + largerdiffBitPat.cover(diffBitPat) should be (true) + } + + it should "be able to handle complex subtract between BitSet" in { + val aBitSet = BitSet( + """b?01?0 + |b11111 + |b00000 + |""".stripMargin) + val bBitSet = BitSet( + """b?1111 + |b?0000 + |""".stripMargin + ) + val expected = BitPat("b?01?0") + + expected.equals(aBitSet.subtract(bBitSet)) should be (true) + } + + it should "be generated from BitPat union" in { + val aBitSet = BitSet( + """b001?0 + |b000??""".stripMargin) + val aBitPat = BitPat("b000??") + val bBitPat = BitPat("b001?0") + val cBitPat = BitPat("b00000") + aBitPat.cover(cBitPat) should be (true) + aBitSet.cover(bBitPat) should be (true) + + aBitSet.equals(aBitPat.union(bBitPat)) should be (true) + } + + it should "be generated from BitPat substraction" in { + val aBitSet = BitSet( + """b001?0 + |b000??""".stripMargin) + val aBitPat = BitPat("b00???") + val bBitPat = BitPat("b001?1") + + aBitSet.equals(aBitPat.subtract(bBitPat)) should be (true) + } + + it should "union two BitSet together" in { + val aBitSet = BitSet( + """b001?0 + |b001?1 + |""".stripMargin) + val bBitSet = BitSet( + """b000?? + |b01??? + |""".stripMargin + ) + val cBitPat = BitPat("b0????") + cBitPat.equals(aBitSet.union(bBitSet)) should be (true) + } + +}