Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TruthTable.apply and add factory method for Espresso #2612

Merged
merged 10 commits into from
Jul 6, 2022
4 changes: 4 additions & 0 deletions src/main/scala/chisel3/util/BitPat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import scala.language.experimental.macros
import chisel3._
import chisel3.internal.sourceinfo.{SourceInfo, SourceInfoTransform}
import scala.collection.mutable
import scala.util.hashing.MurmurHash3

object BitPat {

Expand Down Expand Up @@ -326,6 +327,9 @@ sealed class BitPat(val value: BigInt, val mask: BigInt, val width: Int)
def =/=(that: UInt): Bool = macro SourceInfoTransform.thatArg
def ##(that: BitPat): BitPat = macro SourceInfoTransform.thatArg

override def hashCode: Int =
MurmurHash3.seqHash(Seq(this.value, this.mask, this.width))

/** @group SourceInfoTransformMacro */
def do_apply(x: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): BitPat = {
do_apply(x, x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@ object EspressoMinimizer extends Minimizer with LazyLogging {
logger.trace(s"""espresso output table:
|$output
|""".stripMargin)
TruthTable(readTable(output), table.default)
TruthTable.fromEspressoOutput(readTable(output), table.default)
}
}
117 changes: 89 additions & 28 deletions src/main/scala/chisel3/util/experimental/decode/TruthTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package chisel3.util.experimental.decode

import chisel3.util.BitPat
import scala.collection.mutable

sealed class TruthTable private (val table: Seq[(BitPat, BitPat)], val default: BitPat, val sort: Boolean) {
def inputWidth = table.head._1.getWidth
Expand All @@ -29,40 +30,89 @@ sealed class TruthTable private (val table: Seq[(BitPat, BitPat)], val default:

object TruthTable {

/** Convert a table and default output into a [[TruthTable]]. */
def apply(table: Iterable[(BitPat, BitPat)], default: BitPat, sort: Boolean = true): TruthTable = {
/** Pad the input signals to equalize all input widths. Pads input signals
* to the maximum width found in the table.
*
* @param table the truth table whose rows will be padded
* @return the same truth table but with inputs padded
*/
private def padInputs(table: Iterable[(BitPat, BitPat)]): Iterable[(BitPat, BitPat)] = {
val inputWidth = table.map(_._1.getWidth).max
require(table.map(_._2.getWidth).toSet.size == 1, "output width not equal.")
val outputWidth = table.map(_._2.getWidth).head
val mergedTable = table.map {
// pad input signals if necessary
table.map {
case (in, out) if inputWidth > in.width =>
(BitPat.N(inputWidth - in.width) ## in, out)
case (in, out) => (in, out)
}
.groupBy(_._1.toString)
.map {
case (key, values) =>
// merge same input inputs.
values.head._1 -> BitPat(s"b${Seq
.tabulate(outputWidth) { i =>
val outputSet = values
.map(_._2)
.map(_.rawString)
.map(_(i))
.toSet
.filterNot(_ == '?')
require(
outputSet.size != 2,
s"TruthTable conflict in :\n${values.map { case (i, o) => s"${i.rawString}->${o.rawString}" }.mkString("\n")}"
)
outputSet.headOption.getOrElse('?')
}
.mkString}")
}
.toSeq
}

/** For each duplicated input, collect the outputs into a single Seq.
*
* @param table the truth table
* @return a Seq of tuple of length 2, where the first element is the
* input and the second element is a Seq of OR-ed outputs
* for the input
*/
private def mergeTableOnInputs(table: Iterable[(BitPat, BitPat)]): Seq[(BitPat, Seq[BitPat])] = {
groupByIntoSeq(table)(_._1).map {
adkian-sifive marked this conversation as resolved.
Show resolved Hide resolved
case (input, mappings) =>
input -> mappings.map(_._2)
}
}

/** Merge two BitPats by OR-ing the values and masks, and setting the
* width to the max width among the two
*/
private def merge(a: BitPat, b: BitPat): BitPat = {
new BitPat(a.value | b.value, a.mask | b.mask, a.width.max(b.width))
}

/** Public method for calling with the Espresso decoder format fd
*
* For Espresso, for each output, a 1 means this product term belongs to the ON-set,
* a 0 means this product term has no meaning for the value of this function.
* This is the same as the fd (or f) type in espresso.
*
* @param table the truth table
* @param default the default BitPat is made up of a single bit type, either "?", "0" or "1".
* A default of "?" sets Espresso to fr-format, while a "0" or "1" sets it to the
* fd-format.
* @param sort whether to sort the final truth table using BitPat.bitPatOrder
* @return a fully built TruthTable
*/
def fromEspressoOutput(table: Iterable[(BitPat, BitPat)], default: BitPat, sort: Boolean = true): TruthTable = {
apply_impl(table, default, sort, false)
}

/** Public apply method to TruthTable. Calls apply_impl with the default value true of checkCollisions */
def apply(table: Iterable[(BitPat, BitPat)], default: BitPat, sort: Boolean = true): TruthTable = {
adkian-sifive marked this conversation as resolved.
Show resolved Hide resolved
adkian-sifive marked this conversation as resolved.
Show resolved Hide resolved
apply_impl(table, default, sort, true)
}

/** Convert a table and default output into a [[TruthTable]]. */
private def apply_impl(
table: Iterable[(BitPat, BitPat)],
default: BitPat,
sort: Boolean,
checkCollisions: Boolean
): TruthTable = {
val paddedTable = padInputs(table)

require(table.map(_._2.getWidth).toSet.size == 1, "output width not equal.")

val mergedTable = mergeTableOnInputs(paddedTable)

val finalTable: Seq[(BitPat, BitPat)] = mergedTable.map {
case (input, outputs) =>
val (result, noCollisions) = outputs.tail.foldLeft((outputs.head, checkCollisions)) {
case ((acc, ok), o) => (merge(acc, o), ok && acc.overlap(o))
}
// Throw an error if checkCollisions is true but there are bits with a non-zero overlap.
require(!checkCollisions || noCollisions, s"TruthTable conflict on merged row: \n\t$input -> $outputs")
adkian-sifive marked this conversation as resolved.
Show resolved Hide resolved
(input, result)
}

import BitPat.bitPatOrder
new TruthTable(if (sort) mergedTable.sorted else mergedTable, default, sort)
new TruthTable(if (sort) finalTable.sorted else finalTable, default, sort)
}

/** Parse TruthTable from its string representation. */
Expand Down Expand Up @@ -140,4 +190,15 @@ object TruthTable {
bitPat(tables.flatMap { case (table, indexes) => table.default.rawString.zip(indexes) })
)
}

/** Similar to Seq.groupBy except that it preserves ordering of elements within each group */
private def groupByIntoSeq[A, K](xs: Iterable[A])(f: A => K): Seq[(K, Seq[A])] = {
val map = mutable.LinkedHashMap.empty[K, mutable.ListBuffer[A]]
for (x <- xs) {
val key = f(x)
val l = map.getOrElseUpdate(key, mutable.ListBuffer.empty[A])
l += x
}
map.view.map({ case (k, vs) => k -> vs.toList }).toList
}
}
17 changes: 17 additions & 0 deletions src/test/scala/chiselTests/util/experimental/TruthTableSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,21 @@ class TruthTableSpec extends AnyFlatSpec {
assert(t.toString contains "111->?")
assert(t.toString contains " 0")
}

"Using TruthTable.fromEspressoOutput" should "merge rows on conflict" in {
val mapping = List(
(BitPat("b110"), BitPat("b001")),
(BitPat("b111"), BitPat("b001")),
(BitPat("b111"), BitPat("b010")),
(BitPat("b111"), BitPat("b100"))
)

assert(
TruthTable.fromEspressoOutput(mapping, BitPat("b?")) ==
TruthTable.fromString("""110->001
|111->111
|?
|""".stripMargin)
)
}
}