Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent Degenerate Data Source Implementations #453

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/creating-data-sources.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Let's consider `getUserNameById` from the previous example.
We need to define a corresponding request type that extends `Request` for a given response type:

```scala mdoc:silent
case class GetUserName(id: Int) extends Request[Nothing, String]
case class GetUserName(id: Int) extends Request[Throwable, String]
```

Now let's build the corresponding `DataSource`. We will create a `Batched` data source that executes requests that can be performed in parallel in batches but does not further optimize batches of requests that must be performed sequentially. We need to implement the following functions:
Expand All @@ -47,13 +47,13 @@ def run(requests: Chunk[GetUserName]): ZIO[Any, Nothing, CompletedRequestMap] =
case request :: Nil =>
// get user by ID e.g. SELECT name FROM users WHERE id = $id
val result: Task[String] = ???
result.exit.map(resultMap.insert(request))
result.exit.map(resultMap.insert(request, _))
case batch =>
// get multiple users at once e.g. SELECT id, name FROM users WHERE id IN ($ids)
val result: Task[List[(Int, String)]] = ???
result.fold(
err => requests.foldLeft(resultMap) { case (map, req) => map.insert(req)(Exit.fail(err)) },
_.foldLeft(resultMap) { case (map, (id, name)) => map.insert(GetUserName(id))(Exit.succeed(name)) }
err => requests.foldLeft(resultMap) { case (map, req) => map.insert(req, Exit.fail(err)) },
_.foldLeft(resultMap) { case (map, (id, name)) => map.insert(GetUserName(id), Exit.succeed(name)) }
)
}
}
Expand All @@ -62,7 +62,7 @@ def run(requests: Chunk[GetUserName]): ZIO[Any, Nothing, CompletedRequestMap] =
Now to build a `ZQuery`, we can use `ZQuery.fromRequest` and just pass the request and the data source:

```scala mdoc:silent
def getUserNameById(id: Int): ZQuery[Any, Nothing, String] =
def getUserNameById(id: Int): ZQuery[Any, Throwable, String] =
ZQuery.fromRequest(GetUserName(id))(UserDataSource)
```

Expand Down
12 changes: 6 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ import zio._
import zio.query._

