Skip to content

Commit

Permalink
Merge pull request #2223 from OceanS2000/bitset
Browse files Browse the repository at this point in the history
Refactor BitSet implementation with unittest and doc
  • Loading branch information
sequencer authored Oct 29, 2021
2 parents 335c11e + 40005b3 commit 6bd2b0c
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 88 deletions.
64 changes: 56 additions & 8 deletions src/main/scala/chisel3/util/BitPat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ object BitPat {
* }}}
*/
sealed class BitPat(val value: BigInt, val mask: BigInt, w: Int) extends BitSet with SourceInfoDoc {
val onSet = value & mask
val offSet = ~value & mask
override val terms = Set(this)
override lazy val width = w

def getWidth: Int = w
Expand Down Expand Up @@ -158,14 +157,63 @@ sealed class BitPat(val value: BigInt, val mask: BigInt, w: Int) extends BitSet
new BitPat((value << that.getWidth) + that.value, (mask << that.getWidth) + that.mask, this.width + that.getWidth)
}

/** Generate raw string of a BitPat. */
def rawString: String = Seq.tabulate(width) { i =>
def overlap(rhs: BitPat): Boolean = ((mask & rhs.mask) & (value ^ rhs.value)) == 0

def cover(that: BitPat): Boolean = (mask & (~that.mask | (value ^ that.value))) == 0

def intersect(that: BitPat): BitSet = {
if (!overlap(that)) {
BitSet.emptyBitSet
} else {
new BitPat(this.value | that.value, this.mask | that.mask, this.width.max(that.width))
}
}

def subtract(that: BitPat): BitSet = {
require(width == that.width)
def enumerateBits(mask: BigInt): Seq[BigInt] = {
if (mask == 0) {
Nil
} else {
// bits comes after the first '1' in a number are inverted in its two's complement.
// therefore bit is always the first '1' in x (counting from least significant bit).
val bit = mask & (-mask)
bit +: enumerateBits(mask & ~bit)
}
}

val intersection = intersect(that)
val omask = this.mask
if (intersection.isEmpty) {
this
} else {
new BitSet {
val terms =
intersection.terms.flatMap { remove =>
enumerateBits(~omask & remove.mask).map { bit =>
// Only care about higher than current bit in remove
val nmask = (omask | ~(bit - 1)) & remove.mask
val nvalue = (remove.value ^ bit) & nmask
val nwidth = remove.width
new BitPat(nvalue, nmask, nwidth)
}
}
}
}
}

override def isEmpty: Boolean = false

/** Generate raw string of a BitSat. */
def rawString: String = Seq
.tabulate(width) { i =>
(value.testBit(width - i - 1), mask.testBit(width - i - 1)) match {
case (true, true) => "1"
case (false, true) => "0"
case (_, false) => "?"
case (true, true) => "1"
case (false, true) => "0"
case (_, false) => "?"
}
}
}.mkString
.mkString

override def toString = s"BitPat($rawString)"
}
154 changes: 75 additions & 79 deletions src/main/scala/chisel3/util/BitSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,112 +2,108 @@

package chisel3.util

