Skip to content

Commit

Permalink
Optimize FiberRef init during ZQuery#run (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Aug 25, 2024
1 parent c2c08a1 commit 70ea00d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 12 deletions.
8 changes: 8 additions & 0 deletions benchmarks/src/main/scala/zio/query/ZQueryBenchmark.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ class ZQueryBenchmark {
def zQueryRunSucceedNowBenchmark() =
unsafeRunZIO(ZIO.collectAllDiscard(qs1))

@Benchmark
def zQuerySingleRunSucceedNowBenchmark() =
unsafeRunZIO(qs1.head)

@Benchmark
@OperationsPerInvocation(1000)
def zQueryRunSucceedBenchmark() =
unsafeRunZIO(ZIO.collectAllDiscard(qs2))

@Benchmark
def zQuerySingleRunSucceedBenchmark() =
unsafeRunZIO(qs2.head)
}
46 changes: 34 additions & 12 deletions zio-query/shared/src/main/scala/zio/query/ZQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package zio.query

import zio._
import zio.query.ZQuery.disabledCache
import zio.query.internal._
import zio.stacktracer.TracingImplicits.disableAutoTrace

Expand Down Expand Up @@ -550,25 +551,46 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
def runCache(cache: => Cache)(implicit trace: Trace): ZIO[R, E, A] = {
import ZQuery.{currentCache, currentScope}

def setRef[V](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], newValue: V): V = {
val oldValue = state.getFiberRefOrNull(fiberRef)
state.setFiberRef(fiberRef, newValue)
oldValue
def resetRef[V <: AnyRef](
fid: FiberId.Runtime,
oldRefs: FiberRefs,
newRefs: FiberRefs
)(
fiberRef: FiberRef[V]
): FiberRefs = {
val oldValue = oldRefs.getOrNull(fiberRef)
if (oldValue ne null) newRefs.updatedAs(fid)(fiberRef, oldValue) else newRefs.delete(fiberRef)
}

def resetRef[V <: AnyRef](state: Fiber.Runtime[E, A], fiberRef: FiberRef[V], oldValue: V): Unit =
if (oldValue ne null) state.setFiberRef(fiberRef, oldValue) else state.deleteFiberRef(fiberRef)

asExitOrElse(null) match {
case null =>
ZIO.uninterruptibleMask { restore =>
ZIO.withFiberRuntime[R, E, A] { (state, _) =>
val scope = QueryScope.make()
val oldCache = setRef(state, currentCache, Some(cache))
val oldScope = setRef(state, currentScope, scope)
// NOTE: Running a ZQuery requires up to 3 FiberRefs, which can be expensive to use `locally` with for simple queries.
// Therefore, we handle them all together to avoid the added penalty of running `locally` 3 times
val fid = state.id
val scope = QueryScope.make()
val oldRefs = state.getFiberRefs(false)
val newRefs = {
val refs = oldRefs.updatedAs(fid)(currentCache, Some(cache)).updatedAs(fid)(currentScope, scope)
if (refs.getOrNull(disabledCache) ne null)
refs.delete(disabledCache)
else refs
}
state.setFiberRefs(newRefs)
restore(runToZIO).exitWith { exit =>
resetRef(state, currentCache, oldCache)
resetRef(state, currentScope, oldScope)
val curRefs = state.getFiberRefs(false)
if (curRefs eq newRefs) {
// Cheap and common: FiberRefs were not modified during the execution so we just replace them with the old ones
state.setFiberRefs(oldRefs)
} else {
// FiberRefs were mdified so we need to manually revert each one
var revertedRefs = oldRefs
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentCache)
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentScope)
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(disabledCache)
state.setFiberRefs(revertedRefs)
}
scope.closeAndExitWith(exit)
}
}
Expand Down
10 changes: 10 additions & 0 deletions zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,16 @@ object ZQuerySpec extends ZIOBaseSpec {

q.run.map { case (c1, c2) => assertTrue(c1.isDefined, c1 == c2) }
},
test("disabling caching is reentrant safe") {
val q =
for {
c1 <- ZQuery.fromZIO(ZQuery.currentCache.get)
c2 <- ZQuery.fromZIO(ZQuery.fromZIO(ZQuery.currentCache.get).cached.run).uncached
c3 <- ZQuery.fromZIO(ZQuery.currentCache.get)
} yield (c1, c2, c3)

q.run.map { case (c1, c2, c3) => assertTrue(c1.isDefined, c2.isDefined, c1 == c3, c1 != c2) }
},
test("scope is reentrant safe") {
val q =
for {
Expand Down

0 comments on commit 70ea00d

Please sign in to comment.