From 6c95e94d1969c21cd330c7c7ae582b4113a7ae48 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Tue, 29 Aug 2017 15:00:01 -0400 Subject: [PATCH 1/2] Clean up Dagon API a bit. * Introduce Cache analog to HCache * Rename ExpressionDag to Dag * Move DependantGraph to tests, rename to SimpleDag * A bit of other code clean up * Some extra docs --- README.md | 6 +- .../main/scala/com/stripe/dagon/Cache.scala | 70 +++++++++++++++ .../dagon/{ExpressionDag.scala => Dag.scala} | 85 +++++++++++-------- .../main/scala/com/stripe/dagon/HCache.scala | 57 +++++++++++-- .../main/scala/com/stripe/dagon/HMap.scala | 17 +++- core/src/main/scala/com/stripe/dagon/Id.scala | 2 +- .../main/scala/com/stripe/dagon/Memoize.scala | 18 +--- .../scala/com/stripe/dagon/PartialRule.scala | 4 +- .../main/scala/com/stripe/dagon/Rule.scala | 33 +++---- .../scala/com/stripe/dagon/DataFlowTest.scala | 62 +++++++------- .../com/stripe/dagon/ExpressionDagTests.scala | 22 ++--- .../scala/com/stripe/dagon/ReadmeTest.scala | 6 +- .../scala/com/stripe/dagon/SimpleDag.scala} | 8 +- 13 files changed, 257 insertions(+), 133 deletions(-) create mode 100644 core/src/main/scala/com/stripe/dagon/Cache.scala rename core/src/main/scala/com/stripe/dagon/{ExpressionDag.scala => Dag.scala} (85%) rename core/src/{main/scala/com/stripe/dagon/DependantGraph.scala => test/scala/com/stripe/dagon/SimpleDag.scala} (92%) diff --git a/README.md b/README.md index 07e0ba6..b2b2de9 100644 --- a/README.md +++ b/README.md @@ -102,14 +102,14 @@ object Example { // 3. set up rewrite rules object SimplifyNegation extends PartialRule[Eqn] { - def applyWhere[T](on: ExpressionDag[Eqn]) = { + def applyWhere[T](on: Dag[Eqn]) = { case Negate(Negate(e)) => e case Negate(Const(x)) => Const(-x) } } object SimplifyAddition extends PartialRule[Eqn] { - def applyWhere[T](on: ExpressionDag[Eqn]) = { + def applyWhere[T](on: Dag[Eqn]) = { case Add(Const(x), Const(y)) => Const(x + y) case Add(Add(e, Const(x)), Const(y)) => Add(e, Const(x + y)) case Add(Add(Const(x), e), Const(y)) => Add(e, Const(x + y)) @@ -128,7 +128,7 @@ object Example { val rules = SimplifyNegation.orElse(SimplifyAddition) val simplified: Eqn[Unit] = - ExpressionDag.applyRule(c, toLiteral, rules) + Dag.applyRule(c, toLiteral, rules) } ``` diff --git a/core/src/main/scala/com/stripe/dagon/Cache.scala b/core/src/main/scala/com/stripe/dagon/Cache.scala new file mode 100644 index 0000000..c662401 --- /dev/null +++ b/core/src/main/scala/com/stripe/dagon/Cache.scala @@ -0,0 +1,70 @@ +package com.stripe.dagon + +/** + * This is a useful cache for memoizing function. + * + * The cache is implemented using a mutable pointer to an immutable + * map value. In the worst-case, race conditions might cause us to + * lose cache values (i.e. compute some keys twice), but we will never + * produce incorrect values. + */ +sealed class Cache[K, V] private (init: Map[K, V]) { + + private[this] var map: Map[K, V] = init + + /** + * Given a key, either return a cached value, or compute, store, and + * return a new value. + * + * This method is what justifies the existence of Cache. Its second + * parameter (`v`) is by-name: it will only be evaluated in cases + * where the key is not cached. + * + * For example: + * + * def greet(i: Int): Int = { + * println("hi") + * i + 1 + * } + * + * val c = Cache.empty[Int, Int] + * c.getOrElseUpdate(1, greet(1)) // says hi, returns 2 + * c.getOrElseUpdate(1, greet(1)) // just returns 2 + */ + def getOrElseUpdate(k: K, v: => V): V = + map.get(k) match { + case Some(exists) => exists + case None => + val res = v + map = map.updated(k, res) + res + } + + /** + * Create a second cache with the same values as this one. + * + * The two caches will start with the same values, but will be + * independently updated. + */ + def duplicate: Cache[K, V] = + new Cache(map) + + /** + * Access the currently-cached keys and values as a map. + */ + def toMap: Map[K, V] = + map + + /** + * Forget all cached keys and values. + * + * After calling this method, the resulting cache is equivalent to + * Cache.empty[K, V]. + */ + def reset(): Unit = + map = Map.empty +} + +object Cache { + def empty[K, V]: Cache[K, V] = new Cache(Map.empty) +} diff --git a/core/src/main/scala/com/stripe/dagon/ExpressionDag.scala b/core/src/main/scala/com/stripe/dagon/Dag.scala similarity index 85% rename from core/src/main/scala/com/stripe/dagon/ExpressionDag.scala rename to core/src/main/scala/com/stripe/dagon/Dag.scala index 31403dd..730aeb8 100644 --- a/core/src/main/scala/com/stripe/dagon/ExpressionDag.scala +++ b/core/src/main/scala/com/stripe/dagon/Dag.scala @@ -17,7 +17,12 @@ package com.stripe.dagon -sealed abstract class ExpressionDag[N[_]] { self => +/** + * Represents a directed acyclic graph (DAG). + * + * The type N[_] represents the type of nodes in the graph. + */ +sealed abstract class Dag[N[_]] { self => /** * These have package visibility to test @@ -25,53 +30,71 @@ sealed abstract class ExpressionDag[N[_]] { self => * evaluate to is unique */ protected def idToExp: HMap[Id, Expr[N, ?]] + /** * The set of roots that were added by addRoot. * These are Ids that will always evaluate * such that roots.forall(evaluateOption(_).isDefined) */ protected def roots: Set[Id[_]] + /** * This is the next Id value which will be allocated */ protected def nextId: Int /** - * Convert a N[T] to a Literal[T, N] + * Convert a N[T] to a Literal[T, N]. */ def toLiteral: FunctionK[N, Literal[N, ?]] + // Caches polymorphic functions of type T => Option[N[T]] + private val idToN: HCache[Id, Lambda[t => Option[N[t]]]] = + HCache.empty[Id, Lambda[t => Option[N[t]]]] + + // Caches polymorphic functions of type N[T] => Option[T] + private val nodeToId: HCache[N, Lambda[t => Option[Id[t]]]] = + HCache.empty[N, Lambda[t => Option[Id[t]]]] + + // Convenient method to produce new, modified DAGs based on this + // one. private def copy( id2Exp: HMap[Id, Expr[N, ?]] = self.idToExp, node2Literal: FunctionK[N, Literal[N, ?]] = self.toLiteral, gcroots: Set[Id[_]] = self.roots, id: Int = self.nextId - ): ExpressionDag[N] = new ExpressionDag[N] { + ): Dag[N] = new Dag[N] { def idToExp = id2Exp def roots = gcroots def toLiteral = node2Literal def nextId = id } - override def toString: String = - s"ExpressionDag(idToExp = $idToExp, roots = $roots)" + // Produce a new DAG that is equivalent to this one, but which frees + // orphaned nodes and other internal state which may no longer be + // needed. + private def gc: Dag[N] = { + val keepers = reachableIds + val kept = idToExp.filter { case (id, _) => keepers(id) } + if (idToExp.size == kept.size) this else copy(id2Exp = kept) + } - // This is a cache of Id[T] => Option[N[T]] - private val idToN = - HCache.empty[Id, Lambda[t => Option[N[t]]]] - private val nodeToId = - HCache.empty[N, Lambda[t => Option[Id[t]]]] + /** + * String representation of this DAG. + */ + override def toString: String = + s"Dag(idToExp = $idToExp, roots = $roots)" /** * Add a GC root, or tail in the DAG, that can never be deleted. */ - def addRoot[T](node: N[T]): (ExpressionDag[N], Id[T]) = { + def addRoot[T](node: N[T]): (Dag[N], Id[T]) = { val (dag, id) = ensure(node) (dag.copy(gcroots = roots + id), id) } /** - * Which ids are reachable from the roots + * Which ids are reachable from the roots? */ def reachableIds: Set[Id[_]] = { @@ -86,22 +109,14 @@ sealed abstract class ExpressionDag[N[_]] { self => Graphs.reflexiveTransitiveClosure(roots.toList)(neighbors _).toSet } - private def gc: ExpressionDag[N] = { - val goodIds = reachableIds - val toKeepI2E = idToExp.filter(new FunctionK[HMap[Id, Expr[N, ?]]#Pair, BoolT] { - def toFunction[T] = { case (id, _) => goodIds(id) } - }) - copy(id2Exp = toKeepI2E) - } - /** * Apply the given rule to the given dag until * the graph no longer changes. */ - def apply(rule: Rule[N]): ExpressionDag[N] = { + def apply(rule: Rule[N]): Dag[N] = { @annotation.tailrec - def loop(d: ExpressionDag[N]): ExpressionDag[N] = { + def loop(d: Dag[N]): Dag[N] = { val next = d.applyOnce(rule) if (next eq d) next else loop(next) @@ -114,8 +129,8 @@ sealed abstract class ExpressionDag[N[_]] { self => * apply the rule at the first place that satisfies * it, and return from there. */ - def applyOnce(rule: Rule[N]): ExpressionDag[N] = { - type DagT[T] = ExpressionDag[N] + def applyOnce(rule: Rule[N]): Dag[N] = { + type DagT[T] = Dag[N] val f = new FunctionK[HMap[Id, Expr[N, ?]]#Pair, Lambda[x => Option[DagT[x]]]] { def toFunction[U] = { (kv: (Id[U], Expr[N, U])) => @@ -134,7 +149,7 @@ sealed abstract class ExpressionDag[N[_]] { self => // publicly, and the ids may be embedded in many // nodes. Instead we remap 'id' to be a pointer // to 'newid'. - dag.copy(id2Exp = dag.idToExp + (id -> Expr.Var[N, U](newId))).gc + dag.copy(id2Exp = dag.idToExp.updated(id, Expr.Var[N, U](newId))).gc } } } @@ -146,10 +161,10 @@ sealed abstract class ExpressionDag[N[_]] { self => /** * Apply a rule at most cnt times. */ - def applyMax(rule: Rule[N], cnt: Int): ExpressionDag[N] = { + def applyMax(rule: Rule[N], cnt: Int): Dag[N] = { @annotation.tailrec - def loop(d: ExpressionDag[N], cnt: Int): ExpressionDag[N] = + def loop(d: Dag[N], cnt: Int): Dag[N] = if (cnt <= 0) d else { val next = d.applyOnce(rule) @@ -165,10 +180,10 @@ sealed abstract class ExpressionDag[N[_]] { self => * * Note, Expr must never be a Var */ - private def addExp[T](node: N[T], exp: Expr[N, T]): (ExpressionDag[N], Id[T]) = { + private def addExp[T](node: N[T], exp: Expr[N, T]): (Dag[N], Id[T]) = { require(!exp.isVar) val nodeId = Id[T](nextId) - (copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId) + (copy(id2Exp = idToExp.updated(nodeId, exp), id = nextId + 1), nodeId) } /** @@ -264,7 +279,7 @@ sealed abstract class ExpressionDag[N[_]] { self => * at most one id in the graph. Put another way, for all * Id[T] in the graph evaluate(id) is distinct. */ - protected def ensure[T](node: N[T]): (ExpressionDag[N], Id[T]) = + protected def ensure[T](node: N[T]): (Dag[N], Id[T]) = find(node) match { case Some(id) => (this, id) case None => @@ -377,10 +392,10 @@ sealed abstract class ExpressionDag[N[_]] { self => } } -object ExpressionDag { +object Dag { - def empty[N[_]](n2l: FunctionK[N, Literal[N, ?]]): ExpressionDag[N] = - new ExpressionDag[N] { + def empty[N[_]](n2l: FunctionK[N, Literal[N, ?]]): Dag[N] = + new Dag[N] { val idToExp = HMap.empty[Id, Expr[N, ?]] val toLiteral = n2l val roots = Set.empty[Id[_]] @@ -388,9 +403,9 @@ object ExpressionDag { } /** - * This creates a new ExpressionDag rooted at the given tail node + * This creates a new Dag rooted at the given tail node */ - def apply[T, N[_]](n: N[T], nodeToLit: FunctionK[N, Literal[N, ?]]): (ExpressionDag[N], Id[T]) = + def apply[T, N[_]](n: N[T], nodeToLit: FunctionK[N, Literal[N, ?]]): (Dag[N], Id[T]) = empty(nodeToLit).addRoot(n) /** diff --git a/core/src/main/scala/com/stripe/dagon/HCache.scala b/core/src/main/scala/com/stripe/dagon/HCache.scala index 9106bd5..2a18083 100644 --- a/core/src/main/scala/com/stripe/dagon/HCache.scala +++ b/core/src/main/scala/com/stripe/dagon/HCache.scala @@ -1,21 +1,36 @@ package com.stripe.dagon /** - * This is a useful cache for memoizing heterogenously types functions + * This is a useful cache for memoizing natural transformations. + * + * The cache is implemented using a mutable pointer to an immutable + * map value. In the worst-case, race conditions might cause us to + * lose cache values (i.e. compute some keys twice), but we will never + * produce incorrect values. */ sealed class HCache[K[_], V[_]] private (init: HMap[K, V]) { - private var hmap: HMap[K, V] = init - /** - * Get an immutable snapshot of the current state - */ - def snapshot: HMap[K, V] = hmap + private[this] var hmap: HMap[K, V] = init /** - * Get a mutable copy of the current state + * Given a key, either return a cached value, or compute, store, and + * return a new value. + * + * This method is what justifies the existence of Cache. Its second + * parameter (`v`) is by-name: it will only be evaluated in cases + * where the key is not cached. + * + * For example: + * + * def greet(i: Int): Option[Int] = { + * println("hi") + * Option(i + 1) + * } + * + * val c = Cache.empty[Option, Option] + * c.getOrElseUpdate(Some(1), greet(1)) // says hi, returns Some(2) + * c.getOrElseUpdate(Some(1), greet(1)) // just returns Some(2) */ - def duplicate: HCache[K, V] = new HCache(hmap) - def getOrElseUpdate[T](k: K[T], v: => V[T]): V[T] = hmap.get(k) match { case Some(exists) => exists @@ -24,6 +39,30 @@ sealed class HCache[K[_], V[_]] private (init: HMap[K, V]) { hmap = hmap + (k -> res) res } + + /** + * Create a second cache with the same values as this one. + * + * The two caches will start with the same values, but will be + * independently updated. + */ + def duplicate: HCache[K, V] = + new HCache(hmap) + + /** + * Access the currently-cached keys and values as a map. + */ + def toHMap: HMap[K, V] = + hmap + + /** + * Forget all cached keys and values. + * + * After calling this method, the resulting cache is equivalent to + * Cache.empty[K, V]. + */ + def reset(): Unit = + hmap = HMap.empty[K, V] } object HCache { diff --git a/core/src/main/scala/com/stripe/dagon/HMap.scala b/core/src/main/scala/com/stripe/dagon/HMap.scala index 17bf926..094b447 100644 --- a/core/src/main/scala/com/stripe/dagon/HMap.scala +++ b/core/src/main/scala/com/stripe/dagon/HMap.scala @@ -40,6 +40,9 @@ final class HMap[K[_], V[_]](protected val map: Map[K[_], V[_]]) { override def hashCode: Int = map.hashCode + def updated[T](k: K[T], v: V[T]): HMap[K, V] = + HMap.from[K, V](map.updated(k, v)) + def +[T](kv: (K[T], V[T])): HMap[K, V] = HMap.from[K, V](map + kv) @@ -55,10 +58,16 @@ final class HMap[K[_], V[_]](protected val map: Map[K[_], V[_]]) { def contains[T](id: K[T]): Boolean = get(id).isDefined - def filter(pred: FunctionK[Pair, BoolT]): HMap[K, V] = { - val filtered = map.asInstanceOf[Map[K[Any], V[Any]]].filter(pred.apply[Any]) - HMap.from[K, V](filtered.asInstanceOf[Map[K[_], V[_]]]) - } + def size: Int = map.size + + def exists(p: ((K[_], V[_])) => Boolean): Boolean = + map.exists(p) + + def forall(p: ((K[_], V[_])) => Boolean): Boolean = + map.forall(p) + + def filter(p: ((K[_], V[_])) => Boolean): HMap[K, V] = + HMap.from[K, V](map.filter(p)) def keysOf[T](v: V[T]): Set[K[T]] = map.collect { diff --git a/core/src/main/scala/com/stripe/dagon/Id.scala b/core/src/main/scala/com/stripe/dagon/Id.scala index c538df6..7afd3ea 100644 --- a/core/src/main/scala/com/stripe/dagon/Id.scala +++ b/core/src/main/scala/com/stripe/dagon/Id.scala @@ -4,7 +4,7 @@ package com.stripe.dagon * The Expressions are assigned Ids. Each Id is associated with * an expression of inner type T. * - * This is done to put an indirection in the ExpressionDag that + * This is done to put an indirection in the Dag that * allows us to rewrite nodes by simply replacing the expressions * associated with given Ids. * diff --git a/core/src/main/scala/com/stripe/dagon/Memoize.scala b/core/src/main/scala/com/stripe/dagon/Memoize.scala index 584e0a0..ed35d64 100644 --- a/core/src/main/scala/com/stripe/dagon/Memoize.scala +++ b/core/src/main/scala/com/stripe/dagon/Memoize.scala @@ -22,17 +22,8 @@ object Memoize { // but mutating the Map inside of the call-by-name value causes // some issues in some versions of scala. It is // safer to use a mutable pointer to an immutable Map. - var cache = Map.empty[A, B] - def getOrElseUpdate(a: A, b: => B): B = - cache.get(a) match { - case Some(res) => res - case None => - val res = b - cache = cache.updated(a, res) - res - } - - lazy val g: A => B = (a: A) => getOrElseUpdate(a, f(a, g)) + val cache = Cache.empty[A, B] + lazy val g: A => B = (a: A) => cache.getOrElseUpdate(a, f(a, g)) g } @@ -44,9 +35,8 @@ object Memoize { def functionK[A[_], B[_]](f: RecursiveK[A, B]): FunctionK[A, B] = { val hcache = HCache.empty[A, B] lazy val hg: FunctionK[A, B] = new FunctionK[A, B] { - def toFunction[T] = { at => - hcache.getOrElseUpdate(at, f((at, hg))) - } + def toFunction[T]: A[T] => B[T] = + at => hcache.getOrElseUpdate(at, f((at, hg))) } hg } diff --git a/core/src/main/scala/com/stripe/dagon/PartialRule.scala b/core/src/main/scala/com/stripe/dagon/PartialRule.scala index 1fcf479..fe9b958 100644 --- a/core/src/main/scala/com/stripe/dagon/PartialRule.scala +++ b/core/src/main/scala/com/stripe/dagon/PartialRule.scala @@ -4,8 +4,8 @@ package com.stripe.dagon * Often a partial function is an easier way to express rules */ trait PartialRule[N[_]] extends Rule[N] { - final def apply[T](on: ExpressionDag[N]): N[T] => Option[N[T]] = + final def apply[T](on: Dag[N]): N[T] => Option[N[T]] = applyWhere[T](on).lift - def applyWhere[T](on: ExpressionDag[N]): PartialFunction[N[T], N[T]] + def applyWhere[T](on: Dag[N]): PartialFunction[N[T], N[T]] } diff --git a/core/src/main/scala/com/stripe/dagon/Rule.scala b/core/src/main/scala/com/stripe/dagon/Rule.scala index 80e14ce..a28e455 100644 --- a/core/src/main/scala/com/stripe/dagon/Rule.scala +++ b/core/src/main/scala/com/stripe/dagon/Rule.scala @@ -1,7 +1,7 @@ package com.stripe.dagon /** - * This implements a simplification rule on ExpressionDags + * This implements a simplification rule on Dags */ trait Rule[N[_]] { self => @@ -12,29 +12,30 @@ trait Rule[N[_]] { self => * If it is convenient, you might write a partial function * and then call .lift to get the correct Function type */ - def apply[T](on: ExpressionDag[N]): N[T] => Option[N[T]] + def apply[T](on: Dag[N]): N[T] => Option[N[T]] // If the current rule cannot apply, then try the argument here - def orElse(that: Rule[N]): Rule[N] = new Rule[N] { - def apply[T](on: ExpressionDag[N]) = { n => - self.apply(on)(n) match { - case Some(n1) if n1 == n => - // If the rule emits the same as input fall through - that.apply(on)(n) - case None => - that.apply(on)(n) - case s@Some(_) => s + def orElse(that: Rule[N]): Rule[N] = + new Rule[N] { + def apply[T](on: Dag[N]) = { n => + self.apply(on)(n) match { + case Some(n1) if n1 == n => + // If the rule emits the same as input fall through + that.apply(on)(n) + case None => + that.apply(on)(n) + case s@Some(_) => s + } } - } - override def toString: String = - s"$self.orElse($that)" - } + override def toString: String = + s"$self.orElse($that)" + } } object Rule { def empty[N[_]]: Rule[N] = new Rule[N] { - def apply[T](on: ExpressionDag[N]) = { _ => None } + def apply[T](on: Dag[N]) = { _ => None } } } diff --git a/core/src/test/scala/com/stripe/dagon/DataFlowTest.scala b/core/src/test/scala/com/stripe/dagon/DataFlowTest.scala index dfe8af0..b5afcf3 100644 --- a/core/src/test/scala/com/stripe/dagon/DataFlowTest.scala +++ b/core/src/test/scala/com/stripe/dagon/DataFlowTest.scala @@ -88,7 +88,7 @@ object DataFlowTest { * we use object to get good toString for debugging */ object composeOptionMapped extends PartialRule[Flow] { - def applyWhere[T](on: ExpressionDag[Flow]) = { + def applyWhere[T](on: Dag[Flow]) = { case (OptionMapped(inner @ OptionMapped(s, fn0), fn1)) if on.fanOut(inner) == 1 => OptionMapped(s, ComposedOM(fn0, fn1)) } @@ -98,7 +98,7 @@ object DataFlowTest { * f.concatMap(fn1).concatMap(fn2) == f.concatMap { t => fn1(t).flatMap(fn2) } */ object composeConcatMap extends PartialRule[Flow] { - def applyWhere[T](on: ExpressionDag[Flow]) = { + def applyWhere[T](on: Dag[Flow]) = { case (ConcatMapped(inner @ ConcatMapped(s, fn0), fn1)) if on.fanOut(inner) == 1 => ConcatMapped(s, ComposedCM(fn0, fn1)) } @@ -109,7 +109,7 @@ object DataFlowTest { * (a ++ b).optionMap(fn) == (a.optionMap(fn) ++ b.optionMap(fn)) */ object mergePullDown extends PartialRule[Flow] { - def applyWhere[T](on: ExpressionDag[Flow]) = { + def applyWhere[T](on: Dag[Flow]) = { case (ConcatMapped(merge @ Merge(a, b), fn)) if on.fanOut(merge) == 1 => a.concatMap(fn) ++ b.concatMap(fn) case (OptionMapped(merge @ Merge(a, b), fn)) if on.fanOut(merge) == 1 => @@ -122,7 +122,7 @@ object DataFlowTest { * the knowledge about which fns potentially expand the size */ object optionMapToConcatMap extends PartialRule[Flow] { - def applyWhere[T](on: ExpressionDag[Flow]) = { + def applyWhere[T](on: Dag[Flow]) = { case OptionMapped(of, fn) => ConcatMapped(of, OptionToConcatFn(fn)) } } @@ -131,7 +131,7 @@ object DataFlowTest { * right associate merges */ object rightMerge extends PartialRule[Flow] { - def applyWhere[T](on: ExpressionDag[Flow]) = { + def applyWhere[T](on: Dag[Flow]) = { case Merge(left@Merge(a, b), c) if on.fanOut(left) == 1 => Merge(a, Merge(b, c)) } @@ -141,7 +141,7 @@ object DataFlowTest { * evaluate single fanout sources */ object evalSource extends PartialRule[Flow] { - def applyWhere[T](on: ExpressionDag[Flow]) = { + def applyWhere[T](on: Dag[Flow]) = { case OptionMapped(src @ IteratorSource(it), fn) if on.fanOut(src) == 1 => IteratorSource(it.flatMap(fn(_).toIterator)) case ConcatMapped(src @ IteratorSource(it), fn) if on.fanOut(src) == 1 => @@ -156,7 +156,7 @@ object DataFlowTest { } object removeTag extends PartialRule[Flow] { - def applyWhere[T](on: ExpressionDag[Flow]) = { + def applyWhere[T](on: Dag[Flow]) = { case Tagged(in, _) => in } } @@ -229,13 +229,13 @@ object DataFlowTest { Arbitrary(genFlow[T](implicitly[Arbitrary[T]].arbitrary)) - def expDagGen[T: Cogen](g: Gen[T]): Gen[ExpressionDag[Flow]] = { - val empty = ExpressionDag.empty[Flow](toLiteral) + def expDagGen[T: Cogen](g: Gen[T]): Gen[Dag[Flow]] = { + val empty = Dag.empty[Flow](toLiteral) Gen.frequency((1, Gen.const(empty)), (10, genFlow(g).map { f => empty.addRoot(f)._1 })) } - def arbExpDag[T: Arbitrary: Cogen]: Arbitrary[ExpressionDag[Flow]] = + def arbExpDag[T: Arbitrary: Cogen]: Arbitrary[Dag[Flow]] = Arbitrary(expDagGen[T](implicitly[Arbitrary[T]].arbitrary)) } } @@ -258,7 +258,7 @@ class DataFlowTest extends FunSuite { import Flow._ - val res = ExpressionDag.applyRule(tail, toLiteral, mergePullDown.orElse(composeOptionMapped)) + val res = Dag.applyRule(tail, toLiteral, mergePullDown.orElse(composeOptionMapped)) res match { case Merge(OptionMapped(s1, fn1), OptionMapped(s2, fn2)) => @@ -273,7 +273,7 @@ class DataFlowTest extends FunSuite { val f = Flow(it1).map(_ * 2) ++ Flow(it2).filter(_ % 7 == 0) - ExpressionDag.applyRule(f, Flow.toLiteral, Flow.allRules) match { + Dag.applyRule(f, Flow.toLiteral, Flow.allRules) match { case Flow.IteratorSource(it) => assert(it.toList == (it1.map(_ * 2) ++ (it2.filter(_ % 7 == 0))).toList) case nonSrc => @@ -285,12 +285,12 @@ class DataFlowTest extends FunSuite { test("fanOut matches") { def law(f: Flow[Int], rule: Rule[Flow], maxApplies: Int) = { - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optimizedDag = dag.applyMax(rule, maxApplies) val optF = optimizedDag.evaluate(id) - val depGraph = DependantGraph[Flow[Any]](Flow.transitiveDeps(optF))(Flow.dependenciesOf _) + val depGraph = SimpleDag[Flow[Any]](Flow.transitiveDeps(optF))(Flow.dependenciesOf _) def fanOut(f: Flow[Any]): Int = { val internal = depGraph.fanOut(f).getOrElse(0) @@ -327,14 +327,14 @@ class DataFlowTest extends FunSuite { test("we either totally evaluate or have Iterators with fanOut") { def law(f: Flow[Int]) = { - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optDag = dag(Flow.allRules) val optF = optDag.evaluate(id) optF match { case Flow.IteratorSource(_) => succeed case nonEval => - val depGraph = DependantGraph[Flow[Any]](Flow.transitiveDeps(nonEval))(Flow.dependenciesOf _) + val depGraph = SimpleDag[Flow[Any]](Flow.transitiveDeps(nonEval))(Flow.dependenciesOf _) val fansOut = depGraph .nodes @@ -353,7 +353,7 @@ class DataFlowTest extends FunSuite { test("addRoot adds roots") { implicit val dag = Flow.arbExpDag[Int] - forAll { (d: ExpressionDag[Flow], f: Flow[Int]) => + forAll { (d: Dag[Flow], f: Flow[Int]) => val (next, id) = d.addRoot(f) assert(next.isRoot(f)) @@ -362,9 +362,9 @@ class DataFlowTest extends FunSuite { } } - test("all ExpressionDag.allNodes agrees with Flow.transitiveDeps") { + test("all Dag.allNodes agrees with Flow.transitiveDeps") { forAll { (f: Flow[Int], rule: Rule[Flow], max: Int) => - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optimizedDag = dag.applyMax(rule, max) @@ -373,9 +373,9 @@ class DataFlowTest extends FunSuite { } } - test("ExpressionDag: findAll(n).forall(evaluate(_) == n)") { + test("Dag: findAll(n).forall(evaluate(_) == n)") { forAll { (f: Flow[Int], rule: Rule[Flow], max: Int) => - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optimizedDag = dag.applyMax(rule, max) @@ -390,7 +390,7 @@ class DataFlowTest extends FunSuite { test("apply the empty rule returns eq dag") { implicit val dag = Flow.arbExpDag[Int] - forAll { (d: ExpressionDag[Flow]) => + forAll { (d: Dag[Flow]) => assert(d(Rule.empty[Flow]) eq d) } } @@ -398,11 +398,11 @@ class DataFlowTest extends FunSuite { test("rules are idempotent") { def law(f: Flow[Int], rule: Rule[Flow]) = { - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optimizedDag = dag(rule) val optF = optimizedDag.evaluate(id) - val (dag2, id2) = ExpressionDag(optF, Flow.toLiteral) + val (dag2, id2) = Dag(optF, Flow.toLiteral) val optimizedDag2 = dag2(rule) val optF2 = optimizedDag2.evaluate(id2) @@ -412,12 +412,12 @@ class DataFlowTest extends FunSuite { forAll(law _) } - test("dependentsOf matches DependantGraph.dependantsOf") { + test("dependentsOf matches SimpleDag.dependantsOf") { forAll { (f: Flow[Int], rule: Rule[Flow], max: Int) => - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optimizedDag = dag.applyMax(rule, max) - val depGraph = DependantGraph[Flow[Any]](Flow.transitiveDeps(optimizedDag.evaluate(id)))(Flow.dependenciesOf _) + val depGraph = SimpleDag[Flow[Any]](Flow.transitiveDeps(optimizedDag.evaluate(id)))(Flow.dependenciesOf _) optimizedDag.allNodes.foreach { n => assert(optimizedDag.dependentsOf(n) == depGraph.dependantsOf(n).fold(Set.empty[Flow[Any]])(_.toSet)) @@ -429,7 +429,7 @@ class DataFlowTest extends FunSuite { test("contains(n) is the same as allNodes.contains(n)") { forAll { (f: Flow[Int], rule: Rule[Flow], max: Int, check: List[Flow[Int]]) => - val (dag, _) = ExpressionDag(f, Flow.toLiteral) + val (dag, _) = Dag(f, Flow.toLiteral) val optimizedDag = dag.applyMax(rule, max) @@ -441,7 +441,7 @@ class DataFlowTest extends FunSuite { test("all roots can be evaluated") { forAll { (roots: List[Flow[Int]], rule: Rule[Flow], max: Int) => - val dag = ExpressionDag.empty[Flow](Flow.toLiteral) + val dag = Dag.empty[Flow](Flow.toLiteral) // This is pretty slow with tons of roots, take 10 val (finalDag, allRoots) = roots.take(10).foldLeft((dag, Set.empty[Id[Int]])) { case ((d, s), f) => @@ -459,7 +459,7 @@ class DataFlowTest extends FunSuite { test("removeTag removes all .tagged") { forAll { f: Flow[Int] => - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optDag = dag(Flow.allRules) // includes removeTagged optDag.allNodes.foreach { @@ -471,7 +471,7 @@ class DataFlowTest extends FunSuite { test("reachableIds are only the set of nodes") { forAll { (f: Flow[Int], rule: Rule[Flow], max: Int) => - val (dag, id) = ExpressionDag(f, Flow.toLiteral) + val (dag, id) = Dag(f, Flow.toLiteral) val optimizedDag = dag.applyMax(rule, max) diff --git a/core/src/test/scala/com/stripe/dagon/ExpressionDagTests.scala b/core/src/test/scala/com/stripe/dagon/ExpressionDagTests.scala index 877ac46..aa7b26f 100644 --- a/core/src/test/scala/com/stripe/dagon/ExpressionDagTests.scala +++ b/core/src/test/scala/com/stripe/dagon/ExpressionDagTests.scala @@ -20,7 +20,7 @@ package com.stripe.dagon import org.scalacheck.Prop._ import org.scalacheck.{Gen, Prop, Properties} -object ExpressionDagTests extends Properties("ExpressionDag") { +object DagTests extends Properties("Dag") { /* * Here we test with a simple algebra optimizer @@ -60,7 +60,7 @@ object ExpressionDagTests extends Properties("ExpressionDag") { } def testRule[T](start: Formula[T], expected: Formula[T], rule: Rule[Formula]): Prop = { - val got = ExpressionDag.applyRule(start, toLiteral, rule) + val got = Dag.applyRule(start, toLiteral, rule) (got == expected) :| s"$got == $expected" } @@ -107,14 +107,14 @@ object ExpressionDagTests extends Properties("ExpressionDag") { * Inc(Inc(a, b), c) = Inc(a, b + c) */ object CombineInc extends Rule[Formula] { - def apply[T](on: ExpressionDag[Formula]) = { + def apply[T](on: Dag[Formula]) = { case Inc(i @ Inc(a, b), c) if on.fanOut(i) == 1 => Some(Inc(a, b + c)) case _ => None } } object RemoveInc extends PartialRule[Formula] { - def applyWhere[T](on: ExpressionDag[Formula]) = { + def applyWhere[T](on: Dag[Formula]) = { case Inc(f, by) => Sum(f, Constant(by)) } } @@ -125,18 +125,18 @@ object ExpressionDagTests extends Properties("ExpressionDag") { toLiteral.apply(form).evaluate == form } - property("Going to ExpressionDag round trips") = forAll(genForm) { form => - val (dag, id) = ExpressionDag(form, toLiteral) + property("Going to Dag round trips") = forAll(genForm) { form => + val (dag, id) = Dag(form, toLiteral) dag.evaluate(id) == form } property("CombineInc does not change results") = forAll(genForm) { form => - val simplified = ExpressionDag.applyRule(form, toLiteral, CombineInc) + val simplified = Dag.applyRule(form, toLiteral, CombineInc) form.evaluate == simplified.evaluate } property("RemoveInc removes all Inc") = forAll(genForm) { form => - val noIncForm = ExpressionDag.applyRule(form, toLiteral, RemoveInc) + val noIncForm = Dag.applyRule(form, toLiteral, RemoveInc) def noInc(f: Formula[Int]): Boolean = f match { case Constant(_) => true case Inc(_, _) => false @@ -156,7 +156,7 @@ object ExpressionDagTests extends Properties("ExpressionDag") { def genChain: Gen[Formula[Int]] = Gen.frequency((1, genConst), (3, genChainInc)) property("CombineInc compresses linear Inc chains") = forAll(genChain) { chain => - ExpressionDag.applyRule(chain, toLiteral, CombineInc) match { + Dag.applyRule(chain, toLiteral, CombineInc) match { case Constant(n) => true case Inc(Constant(n), b) => true case _ => false // All others should have been compressed @@ -167,7 +167,7 @@ object ExpressionDagTests extends Properties("ExpressionDag") { * We should be able to totally evaluate these formulas */ object EvaluationRule extends Rule[Formula] { - def apply[T](on: ExpressionDag[Formula]) = { + def apply[T](on: Dag[Formula]) = { case Sum(Constant(a), Constant(b)) => Some(Constant(a + b)) case Product(Constant(a), Constant(b)) => Some(Constant(a * b)) case Inc(Constant(a), b) => Some(Constant(a + b)) @@ -197,7 +197,7 @@ object ExpressionDagTests extends Properties("ExpressionDag") { val tails = (n1 :: ns).zipWithIndex.map { case (i, idx) => Formula(i).inc(idx) } val (dag, roots) = - tails.foldLeft((ExpressionDag.empty[Formula](toLiteral), Set.empty[Id[_]])) { + tails.foldLeft((Dag.empty[Formula](toLiteral), Set.empty[Id[_]])) { case ((d, s), f) => val (dnext, id) = d.addRoot(f) (dnext, s + id) diff --git a/core/src/test/scala/com/stripe/dagon/ReadmeTest.scala b/core/src/test/scala/com/stripe/dagon/ReadmeTest.scala index dc8c4be..a601549 100644 --- a/core/src/test/scala/com/stripe/dagon/ReadmeTest.scala +++ b/core/src/test/scala/com/stripe/dagon/ReadmeTest.scala @@ -31,14 +31,14 @@ object Example { }) object SimplifyNegation extends PartialRule[Eqn] { - def applyWhere[T](on: ExpressionDag[Eqn]) = { + def applyWhere[T](on: Dag[Eqn]) = { case Negate(Negate(e)) => e case Negate(Const(x)) => Const(-x) } } object SimplifyAddition extends PartialRule[Eqn] { - def applyWhere[T](on: ExpressionDag[Eqn]) = { + def applyWhere[T](on: Dag[Eqn]) = { case Add(Const(x), Const(y)) => Const(x + y) case Add(Add(e, Const(x)), Const(y)) => Add(e, Const(x + y)) case Add(Add(Const(x), e), Const(y)) => Add(e, Const(x + y)) @@ -54,5 +54,5 @@ object Example { val b2 = a + Const(5) + Var("y") val c = b1 - b2 - ExpressionDag.applyRule(c, toLiteral, rules) + Dag.applyRule(c, toLiteral, rules) } diff --git a/core/src/main/scala/com/stripe/dagon/DependantGraph.scala b/core/src/test/scala/com/stripe/dagon/SimpleDag.scala similarity index 92% rename from core/src/main/scala/com/stripe/dagon/DependantGraph.scala rename to core/src/test/scala/com/stripe/dagon/SimpleDag.scala index fd1b378..bd00148 100644 --- a/core/src/main/scala/com/stripe/dagon/DependantGraph.scala +++ b/core/src/test/scala/com/stripe/dagon/SimpleDag.scala @@ -23,7 +23,7 @@ import Graphs._ * Given Dag and a List of immutable nodes, and a function to get * dependencies, compute the dependants (reverse the graph) */ -abstract class DependantGraph[T] { +abstract class SimpleDag[T] { def nodes: List[T] def dependenciesOf(t: T): Iterable[T] @@ -58,9 +58,9 @@ abstract class DependantGraph[T] { def transitiveDependantsOf(p: T): List[T] = depthFirstOf(p)(graph) } -object DependantGraph { - def apply[T](nodes0: List[T])(nfn: T => Iterable[T]): DependantGraph[T] = - new DependantGraph[T] { +object SimpleDag { + def apply[T](nodes0: List[T])(nfn: T => Iterable[T]): SimpleDag[T] = + new SimpleDag[T] { def nodes = nodes0 def dependenciesOf(t: T) = nfn(t) } From b6c967bddb19380b1d00b3d2125ceea2b08a5960 Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Tue, 29 Aug 2017 15:40:13 -0400 Subject: [PATCH 2/2] Respond to in-person review comments. --- core/src/main/scala/com/stripe/dagon/Dag.scala | 4 ++-- core/src/main/scala/com/stripe/dagon/HMap.scala | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/com/stripe/dagon/Dag.scala b/core/src/main/scala/com/stripe/dagon/Dag.scala index 730aeb8..620232a 100644 --- a/core/src/main/scala/com/stripe/dagon/Dag.scala +++ b/core/src/main/scala/com/stripe/dagon/Dag.scala @@ -75,8 +75,8 @@ sealed abstract class Dag[N[_]] { self => // needed. private def gc: Dag[N] = { val keepers = reachableIds - val kept = idToExp.filter { case (id, _) => keepers(id) } - if (idToExp.size == kept.size) this else copy(id2Exp = kept) + if (idToExp.forallKeys(keepers)) this + else copy(id2Exp = idToExp.filterKeys(keepers)) } /** diff --git a/core/src/main/scala/com/stripe/dagon/HMap.scala b/core/src/main/scala/com/stripe/dagon/HMap.scala index 094b447..7511a0a 100644 --- a/core/src/main/scala/com/stripe/dagon/HMap.scala +++ b/core/src/main/scala/com/stripe/dagon/HMap.scala @@ -60,14 +60,11 @@ final class HMap[K[_], V[_]](protected val map: Map[K[_], V[_]]) { def size: Int = map.size - def exists(p: ((K[_], V[_])) => Boolean): Boolean = - map.exists(p) - - def forall(p: ((K[_], V[_])) => Boolean): Boolean = - map.forall(p) - - def filter(p: ((K[_], V[_])) => Boolean): HMap[K, V] = - HMap.from[K, V](map.filter(p)) + def forallKeys(p: K[_] => Boolean): Boolean = + map.forall { case (k, _) => p(k) } + + def filterKeys(p: K[_] => Boolean): HMap[K, V] = + HMap.from[K, V](map.filter { case (k, _) => p(k) }) def keysOf[T](v: V[T]): Set[K[T]] = map.collect {