trait BitSetFamily { bsf =>
val terms: Set[BitSet]

// set width = 0 if terms is empty.
trait BitSet { bsf =>
val terms: Set[BitPat]

/**
* Get specified width of said BitSetFamily
*
* All BitPat contained in one family should have the same width
*/
lazy val width: Int = {
assert(terms.map(_.width).size < 1)
assert(terms.map(_.width).size <= 1)
// set width = 0 if terms is empty.
terms.headOption.map(_.width).getOrElse(0)
}

override def toString: String = terms.toSeq.sortBy((t: BitSet) => (t.onSet, t.offSet)).mkString("\n")
override def toString: String = terms.toSeq.sortBy((t: BitPat) => (t.mask, t.value)).mkString("\n")

/**
* @return whether the BitSetFamily is empty (i.e. no value matches)
*/
def isEmpty: Boolean = terms.forall(_.isEmpty)

def cover(that: BitSetFamily): Boolean =
if (terms.isEmpty)
that.terms.isEmpty
else
terms.flatMap(a => that.terms.map(b => (a, b))).foldLeft(true) {
case (left, (a, b)) => a.cover(b) & left
}

def intersect(that: BitSetFamily): BitSetFamily = new BitSetFamily {
val terms = bsf.terms.flatMap(a => that.terms.map(b => a.intersect(b))).filterNot(_.isEmpty)
/**
* @return whether this BitSetFamily overlap with that BitSetFamily, i.e. !(intersect.isEmpty)
*/
def overlap(that: BitSet): Boolean = {
!bsf.terms.flatMap(a => that.terms.map(b => (a, b))).forall { case (a, b) => !a.overlap(b) }
}

def subtract(that: BitSetFamily): BitSetFamily = new BitSetFamily {
val terms = bsf.terms.flatMap(a => that.terms.map(b => a.subtract(b))).filterNot(_.isEmpty)
}

def union(that: BitSetFamily): BitSetFamily = new BitSetFamily {
/**
* @param that BitSetFamily to be covered
* @return whether this BitSetFamily covers that (i.e. forall b matches that, b also matches this)
*/
def cover(that: BitSet): Boolean =
that.subtract(this).isEmpty

/**
* @return a BitSetFamily that only match a value when both operand match
*/
def intersect(that: BitSet): BitSet =
bsf.terms
.flatMap(a => that.terms.map(b => a.intersect(b)))
.filterNot(_.isEmpty)
.fold(BitSet.emptyBitSet)(_.union(_))

/**
* Subtract that from this BitSetFamily
* @param that subtrahend
*/
def subtract(that: BitSet): BitSet =
bsf.terms.map { a =>
that.terms.map(b => a.subtract(b)).fold(a)(_.intersect(_))
}.filterNot(_.isEmpty).fold(BitSet.emptyBitSet)(_.union(_))

/**
* Union of two BitSetFamily
*/
def union(that: BitSet): BitSet = new BitSet {
val terms = bsf.terms ++ that.terms
}

/**
* Test whether two BitSetFamily matches the same set of value
*
* Caution: This method can be very expensive compared to ordinary == operator between two Objects
*/
override def equals(obj: Any): Boolean = {
obj match {
case that: BitSetFamily => if (width == that.width) this.cover(that) && that.cover(this) else false
case that: BitSet => if (width == that.width) this.cover(that) && that.cover(this) else false
case _ => false
}
}
}

trait BitSet extends BitSetFamily { b =>
override val terms = Set(this)
val onSet: BigInt
val offSet: BigInt
val width: Int

assert(width > 0)

def cover(that: BitSet) = ((onSet & that.onSet) == that.onSet) & ((offSet & that.offSet) == that.offSet)

def intersect(that: BitSet): BitSet = {
require(width == that.width)
new BitSet {
val onSet: BigInt = b.onSet & that.onSet
val offSet: BigInt = b.offSet & that.offSet
override lazy val width: Int = b.width
}
object BitSet {
val emptyBitSet: BitSet = new BitSet {
override val terms = Set()
}

def subtract(that: BitSet): BitSet = {
require(width == that.width)
def apply(str: String): BitSet = {
new BitSet {
val onSet: BigInt = b.onSet & ~that.onSet
val offSet: BigInt = b.offSet & ~that.offSet
override lazy val width: Int = b.width
}
}

override def isEmpty: Boolean = (onSet | offSet) == 0
override def toString: String = Seq
.tabulate(width) { i =>
(onSet.testBit(i), offSet.testBit(i)) match {
case (true, true) => "-"
case (true, false) => "1"
case (false, true) => "0"
case (false, false) => "~"
}
}
.mkString
}

object BitSet {
def apply(str: String): BitSetFamily = {
new BitSetFamily {
val terms = str
.split('\n')
.map(str =>
new BitSet {
val onSet: BigInt = str.zipWithIndex.map {
case ('-' | '1', i) => BigInt(1) << i
case ('~' | '0', _) => BigInt(0)
case _ => throw new Exception("bitset parse error")
}.sum
val offSet: BigInt = str.zipWithIndex.map {
case ('-' | '0', i) => BigInt(1) << i
case ('~' | '1', _) => BigInt(0)
case _ => throw new Exception("bitset parse error")
}.sum
override lazy val width: Int = str.length
}
).toSet
.map(str => BitPat(str))
.toSet
assert(terms.map(_.width).size <= 1)
}
}

def decode(input: chisel3.UInt, bitSets: Seq[BitSetFamily], errorBit: Boolean = false) =
/**
* Generate a decoder circuit that matches the input to each bitSet.
*
* The resulting circuit functions like the following but is optimized with a logic minifier
* when(input === bitSets(0)) { output := b000001 }
* .elsewhen (input === bitSets(1)) { output := b000010 }
* ....
* .otherwise { if (errorBit) output := b100000 else output := DontCare }
*
* @param input input to the decoder circuit, width should be equal to bitSets.width
* @param bitSets set of ports to be matched, all width should be the equal
* @param errorBit whether generate an additional decode error bit
*/
def decode(input: chisel3.UInt, bitSets: Seq[BitSet], errorBit: Boolean = false): chisel3.UInt =
chisel3.util.experimental.decode.decoder(
input,
chisel3.util.experimental.decode.TruthTable(
Expand Down
2 changes: 1 addition & 1 deletion src/test/scala/chiselTests/util/BitPatSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BitPatSpec extends AnyFlatSpec with Matchers {
intercept[IllegalArgumentException]{BitPat("b")}
}

it should "contact BitPat via ##" in {
it should "concat BitPat via ##" in {
(BitPat.Y(4) ## BitPat.dontCare(3) ## BitPat.N(2)).toString should be (s"BitPat(1111???00)")
}

Expand Down
94 changes: 94 additions & 0 deletions src/test/scala/chiselTests/util/BitSetSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package chiselTests.util

import chisel3.util.{BitPat, BitSet}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class BitSetSpec extends AnyFlatSpec with Matchers {
behavior of classOf[BitSet].toString

it should "reject unequal width when constructing a BitSet" in {
intercept[AssertionError] {
BitSet(
"""b0010
|b00010
|""".stripMargin)
}
}

it should "return empty subtraction result correctly" in {
val aBitPat = BitPat("b10?")
val bBitPat = BitPat("b1??")

aBitPat.subtract(bBitPat).isEmpty should be (true)
}

it should "return nonempty subtraction result correctly" in {
val aBitPat = BitPat("b10?")
val bBitPat = BitPat("b1??")
val cBitPat = BitPat("b11?")
val dBitPat = BitPat("b100")

val diffBitPat = bBitPat.subtract(aBitPat)
bBitPat.cover(diffBitPat) should be (true)
diffBitPat.equals(cBitPat) should be (true)

val largerdiffBitPat = bBitPat.subtract(dBitPat)
aBitPat.cover(dBitPat) should be (true)
largerdiffBitPat.cover(diffBitPat) should be (true)
}

it should "be able to handle complex subtract between BitSet" in {
val aBitSet = BitSet(
"""b?01?0
|b11111
|b00000
|""".stripMargin)
val bBitSet = BitSet(
"""b?1111
|b?0000
|""".stripMargin
)
val expected = BitPat("b?01?0")

expected.equals(aBitSet.subtract(bBitSet)) should be (true)
}

it should "be generated from BitPat union" in {
val aBitSet = BitSet(
"""b001?0
|b000??""".stripMargin)
val aBitPat = BitPat("b000??")
val bBitPat = BitPat("b001?0")
val cBitPat = BitPat("b00000")
aBitPat.cover(cBitPat) should be (true)
aBitSet.cover(bBitPat) should be (true)

aBitSet.equals(aBitPat.union(bBitPat)) should be (true)
}

it should "be generated from BitPat substraction" in {
val aBitSet = BitSet(
"""b001?0
|b000??""".stripMargin)
val aBitPat = BitPat("b00???")
val bBitPat = BitPat("b001?1")

aBitSet.equals(aBitPat.subtract(bBitPat)) should be (true)
}

it should "union two BitSet together" in {
val aBitSet = BitSet(
"""b001?0
|b001?1
|""".stripMargin)
val bBitSet = BitSet(
"""b000??
|b01???
|""".stripMargin
)
val cBitPat = BitPat("b0????")
cBitPat.equals(aBitSet.union(bBitSet)) should be (true)
}

}

0 comments on commit 6bd2b0c

Please sign in to comment.