object ZQueryExample extends ZIOAppDefault {
case class GetUserName(id: Int) extends Request[Nothing, String]
case class GetUserName(id: Int) extends Request[Throwable, String]

lazy val UserDataSource: DataSource.Batched[Any, GetUserName] =
new DataSource.Batched[Any, GetUserName] {
Expand All @@ -98,7 +98,7 @@ object ZQueryExample extends ZIOAppDefault {
ZIO.succeed(???)
}

result.exit.map(resultMap.insert(request))
result.exit.map(resultMap.insert(request, _))

case batch: Seq[GetUserName] =>
val result: Task[List[(Int, String)]] = {
Expand All @@ -109,21 +109,21 @@ object ZQueryExample extends ZIOAppDefault {
result.fold(
err =>
requests.foldLeft(resultMap) { case (map, req) =>
map.insert(req)(Exit.fail(err))
map.insert(req, Exit.fail(err))
},
_.foldLeft(resultMap) { case (map, (id, name)) =>
map.insert(GetUserName(id))(Exit.succeed(name))
map.insert(GetUserName(id), Exit.succeed(name))
}
)
}
}

}

def getUserNameById(id: Int): ZQuery[Any, Nothing, String] =
def getUserNameById(id: Int): ZQuery[Any, Throwable, String] =
ZQuery.fromRequest(GetUserName(id))(UserDataSource)

val query: ZQuery[Any, Nothing, List[String]] =
val query: ZQuery[Any, Throwable, List[String]] =
for {
ids <- ZQuery.succeed(1 to 10)
names <- ZQuery.foreachPar(ids)(id => getUserNameById(id)).map(_.toList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ final class CompletedRequestMap private (private val map: Map[Any, Exit[Any, Any
/**
* Appends the specified result to the completed requests map.
*/
def insert[E, A](request: Request[E, A])(result: Exit[E, A]): CompletedRequestMap =
def insert[E, A](request: Request[E, A], result: Exit[E, A]): CompletedRequestMap =
new CompletedRequestMap(self.map + (request -> result))

/**
* Appends the specified optional result to the completed request map.
*/
def insertOption[E, A](request: Request[E, A])(result: Exit[E, Option[A]]): CompletedRequestMap =
def insertOption[E, A](request: Request[E, A], result: Exit[E, Option[A]]): CompletedRequestMap =
result match {
case Exit.Failure(e) => insert(request)(Exit.failCause(e))
case Exit.Success(Some(a)) => insert(request)(Exit.succeed(a))
case Exit.Failure(e) => insert(request, Exit.failCause(e))
case Exit.Success(Some(a)) => insert(request, Exit.succeed(a))
case Exit.Success(None) => self
}

Expand All @@ -64,8 +64,8 @@ final class CompletedRequestMap private (private val map: Map[Any, Exit[Any, Any
/**
* Collects all requests in a set.
*/
def requests: Set[Request[Any, Any]] =
map.keySet.asInstanceOf[Set[Request[Any, Any]]]
def requests: Set[Request[_, _]] =
map.keySet.asInstanceOf[Set[Request[_, _]]]

override def toString: String =
s"CompletedRequestMap(${map.mkString(", ")})"
Expand Down
14 changes: 7 additions & 7 deletions zio-query/shared/src/main/scala/zio/query/DataSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ object DataSource {
new DataSource.Batched[Any, A] {
val identifier: String = name
def run(requests: Chunk[A])(implicit trace: Trace): ZIO[Any, Nothing, CompletedRequestMap] =
ZIO.succeed(requests.foldLeft(CompletedRequestMap.empty)((map, k) => map.insert(k)(Exit.succeed(f(k)))))
ZIO.succeed(requests.foldLeft(CompletedRequestMap.empty)((map, k) => map.insert(k, Exit.succeed(f(k)))))
}

/**
Expand Down Expand Up @@ -242,7 +242,7 @@ object DataSource {
bs => requests.zip(bs.map(Exit.succeed(_)))
)
.map(_.foldLeft(CompletedRequestMap.empty) { case (map, (k, v)) =>
map.insertOption(k)(v)
map.insertOption(k, v)
})
}

Expand All @@ -258,7 +258,7 @@ object DataSource {
)(f: Chunk[A] => Chunk[B], g: B => Request[Nothing, B])(implicit
ev: A <:< Request[Nothing, B]
): DataSource[Any, A] =
fromFunctionBatchedWithZIO(name)(as => Exit.succeed(f(as)), g)
fromFunctionBatchedWithZIO[Any, Nothing, A, B](name)(as => Exit.succeed(f(as)), g)

/**
* Constructs a data source from an effectual function that takes a list of
Expand All @@ -281,7 +281,7 @@ object DataSource {
bs => bs.map(b => (g(b), Exit.succeed(b)))
)
.map(_.foldLeft(CompletedRequestMap.empty) { case (map, (k, v)) =>
map.insert(k)(v)
map.insert(k, v)
})
}

Expand All @@ -303,7 +303,7 @@ object DataSource {
bs => requests.zip(bs.map(Exit.succeed(_)))
)
.map(_.foldLeft(CompletedRequestMap.empty) { case (map, (k, v)) =>
map.insert(k)(v)
map.insert(k, v)
})
}

Expand All @@ -318,7 +318,7 @@ object DataSource {
def run(requests: Chunk[A])(implicit trace: Trace): ZIO[R, Nothing, CompletedRequestMap] =
ZIO
.foreachPar(requests)(a => f(a).exit.map((a, _)))
.map(_.foldLeft(CompletedRequestMap.empty) { case (map, (k, v)) => map.insert(k)(v) })
.map(_.foldLeft(CompletedRequestMap.empty) { case (map, (k, v)) => map.insert(k, v) })
}

/**
Expand All @@ -343,7 +343,7 @@ object DataSource {
ZIO
.foreachPar(requests)(a => f(a).exit.map((a, _)))
.map(_.foldLeft(CompletedRequestMap.empty) { case (map, (k, v)) =>
map.insertOption(k)(v)
map.insertOption(k, v)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace
/**
* `QueryFailure` keeps track of details relevant to query failures.
*/
final case class QueryFailure(dataSource: DataSource[Nothing, Nothing], request: Request[Any, Any])
final case class QueryFailure(dataSource: DataSource[Nothing, Nothing], request: Request[_, _])
extends Throwable(null, null, true, false) {
override def getMessage: String =
s"Data source ${dataSource.identifier} did not complete request ${request.toString}."
Expand Down
4 changes: 2 additions & 2 deletions zio-query/shared/src/main/scala/zio/query/Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace
* that may fail with an `E`.
*
* {{{
* sealed trait UserRequest[+A] extends Request[Nothing, A]
* sealed trait UserRequest[A] extends Request[Nothing, A]
*
* case object GetAllIds extends UserRequest[List[Int]]
* final case class GetNameById(id: Int) extends UserRequest[String]
*
* }}}
*/
trait Request[+E, +A]
trait Request[E, A]
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package zio.query.internal

import zio.query.internal.BlockedRequests._
import zio.query.{Cache, CompletedRequestMap, DataSource, DataSourceAspect, Described, QueryFailure, ZQuery}
import zio.query.{Cache, CompletedRequestMap, DataSource, DataSourceAspect, Described, QueryFailure, Request, ZQuery}
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Exit, Promise, Trace, ZEnvironment, ZIO}

Expand Down Expand Up @@ -108,7 +108,7 @@ private[query] sealed trait BlockedRequests[-R] { self =>
completedRequests <- dataSource.runAll(sequential.map(_.map(_.request))).catchAllCause { cause =>
ZIO.succeed {
sequential.map(_.map(_.request)).flatten.foldLeft(CompletedRequestMap.empty) {
case (map, request) => map.insert(request)(Exit.failCause(cause))
case (map, request) => map.insert(request, Exit.failCause(cause))
}
}
}
Expand All @@ -123,7 +123,10 @@ private[query] sealed trait BlockedRequests[-R] { self =>
}
_ <- ZIO.foreachDiscard(leftovers) { request =>
ZIO.foreach(completedRequests.lookup(request)) { response =>
Promise.make[Any, Any].tap(_.done(response)).flatMap(cache.put(request, _))
Promise
.make[Any, Any]
.tap(_.done(response))
.flatMap(cache.put(request.asInstanceOf[Request[Any, Any]], _))
}
}
} yield ()
Expand Down
51 changes: 26 additions & 25 deletions zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,22 +306,22 @@ object ZQuerySpec extends ZIOBaseSpec {
val userIds: List[Int] = (1 to 26).toList
val userNames: Map[Int, String] = userIds.zip(('a' to 'z').map(_.toString)).toMap

sealed trait UserRequest[+A] extends Request[Nothing, A]
sealed trait UserRequest[A] extends Request[Nothing, A]

case object GetAllIds extends UserRequest[List[Int]]
final case class GetNameById(id: Int) extends UserRequest[String]

val UserRequestDataSource: DataSource[Any, UserRequest[Any]] =
DataSource.Batched.make[Any, UserRequest[Any]]("UserRequestDataSource") { requests =>
val UserRequestDataSource: DataSource[Any, UserRequest[_]] =
DataSource.Batched.make[Any, UserRequest[_]]("UserRequestDataSource") { requests =>
ZIO.when(requests.toSet.size != requests.size)(ZIO.dieMessage("Duplicate requests)")) *>
Console.printLine(requests.toString).orDie *>
ZIO.succeed {
requests.foldLeft(CompletedRequestMap.empty) {
case (completedRequests, GetAllIds) => completedRequests.insert(GetAllIds)(Exit.succeed(userIds))
case (completedRequests, GetAllIds) => completedRequests.insert(GetAllIds, Exit.succeed(userIds))
case (completedRequests, GetNameById(id)) =>
userNames
.get(id)
.fold(completedRequests)(name => completedRequests.insert(GetNameById(id))(Exit.succeed(name)))
.fold(completedRequests)(name => completedRequests.insert(GetNameById(id), Exit.succeed(name)))
}
}
}
Expand All @@ -340,12 +340,12 @@ object ZQuerySpec extends ZIOBaseSpec {

case object GetFoo extends Request[Nothing, String]
val getFoo: ZQuery[Any, Nothing, String] = ZQuery.fromRequest(GetFoo)(
DataSource.fromFunctionZIO("foo")(_ => Console.printLine("Running foo query") *> ZIO.succeed("foo"))
DataSource.fromFunctionZIO("foo")(_ => Console.printLine("Running foo query").orDie *> ZIO.succeed("foo"))
)

case object GetBar extends Request[Nothing, String]
val getBar: ZQuery[Any, Nothing, String] = ZQuery.fromRequest(GetBar)(
DataSource.fromFunctionZIO("bar")(_ => Console.printLine("Running bar query") *> ZIO.succeed("bar"))
DataSource.fromFunctionZIO("bar")(_ => Console.printLine("Running bar query").orDie *> ZIO.succeed("bar"))
)

case object NeverRequest extends Request[Nothing, Nothing]
Expand Down Expand Up @@ -384,7 +384,7 @@ object ZQuerySpec extends ZIOBaseSpec {
val dieQuery: ZQuery[Any, Nothing, Nothing] =
ZQuery.fromRequest(DieRequest)(dieDataSource)

sealed trait CacheRequest[+A] extends Request[Nothing, A]
sealed trait CacheRequest[A] extends Request[Nothing, A]

final case class Get(key: Int) extends CacheRequest[Option[Int]]
case object GetAll extends CacheRequest[Map[Int, Int]]
Expand All @@ -394,38 +394,39 @@ object ZQuerySpec extends ZIOBaseSpec {

object Cache {

trait Service extends DataSource[Any, CacheRequest[Any]] {
trait Service extends DataSource[Any, CacheRequest[_]] {
val clear: ZIO[Any, Nothing, Unit]
val log: ZIO[Any, Nothing, List[List[Set[CacheRequest[Any]]]]]
val log: ZIO[Any, Nothing, List[List[Set[CacheRequest[_]]]]]
}

val live: ZLayer[Any, Nothing, Cache] =
ZLayer.fromZIO {
for {
cache <- Ref.make(Map.empty[Int, Int])
ref <- Ref.make[List[List[Set[CacheRequest[Any]]]]](Nil)
ref <- Ref.make[List[List[Set[CacheRequest[_]]]]](Nil)
} yield new Service {
val clear: ZIO[Any, Nothing, Unit] =
cache.set(Map.empty) *> ref.set(List.empty)
val log: ZIO[Any, Nothing, List[List[Set[CacheRequest[Any]]]]] =
val log: ZIO[Any, Nothing, List[List[Set[CacheRequest[_]]]]] =
ref.get
val identifier: String =
"CacheDataSource"
def runAll(requests: Chunk[Chunk[CacheRequest[Any]]])(implicit
def runAll(requests: Chunk[Chunk[CacheRequest[_]]])(implicit
trace: Trace
): ZIO[Any, Nothing, CompletedRequestMap] =
ref.update(requests.map(_.toSet).toList :: _) *>
ZIO
.foreach(requests) { requests =>
ZIO
.foreachPar(requests) {
case Get(key) => cache.get.map(_.get(key))
case GetAll => cache.get
case Put(key, value) => cache.update(_ + (key -> value))
case Get(key) =>
cache.get.map(_.get(key)).exit.map(CompletedRequestMap.empty.insert(Get(key), _))
case GetAll =>
cache.get.exit.map(CompletedRequestMap.empty.insert(GetAll, _))
case Put(key, value) =>
cache.update(_ + (key -> value)).exit.map(CompletedRequestMap.empty.insert(Put(key, value), _))
}
.map(requests.zip(_).foldLeft(CompletedRequestMap.empty) { case (map, (k, v)) =>
map.insert(k)(Exit.succeed(v))
})
.map(_.foldLeft(CompletedRequestMap.empty)(_ ++ _))
}
.map(_.foldLeft(CompletedRequestMap.empty)(_ ++ _))
}
Expand All @@ -452,7 +453,7 @@ object ZQuerySpec extends ZIOBaseSpec {
val clear: ZIO[Cache, Nothing, Unit] =
ZIO.serviceWithZIO(_.clear)

val log: ZIO[Cache, Nothing, List[List[Set[CacheRequest[Any]]]]] =
val log: ZIO[Cache, Nothing, List[List[Set[CacheRequest[_]]]]] =
ZIO.serviceWithZIO(_.log)
}

Expand Down Expand Up @@ -509,7 +510,7 @@ object ZQuerySpec extends ZIOBaseSpec {
sealed trait DataSourceErrors
case class NotFound(id: Int) extends DataSourceErrors

sealed trait Req[+A] extends Request[DataSourceErrors, A]
sealed trait Req[A] extends Request[DataSourceErrors, A]
object Req {
case object GetAll extends Req[Map[Int, String]]
final case class Get(id: Int) extends Req[String]
Expand All @@ -528,9 +529,9 @@ object ZQuerySpec extends ZIOBaseSpec {
backendGetAll.map { allItems =>
allItems
.foldLeft(CompletedRequestMap.empty) { case (result, (id, value)) =>
result.insert(Req.Get(id))(Exit.succeed(value))
result.insert(Req.Get(id), Exit.succeed(value))
}
.insert(Req.GetAll)(Exit.succeed(allItems))
.insert(Req.GetAll, Exit.succeed(allItems))
}
} else {
for {
Expand All @@ -542,8 +543,8 @@ object ZQuerySpec extends ZIOBaseSpec {
case (result, Req.GetAll) => result
case (result, req @ Req.Get(id)) =>
items.get(id) match {
case Some(value) => result.insert(req)(Exit.succeed(value))
case None => result.insert(req)(Exit.fail(NotFound(id)))
case Some(value) => result.insert(req, Exit.succeed(value))
case None => result.insert(req, Exit.fail(NotFound(id)))
}
}
}
Expand Down