From 09a03ffd06a6dc8fb46934c76d98b83410c85e83 Mon Sep 17 00:00:00 2001 From: Schuyler Eldridge Date: Tue, 11 Dec 2018 22:48:38 -0500 Subject: [PATCH] Add Dependency API to firrtl.options package This adds a Dependency API (#948) to the firrtl.options package. This adds two new methods to Phase: prerequisites and invalidates. The former defines what Phases must have run before this Phase. The latter is a function that can be used to determine if this Phase invalidates another Phase. Additionally, this introduces a PhaseManager that subclasses Phase that determines a sequence of Phases to perform a requested lowering (from a given initial state ensure that some set of Phases are executed). This follows the original suggestion of @azidar in #446 for the PhaseManager algorithm with one modification. The (DFS-based) topological sort of dependencies is seeded using a topological sort of the invalidations. This ensures that the number of repeated Phases is kept to a minimum. (I am not sure if this is actually optimal, however.) DiGraph is updated with a seeded topological sort which the original topological sort method (linearize) now extends. Signed-off-by: Schuyler Eldridge --- src/main/scala/firrtl/graph/DiGraph.scala | 19 +- src/main/scala/firrtl/options/Phase.scala | 20 +- .../scala/firrtl/options/PhaseManager.scala | 117 +++++++++ .../options/PhaseManagerSpec.scala | 223 ++++++++++++++++++ 4 files changed, 371 insertions(+), 8 deletions(-) create mode 100644 src/main/scala/firrtl/options/PhaseManager.scala create mode 100644 src/test/scala/firrtlTests/options/PhaseManagerSpec.scala diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 9cfcad520f..972fcab9a8 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -64,13 +64,14 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link */ def findSinks: Set[T] = reverse.findSources - /** Linearizes (topologically sorts) a DAG - * + /** Linearizes (topologically sorts) a DAG using a DFS. This can be seeded with an order to use for the DFS if the user + * wants to tease out a special ordering of the DAG. + * @param seed an optional sequence of vertices to use. This will default to the vertices ordering provided by getVertices. * @throws CyclicException if the graph is cyclic * @return a Map[T,T] from each visited node to its predecessor in the * traversal */ - def linearize: Seq[T] = { + def seededLinearize(seed: Option[Seq[T]] = None): Seq[T] = { // permanently marked nodes are implicitly held in order val order = new mutable.ArrayBuffer[T] // invariant: no intersection between unmarked and tempMarked @@ -80,7 +81,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link case class LinearizeFrame[T](v: T, expanded: Boolean) val callStack = mutable.Stack[LinearizeFrame[T]]() - unmarked ++= getVertices + unmarked ++= seed.getOrElse(getVertices) while (unmarked.nonEmpty) { callStack.push(LinearizeFrame(unmarked.head, false)) while (callStack.nonEmpty) { @@ -109,6 +110,14 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link order.reverse.toSeq } + /** Linearizes (topologically sorts) a DAG + * + * @throws CyclicException if the graph is cyclic + * @return a Map[T,T] from each visited node to its predecessor in the + * traversal + */ + def linearize: Seq[T] = seededLinearize(None) + /** Performs breadth-first search on the directed graph * * @param root the start node @@ -163,7 +172,7 @@ class DiGraph[T] private[graph] (private[graph] val edges: LinkedHashMap[T, Link * @return a Seq[T] of nodes defining an arbitrary valid path */ def path(start: T, end: T): Seq[T] = path(start, end, Set.empty[T]) - + /** Finds a path (if one exists) from one node to another, with a blacklist * * @param start the start node diff --git a/src/main/scala/firrtl/options/Phase.scala b/src/main/scala/firrtl/options/Phase.scala index 3473905388..d393fb4e05 100644 --- a/src/main/scala/firrtl/options/Phase.scala +++ b/src/main/scala/firrtl/options/Phase.scala @@ -7,8 +7,6 @@ import firrtl.annotations.DeletedAnnotation import logger.LazyLogging -import scala.collection.mutable - /** A polymorphic mathematical transform * @tparam A the transformed type */ @@ -25,12 +23,28 @@ trait TransformLike[A] extends LazyLogging { } +trait DependencyAPI { this: Phase => + + /** All [[Phase]]s that must run before this [[Phase]] */ + def prerequisites: Set[Phase] = Set.empty + + /** A function that, given some other [[Phase]], will return [[true]] if this [[Phase]] invalidates the other [[Phase]]. + * By default, this invalidates everything except itself. + * @note Can a [[Phase]] ever invalidate itself? + */ + def invalidates(phase: Phase): Boolean = phase match { + case _: this.type => false + case _ => true + } + +} + /** A mathematical transformation of an [[AnnotationSeq]]. * * A [[Phase]] forms one unit in the Chisel/FIRRTL Hardware Compiler Framework (HCF). The HCF is built from a sequence * of [[Phase]]s applied to an [[AnnotationSeq]]. Note that a [[Phase]] may consist of multiple phases internally. */ -abstract class Phase extends TransformLike[AnnotationSeq] { +abstract class Phase extends TransformLike[AnnotationSeq] with DependencyAPI { /** The name of this [[Phase]]. This will be used to generate debug/error messages or when deleting annotations. This * will default to the `simpleName` of the class. diff --git a/src/main/scala/firrtl/options/PhaseManager.scala b/src/main/scala/firrtl/options/PhaseManager.scala new file mode 100644 index 0000000000..c83185ff77 --- /dev/null +++ b/src/main/scala/firrtl/options/PhaseManager.scala @@ -0,0 +1,117 @@ +// See LICENSE for license details. + +package firrtl.options + +import firrtl.AnnotationSeq +import firrtl.graph.{DiGraph, CyclicException} + +import scala.collection.mutable + +case class PhaseManagerException(message: String, cause: Throwable = null) extends RuntimeException(message, cause) + +/** A [[Phase]] that will ensure that some other [[Phase]]s and their prerequisites are executed. + * + * This tries to determine a phase ordering such that an [[AnnotationSeq]] ''output'' is produced that has had all of + * the requested [[Phase]] target transforms run without having them be invalidated. + * @param phaseTargets the [[Phase]]s you want to run + */ +case class PhaseManager(phaseTargets: Set[Phase], currentState: Set[Phase] = Set.empty) extends Phase { + + /** Modified breadth-first search that supports multiple starting nodes and a custom extractor that can be used to + * generate/filter the edges to explore. Additionally, this will include edges to previously discovered nodes. + */ + private def bfs(start: Set[Phase], blacklist: Set[Phase], extractor: Phase => Set[Phase]): Map[Phase, Set[Phase]] = { + val queue: mutable.Queue[Phase] = mutable.Queue(start.toSeq:_*) + val edges: mutable.HashMap[Phase, Set[Phase]] = mutable.HashMap[Phase, Set[Phase]](start.map((_ -> Set[Phase]())).toSeq:_*) + while (queue.nonEmpty) { + val u = queue.dequeue + for (v <- extractor(u)) { + if (!blacklist.contains(v) && !edges.contains(v)) { queue.enqueue(v) } + if (!edges.contains(v)) { edges(v) = Set.empty } + edges(u) = edges(u) + v + } + } + edges.toMap + } + + /** Pull in all registered phases once phase registration is integrated + * @todo implement this + */ + private lazy val registeredPhases: Set[Phase] = Set.empty + + /** A directed graph consisting of prerequisite edges */ + lazy val dependencyGraph: DiGraph[Phase] = + DiGraph( + bfs( + start = phaseTargets, + blacklist = currentState, + extractor = (p: Phase) => p.prerequisites)) + + /** A directed graph consisting of edges derived from invalidation */ + lazy val invalidateGraph: DiGraph[Phase] = { + val v = dependencyGraph.getVertices + DiGraph( + bfs( + start = phaseTargets, + blacklist = currentState, + extractor = (p: Phase) => v.filter(p.invalidates).toSet)) + .reverse + } + + /** Wrap a possible [[CyclicException]] thrown by a thunk in a [[PhaseManagerException]] */ + private def cyclePossible[A](a: String, thunk: => A): A = try { thunk } catch { + case e: CyclicException => + throw new PhaseManagerException( + s"No Phase ordering possible due to cyclic dependency in $a at node '${e.node}'.", e) + } + + /** The ordering of phases to run that respects prerequisites and reduces the number of required re-lowerings resulting + * from invalidations. + */ + lazy val phaseOrder: Seq[Phase] = { + + /* Topologically sort the dependency graph using the invalidate graph topological sort as a seed. This has the effect + * of minimizing the number of repeated [[Phase]]s */ + val sorted = { + val seed = cyclePossible("invalidates", invalidateGraph.linearize).reverse + cyclePossible("prerequisites", + dependencyGraph + .seededLinearize(Some(seed)) + .reverse + .dropWhile(currentState.contains)) + } + + val (state, lowerers) = { + val (s, l) = sorted.foldLeft((currentState, Array[Phase]())){ case ((state, out), in) => + val missing = (in.prerequisites -- state) + val preprocessing: Option[Phase] = { + if (missing.nonEmpty) { Some(PhaseManager(missing, state)) } + else { None } + } + ((state ++ missing + in).filterNot(in.invalidates), out ++ preprocessing :+ in) + } + val missing = (phaseTargets -- s) + val postprocessing: Option[Phase] = { + if (missing.nonEmpty) { Some(PhaseManager(missing, s)) } + else { None } + } + + (s ++ missing, l ++ postprocessing) + } + + if (!phaseTargets.subsetOf(state)) { + throw new PhaseException(s"The final state ($state) did not include the requested targets (${phaseTargets})!") + } + lowerers + } + + def flattenedPhaseOrder: Seq[Phase] = phaseOrder.flatMap { + case p: PhaseManager => p.flattenedPhaseOrder + case p: Phase => Some(p) + } + + final def transform(annotations: AnnotationSeq): AnnotationSeq = + phaseOrder + .foldLeft(annotations){ case (a, p) => p.transform(a) } + +} diff --git a/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala b/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala new file mode 100644 index 0000000000..f2c97926fb --- /dev/null +++ b/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala @@ -0,0 +1,223 @@ +// See LICENSE for license details. + +package firrtlTests.options + +import org.scalatest.{FlatSpec, Matchers} + +import firrtl.AnnotationSeq +import firrtl.options.{Phase, PhaseManager, PhaseManagerException} +import firrtl.annotations.{Annotation, NoTargetAnnotation} + +trait IdentityPhase extends Phase { + def transform(annotations: AnnotationSeq): AnnotationSeq = annotations +} + +/** Default [[Phase]] that has no prerequisites and invalidates nothing */ +case object A extends IdentityPhase { + override def invalidates(phase: Phase): Boolean = false +} + +/** [[Phase]] that requires [[A]] and invalidates nothing */ +case object B extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(A) + override def invalidates(phase: Phase): Boolean = false +} + +/** [[Phase]] that requires [[B]] and invalidates nothing */ +case object C extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(A) + override def invalidates(phase: Phase): Boolean = false +} + +/** [[Phase]] that requires [[A]] and invalidates [[A]] */ +case object D extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(A) + override def invalidates(phase: Phase): Boolean = phase match { + case _: A.type => true + case _ => false + } +} + +/** [[Phase]] that requires [[B]] and invalidates nothing */ +case object E extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(B) + override def invalidates(phase: Phase): Boolean = false +} + +/** [[Phase]] that requires [[B]] and [[C]] and invalidates [[E]] */ +case object F extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(B, C) + override def invalidates(phase: Phase): Boolean = phase match { + case _: E.type => true + case _ => false + } +} + + +/** [[Phase]] that requires [[C]] and invalidates [[F]] */ +case object G extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(C) + override def invalidates(phase: Phase): Boolean = phase match { + case _: F.type => true + case _ => false + } +} + +case object CyclicA extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(CyclicB) +} + +case object CyclicB extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(CyclicA) +} + +class CyclicInvalidateFixture { + + case object A extends IdentityPhase { + override def invalidates(phase: Phase): Boolean = false + } + case object B extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(A) + override def invalidates(phase: Phase): Boolean = false + } + case object C extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(A) + override def invalidates(phase: Phase): Boolean = phase match { + case _: B.type => true + case _ => false + } + } + case object D extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(B) + override def invalidates(phase: Phase): Boolean = phase match { + case _: C.type | _: E.type => true + case _ => false + } + } + case object E extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(B) + override def invalidates(phase: Phase): Boolean = false + } + +} + +trait AnalysisFixture { + + case object Analysis extends IdentityPhase { + override def invalidates(phase: Phase): Boolean = false + } + +} + +class RepeatedAnalysisFixture extends AnalysisFixture { + + trait InvalidatesAnalysis extends IdentityPhase { + override def invalidates(phase: Phase): Boolean = phase match { + case _: Analysis.type => true + case _ => false + } + } + + case object A extends InvalidatesAnalysis { + override def prerequisites: Set[Phase] = Set(Analysis) + } + case object B extends InvalidatesAnalysis { + override def prerequisites: Set[Phase] = Set(A, Analysis) + } + case object C extends InvalidatesAnalysis { + override def prerequisites: Set[Phase] = Set(B, Analysis) + } + +} + +class InvertedAnalysisFixture extends AnalysisFixture { + + case object A extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(Analysis) + override def invalidates(phase: Phase): Boolean = phase match { + case _: Analysis.type => true + case _ => false + } + } + case object B extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(Analysis) + override def invalidates(phase: Phase): Boolean = phase match { + case _: Analysis.type | _: A.type => true + case _ => false + } + } + case object C extends IdentityPhase { + override def prerequisites: Set[Phase] = Set(Analysis) + override def invalidates(phase: Phase): Boolean = phase match { + case _: Analysis.type | _: B.type => true + case _ => false + } + } + +} + +class PhaseManagerSpec extends FlatSpec with Matchers { + + behavior of this.getClass.getName + + it should "do nothing if all targets are reached" in { + val targets: Set[Phase] = Set(A, B, C, D) + val pm = new PhaseManager(targets, targets) + + pm.flattenedPhaseOrder should be (empty) + } + + it should "handle a simple dependency" in { + val targets: Set[Phase] = Set(B) + val pm = new PhaseManager(targets) + + pm.flattenedPhaseOrder should be (Seq(A, B)) + } + + it should "handle a simple dependency with an invalidation" in { + val targets: Set[Phase] = Set(A, B, C, D) + val pm = new PhaseManager(targets) + + pm.flattenedPhaseOrder should be (Seq(A, D, A, C, B)) + } + + it should "handle a dependency with two invalidates optimally" in { + val targets: Set[Phase] = Set(A, B, C, E, F, G) + val pm = new PhaseManager(targets) + + pm.flattenedPhaseOrder.size should be (targets.size) + } + + it should "throw an exception for cyclic prerequisites" in { + val targets: Set[Phase] = Set(CyclicA, CyclicB) + val pm = new PhaseManager(targets) + + intercept[PhaseManagerException]{ pm.flattenedPhaseOrder } + .getMessage should startWith ("No Phase ordering possible") + } + + it should "handle invalidates that form a cycle" in { + val f = new CyclicInvalidateFixture + val targets: Set[Phase] = Set(f.A, f.B, f.C, f.D, f.E) + val pm = new PhaseManager(targets) + + info("only one phase was recomputed") + pm.flattenedPhaseOrder.size should be (targets.size + 1) + } + + it should "handle repeated recomputed analyses" in { + val f = new RepeatedAnalysisFixture + val targets: Set[Phase] = Set(f.A, f.B, f.C) + val pm = new PhaseManager(targets) + + pm.flattenedPhaseOrder should be (Seq(f.Analysis, f.A, f.Analysis, f.B, f.Analysis, f.C)) + } + + it should "handle inverted repeated recomputed analyses" in { + val f = new InvertedAnalysisFixture + val targets: Set[Phase] = Set(f.A, f.B, f.C) + val pm = new PhaseManager(targets) + + pm.flattenedPhaseOrder should be (Seq(f.Analysis, f.C, f.Analysis, f.B, f.Analysis, f.A)) + } +}