diff --git a/build.sbt b/build.sbt index b5ce6bfe..822eba7d 100644 --- a/build.sbt +++ b/build.sbt @@ -29,7 +29,6 @@ lazy val buildSettings = Seq( lazy val commonSettings = Seq( resolvers += Resolver.sonatypeRepo("releases"), libraryDependencies ++= Seq( - "org.typelevel" %%% "cats-core" % "0.8.1", "org.typelevel" %%% "cats-free" % "0.8.1", "org.scalatest" %%% "scalatest" % "3.0.0" % "test", compilerPlugin( diff --git a/shared/src/main/scala/fetch.scala b/shared/src/main/scala/fetch.scala index fad19f99..2819b1cc 100644 --- a/shared/src/main/scala/fetch.scala +++ b/shared/src/main/scala/fetch.scala @@ -61,11 +61,14 @@ object Query { } } -/** Requests in Fetch Free monad. - */ -sealed trait FetchRequest extends Product with Serializable { - def fullfilledBy(cache: DataSourceCache): Boolean -} +trait FetchException extends Throwable with Product with Serializable +case class NotFound(env: Env, request: FetchOne[_, _]) extends FetchException +case class MissingIdentities(env: Env, missing: Map[DataSourceName, List[Any]]) + extends FetchException +case class UnhandledException(err: Throwable) extends FetchException + +/** Requests in Fetch Free monad. */ +sealed trait FetchRequest extends Product with Serializable sealed trait FetchQuery[I, A] extends FetchRequest { def missingIdentities(cache: DataSourceCache): List[I] @@ -73,24 +76,14 @@ sealed trait FetchQuery[I, A] extends FetchRequest { def identities: NonEmptyList[I] } -trait FetchException extends Throwable with Product with Serializable -case class NotFound(env: Env, request: FetchOne[_, _]) extends FetchException -case class MissingIdentities(env: Env, missing: Map[DataSourceName, List[Any]]) - extends FetchException -case class UnhandledException(err: Throwable) extends FetchException - /** * Primitive operations in the Fetch Free monad. */ sealed abstract class FetchOp[A] extends Product with Serializable -final case class Fetched[A](a: A) extends FetchOp[A] final case class FetchOne[I, A](id: I, ds: DataSource[I, A]) extends FetchOp[A] with FetchQuery[I, A] { - override def fullfilledBy(cache: DataSourceCache): Boolean = { - cache.contains(ds.identity(id)) - } override def missingIdentities(cache: DataSourceCache): List[I] = { cache.get[A](ds.identity(id)).fold(List(id))(_ => Nil) } @@ -101,9 +94,6 @@ final case class FetchOne[I, A](id: I, ds: DataSource[I, A]) final case class FetchMany[I, A](ids: NonEmptyList[I], ds: DataSource[I, A]) extends FetchOp[List[A]] with FetchQuery[I, A] { - override def fullfilledBy(cache: DataSourceCache): Boolean = { - ids.forall(i => cache.contains(ds.identity(i))) - } override def missingIdentities(cache: DataSourceCache): List[I] = { ids.toList.distinct.filterNot(i => cache.contains(ds.identity(i))) @@ -113,12 +103,10 @@ final case class FetchMany[I, A](ids: NonEmptyList[I], ds: DataSource[I, A]) } final case class Concurrent(queries: NonEmptyList[FetchQuery[Any, Any]]) extends FetchOp[InMemoryCache] - with FetchRequest { - override def fullfilledBy(cache: DataSourceCache): Boolean = { - queries.forall(_.fullfilledBy(cache)) - } -} -final case class Thrown[A](err: Throwable) extends FetchOp[A] + with FetchRequest + +final case class Join[A, B](fl: Fetch[A], fr: Fetch[B]) extends FetchOp[(A, B)] +final case class Thrown[A](err: Throwable) extends FetchOp[A] object `package` { type DataSourceName = String @@ -139,6 +127,9 @@ object `package` { override def product[A, B](fa: Fetch[A], fb: Fetch[B]): Fetch[(A, B)] = Fetch.join(fa, fb) + override def map2[A, B, Z](fa: Fetch[A], fb: Fetch[B])(f: (A, B) => Z): Fetch[Z] = + Fetch.join(fa, fb).map { case (a, b) => f(a, b) } + override def tuple2[A, B](fa: Fetch[A], fb: Fetch[B]): Fetch[(A, B)] = Fetch.join(fa, fb) } @@ -174,7 +165,7 @@ object `package` { Free.liftF(FetchMany(NonEmptyList(i, is.toList), DS)) /** - * Given a non empty list of `FetchRequest`s, lift it to the `Fetch` monad. When executing + * Given a non empty list of `FetchQuery`s, lift it to the `Fetch` monad. When executing * the fetch, data sources will be queried and the fetch will return an `InMemoryCache` * containing the results. */ @@ -205,109 +196,8 @@ object `package` { * Join two fetches from any data sources and return a Fetch that returns a tuple with the two * results. It implies concurrent execution of fetches. */ - def join[A, B](fl: Fetch[A], fr: Fetch[B]): Fetch[(A, B)] = { - def parallelizableQueries(fa: Fetch[_], fb: Fetch[_]): List[FetchQuery[_, _]] = - combineQueries(independentQueries(fa) ++ independentQueries(fb)) - - def parallelizableQueriesAny(fa: Fetch[_], fb: Fetch[_]): List[FetchQuery[Any, Any]] = - parallelizableQueries(fa, fb).asInstanceOf[List[FetchQuery[Any, Any]]] - - def joinWithQueries( - fl: Fetch[A], - fr: Fetch[B], - queries: List[FetchQuery[Any, Any]] - ): Fetch[(A, B)] = { - queries.toNel.fold(Monad[Fetch].tuple2(fl, fr)) { queriesNel => - concurrently(queriesNel).flatMap { cache => - val sfl = fl.compile(simplify(cache)) - val sfr = fr.compile(simplify(cache)) - - val deps = parallelizableQueriesAny(sfl, sfr) - // joinWithQueries(sfl, sfr, deps diff fetches) - joinWithQueries(sfl, sfr, deps) - } - } - } - - joinWithQueries(fl, fr, parallelizableQueriesAny(fl, fr)) - } - - /** - * Use a `DataSourceCache` to optimize a `FetchOp`. - * If the cache contains all the fetch identities, the fetch doesn't need to be - * executed and can be replaced by cached results. - */ - private[this] def simplify(cache: InMemoryCache): (FetchOp ~> FetchOp) = { - new (FetchOp ~> FetchOp) { - def apply[B](fetchOp: FetchOp[B]): FetchOp[B] = fetchOp match { - case one @ FetchOne(id, ds) => - cache.get[B](ds.identity(id)).fold(fetchOp)(b => Fetched(b)) - case many @ FetchMany(ids, ds) => - val fetched = ids.traverse(id => cache.get(ds.identity(id))) - fetched.fold(fetchOp)(results => Fetched(results.toList)) - case conc @ Concurrent(manies) => - val newManies = manies.toList.filterNot(_.fullfilledBy(cache)) - newManies.toNel.fold[FetchOp[B]](Fetched(cache))(Concurrent(_)) - case other => other - } - } - } - - /** - * Combine multiple queries so the resulting `List` only contains one `FetchQuery` - * per `DataSource`. - */ - private[this] def combineQueries(qs: List[FetchQuery[_, _]]): List[FetchQuery[_, _]] = - qs.foldMap[Map[DataSource[_, _], NonEmptyList[Any]]] { - case FetchOne(id, ds) => Map(ds -> NonEmptyList.of[Any](id)) - case FetchMany(ids, ds) => Map(ds -> ids.widen[Any]) - } - .mapValues { nel => - // workaround because NEL[Any].distinct would need Order[Any] - nel.unsafeListOp(_.distinct) - } - .toList - .map { - case (ds, NonEmptyList(id, Nil)) => FetchOne(id, ds.castDS[Any, Any]) - case (ds, ids) => FetchMany(ids, ds.castDS[Any, Any]) - } - - private[this] type FetchOps = List[FetchOp[_]] - private[this] type KeepFetches[A] = Writer[FetchOps, A] - private[this] type AnalyzeTop[A] = EitherT[KeepFetches, Unit, A] - - private[this] object AnalyzeTop { - def stopWith[R](list: FetchOps): AnalyzeTop[R] = - AnalyzeTop.stop(Writer.tell(list)) - - def stopEmpty[R]: AnalyzeTop[R] = - AnalyzeTop.stop(Writer.value(())) - - def stop[R](k: KeepFetches[Unit]): AnalyzeTop[R] = - EitherT.left[KeepFetches, Unit, R](k) - - def go[X](k: KeepFetches[X]): AnalyzeTop[X] = - EitherT.right[KeepFetches, Unit, X](k) - } - - /** - * Get a list of independent `FetchQuery`s for a given `Fetch`. - */ - private[this] def independentQueries(f: Fetch[_]): List[FetchQuery[_, _]] = { - val analyzeTop: FetchOp ~> AnalyzeTop = new (FetchOp ~> AnalyzeTop) { - def apply[A](op: FetchOp[A]): AnalyzeTop[A] = op match { - case fetc @ Fetched(c) => AnalyzeTop.go(Writer(List(), c)) - case one @ FetchOne(_, _) => AnalyzeTop.stopWith(List(one)) - case conc @ Concurrent(as) => AnalyzeTop.stopWith(as.toList.asInstanceOf[FetchOps]) - case _ => AnalyzeTop.stopEmpty - } - } - - f.foldMap[AnalyzeTop](analyzeTop).value.written.collect { - case one @ FetchOne(_, _) => one - case many @ FetchMany(_, _) => many - } - } + def join[A, B](fl: Fetch[A], fr: Fetch[B]): Fetch[(A, B)] = + Free.liftF(Join(fl, fr)) class FetchRunner[M[_]] { def apply[A]( diff --git a/shared/src/main/scala/freeinspect.scala b/shared/src/main/scala/freeinspect.scala new file mode 100644 index 00000000..2e3cd705 --- /dev/null +++ b/shared/src/main/scala/freeinspect.scala @@ -0,0 +1,61 @@ +/* + * Copyright 2016 47 Degrees, LLC. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.free + +import cats.{Id, ~>} +import cats.data.Coproduct + +import scala.annotation.tailrec + +object FreeTopExt { + + @tailrec + def inspect[F[_]](free: Free[F, _]): Option[F[_]] = free match { + case Free.Pure(p) => None + case Free.Suspend(fa) => Some(fa) + case Free.FlatMapped(free2, _) => inspect(free2) + } + + /** + * Is the first step a `Pure` ? + * + * Could be implemented using `resume` as : + * {{{ + * val liftCoyoneda: F ~> Coyoneda[F, ?] = λ[(F ~> Coyoneda[F, ?])](Coyoneda.lift(_)) + * free.compile[Coyoneda[F, ?]](liftCoyoneda).resume.toOption + * }}} + */ + def inspectPure[F[_], A](free: Free[F, A]): Option[A] = free match { + case Free.Pure(p) => Some(p) + case Free.Suspend(fa) => None + case Free.FlatMapped(free2, _) => None + } + + def modify[F[_], A](free: Free[F, A])(f: F ~> Coproduct[F, Id, ?]): Free[F, A] = + free match { + case pure @ Free.Pure(_) => pure + case Free.Suspend(fa) => f(fa).run.fold(Free.liftF(_), Free.Pure(_)) + case Free.FlatMapped(free2, cont) => Free.FlatMapped(modify(free2)(f), cont).step + } + + def print[F[_], A](free: Free[F, A]): String = free match { + case Free.Pure(p) => s"Pure($p)" + case Free.Suspend(fa) => s"Suspend($fa)" + case Free.FlatMapped(free2, _) => s"FlatMapped(${print(free2)}, Free>)" + } + +} diff --git a/shared/src/main/scala/interpreters.scala b/shared/src/main/scala/interpreters.scala index 31b48faf..a1eb522a 100644 --- a/shared/src/main/scala/interpreters.scala +++ b/shared/src/main/scala/interpreters.scala @@ -18,13 +18,14 @@ package fetch import scala.collection.immutable._ -import cats.{Applicative, ApplicativeError, MonadError, Semigroup, ~>} -import cats.data.{EitherT, OptionT, NonEmptyList, StateT, Validated, ValidatedNel} +import cats.{Applicative, ApplicativeError, Id, Monad, MonadError, Semigroup, ~>} +import cats.data.{Coproduct, EitherT, NonEmptyList, OptionT, StateT, Validated, ValidatedNel} import cats.free.Free import cats.instances.option._ import cats.instances.list._ import cats.instances.map._ import cats.instances.tuple._ +import cats.syntax.cartesian._ import cats.syntax.either._ import cats.syntax.flatMap._ import cats.syntax.foldable._ @@ -37,25 +38,37 @@ import cats.syntax.semigroup._ import cats.syntax.traverse._ import cats.syntax.validated._ +import cats.free.FreeTopExt + trait FetchInterpreters { def interpreter[M[_]: FetchMonadError]: FetchOp ~> FetchInterpreter[M]#f = - maxBatchSizePhase.andThen[FetchInterpreter[M]#f]( - Free.foldMap[FetchOp, FetchInterpreter[M]#f](coreInterpreter[M])) + parallelJoinPhase + .andThen[Fetch](Free.foldMap(maxBatchSizePhase)) + .andThen[FetchInterpreter[M]#f]( + Free.foldMap[FetchOp, FetchInterpreter[M]#f](coreInterpreter[M])) def coreInterpreter[M[_]]( implicit M: FetchMonadError[M] ): FetchOp ~> FetchInterpreter[M]#f = { new (FetchOp ~> FetchInterpreter[M]#f) { def apply[A](fa: FetchOp[A]): FetchInterpreter[M]#f[A] = - StateT[M, FetchEnv, A] { env: FetchEnv => - fa match { - case Thrown(e) => M.raiseError(UnhandledException(e)) - case Fetched(a) => M.pure((env, a)) - case one @ FetchOne(_, _) => processOne(one, env) - case many @ FetchMany(_, _) => processMany(many, env) - case conc @ Concurrent(_) => processConcurrent(conc, env) - } + fa match { + case Join(fl, fr) => + Monad[FetchInterpreter[M]#f].tuple2( + fl.foldMap[FetchInterpreter[M]#f](coreInterpreter[M]), + fr.foldMap[FetchInterpreter[M]#f](coreInterpreter[M])) + + case other => + StateT[M, FetchEnv, A] { env: FetchEnv => + other match { + case Thrown(e) => M.raiseError(UnhandledException(e)) + case one @ FetchOne(_, _) => processOne(one, env) + case many @ FetchMany(_, _) => processMany(many, env) + case conc @ Concurrent(_) => processConcurrent(conc, env) + case Join(_, _) => throw new Exception("join already handled") + } + } } } } @@ -251,7 +264,7 @@ trait FetchInterpreters { private[this] def batchMany[I, A](many: FetchMany[I, A]): Fetch[List[A]] = { val batchedFetches = manyInBatches(many) - batchedFetches.reduceLeftM[Fetch, List[A]](Free.liftF) { + batchedFetches.reduceLeftM[Fetch, List[A]](Free.liftF[FetchOp, List[A]]) { case (results, fetchMany) => Free.liftF(fetchMany).map(results ++ _) } @@ -304,4 +317,88 @@ trait FetchInterpreters { either.fold[EitherT[M, A, B]](error => EitherT.left[M, A, B](M.raiseError(error)), EitherT.pure[M, A, B](_)) } + + val parallelJoinPhase: FetchOp ~> Fetch = + new (FetchOp ~> Fetch) { + def apply[A](op: FetchOp[A]): Fetch[A] = op match { + case join @ Join(fl, fr) => + val fetchJoin = Free.liftF(join) + val indepQueries = combineQueries(independentQueries(fetchJoin)) + parallelJoin(fetchJoin, indepQueries) + case other => Free.liftF(other) + } + } + + private[this] def parallelJoin[A, B]( + fetchJoin: Fetch[(A, B)], + queries: List[FetchQuery[_, _]] + ): Fetch[(A, B)] = { + combineQueries(queries).asInstanceOf[List[FetchQuery[Any, Any]]].toNel.fold(fetchJoin) { + queriesNel => + Free.liftF(Concurrent(queriesNel)).flatMap { cache => + val simplerFetchJoin = simplify(cache)(fetchJoin) + val indepQueries = independentQueries(simplerFetchJoin) + indepQueries.toNel.fold(simplerFetchJoin) { queries => + parallelJoin(simplerFetchJoin, queries.toList) + } + } + } + } + + private[this] def independentQueries(f: Fetch[_]): List[FetchQuery[_, _]] = + // we need the `.step` below to ignore pure values when we search for + // independent queries, but this also has the consequence that pure + // values can be executed multiple times. + // eg : Fetch.pure(5).map { i => println("hello"); i * 2 } + FreeTopExt.inspect(f.step).foldMap { + case Join(ffl, ffr) => independentQueries(ffl) ++ independentQueries(ffr) + case one @ FetchOne(_, _) => one :: Nil + case many @ FetchMany(_, _) => many :: Nil + case _ => Nil + } + + /** + * Use a `DataSourceCache` to optimize a `FetchOp`. + * If the cache contains all the fetch identities, the fetch doesn't need to be + * executed and can be replaced by cached results. + */ + private[this] def simplify[A](cache: InMemoryCache)(fetch: Fetch[A]): Fetch[A] = + FreeTopExt.modify(fetch)( + new (FetchOp ~> Coproduct[FetchOp, Id, ?]) { + def apply[X](fetchOp: FetchOp[X]): Coproduct[FetchOp, Id, X] = fetchOp match { + case one @ FetchOne(id, ds) => + Coproduct[FetchOp, Id, X](cache.get[X](ds.identity(id)).toRight(one)) + case many @ FetchMany(ids, ds) => + val fetched = ids.traverse(id => cache.get(ds.identity(id))) + Coproduct[FetchOp, Id, X](fetched.map(_.toList).toRight(many)) + case join @ Join(fl, fr) => + val sfl = simplify(cache)(fl) + val sfr = simplify(cache)(fr) + val optTuple = (FreeTopExt.inspectPure(sfl) |@| FreeTopExt.inspectPure(sfr)).tupled + Coproduct[FetchOp, Id, X](optTuple.toRight(Join(sfl, sfr))) + case other => + Coproduct.leftc(other) + } + } + ) + + /** + * Combine multiple queries so the resulting `List` only contains one `FetchQuery` + * per `DataSource`. + */ + private[this] def combineQueries(qs: List[FetchQuery[_, _]]): List[FetchQuery[_, _]] = + qs.foldMap[Map[DataSource[_, _], NonEmptyList[Any]]] { + case FetchOne(id, ds) => Map(ds -> NonEmptyList.of[Any](id)) + case FetchMany(ids, ds) => Map(ds -> ids.widen[Any]) + } + .mapValues { nel => + // workaround because NEL[Any].distinct would need Order[Any] + nel.unsafeListOp(_.distinct) + } + .toList + .map { + case (ds, NonEmptyList(id, Nil)) => FetchOne(id, ds.castDS[Any, Any]) + case (ds, ids) => FetchMany(ids, ds.castDS[Any, Any]) + } + } diff --git a/shared/src/test/scala/FetchTests.scala b/shared/src/test/scala/FetchTests.scala index 9a6540a4..16aaa318 100644 --- a/shared/src/test/scala/FetchTests.scala +++ b/shared/src/test/scala/FetchTests.scala @@ -14,6 +14,8 @@ * limitations under the License. */ +package fetch + import scala.concurrent._ import scala.concurrent.duration._ @@ -777,6 +779,23 @@ class FetchTests extends AsyncFreeSpec with Matchers { env.rounds.size shouldEqual 1 } } + + "Pure Fetches should be ignored in the parallel optimization" in { + val fetch: Fetch[(Int, Int)] = Fetch.join( + one(1), + for { + a <- Fetch.pure(2) + b <- one(3) + } yield a + b + ) + + Fetch.runFetch[Future](fetch).map { + case (env, res) => + res shouldEqual (1, 5) + totalFetched(env.rounds) shouldEqual 2 + env.rounds.size shouldEqual 1 + } + } } class FetchReportingTests extends AsyncFreeSpec with Matchers {