Skip to content

Commit

Permalink
Make ZQuery#run reentrant safe (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Jul 10, 2024
1 parent a345a27 commit 83f28d8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
24 changes: 18 additions & 6 deletions zio-query/shared/src/main/scala/zio/query/ZQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -539,23 +539,35 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
* Returns an effect that models executing this query with the specified
* cache.
*/
def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] =
def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] = {
import ZQuery.{currentCache, currentScope}

def setRef[V](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], newValue: V): V = {
val oldValue = state.getFiberRefOrNull(fiberRef)
state.setFiberRef(fiberRef, newValue)
oldValue
}

def resetRef[V <: AnyRef](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], oldValue: V): Unit =
if (oldValue ne null) state.setFiberRef(fiberRef, oldValue) else state.deleteFiberRef(fiberRef)

asExitOrElse(null) match {
case null =>
ZIO.uninterruptibleMask { restore =>
ZIO.withFiberRuntime[R, E, A] { (state, _) =>
val scope = QueryScope.make()
state.setFiberRef(ZQuery.currentCache, Some(cache))
state.setFiberRef(ZQuery.currentScope, scope)
val scope = QueryScope.make()
val oldCache = setRef(state, currentCache, Some(cache))
val oldScope = setRef(state, currentScope, scope)
restore(runToZIO).exitWith { exit =>
state.deleteFiberRef(ZQuery.currentCache)
state.deleteFiberRef(ZQuery.currentScope)
resetRef(state, currentCache, oldCache)
resetRef(state, currentScope, oldScope)
scope.closeAndExitWith(exit)
}
}
}
case exit => exit
}
}

/**
* Returns an effect that models executing this query, returning the query
Expand Down
25 changes: 24 additions & 1 deletion zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zio.query

import zio._
import zio.query.QueryAspect._
import zio.query.internal.QueryScope
import zio.test.Assertion._
import zio.test.TestAspect.{after, nonFlaky, silent}
import zio.test.{TestClock, TestConsole, TestEnvironment, _}
Expand Down Expand Up @@ -270,7 +271,7 @@ object ZQuerySpec extends ZIOBaseSpec {
assert(log)(hasAt(0)(containsString("GetNameById(1)"))) &&
assert(log)(hasAt(0)(containsString("GetNameById(2)"))) &&
assert(log)(hasAt(1)(containsString("GetNameById(1)")))
} @@ nonFlaky,
} @@ nonFlaky(10),
suite("race")(
test("race with never") {
val query = ZQuery.never.race(ZQuery.succeed(()))
Expand Down Expand Up @@ -370,6 +371,28 @@ object ZQuerySpec extends ZIOBaseSpec {
value <- ref.get
} yield assertTrue(value == 1, results.forall(_.isLeft))
}
),
suite("run")(
test("cache is reentrant safe") {
val q =
for {
c1 <- ZQuery.fromZIO(ZQuery.currentCache.get)
_ <- ZQuery.fromZIO(ZQuery.succeed("foo").run)
c2 <- ZQuery.fromZIO(ZQuery.currentCache.get)
} yield (c1, c2)

q.run.map { case (c1, c2) => assertTrue(c1.isDefined, c1 == c2) }
},
test("scope is reentrant safe") {
val q =
for {
c1 <- ZQuery.fromZIO(ZQuery.currentScope.get)
_ <- ZQuery.fromZIO(ZQuery.succeed("foo").run)
c2 <- ZQuery.fromZIO(ZQuery.currentScope.get)
} yield (c1, c2)

q.run.map { case (c1, c2) => assertTrue(c1 != QueryScope.NoOp, c1 == c2) }
}
)
) @@ silent

Expand Down

0 comments on commit 83f28d8

Please sign in to comment.