Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace SemispaceCache with LRU Cache #1163

Merged
merged 8 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions modules/core/shared/src/main/scala/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ object Session {
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Resource[F, Resource[F, Session[F]]] = {
Expand Down Expand Up @@ -470,9 +470,9 @@ object Session {
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Resource[F, Tracer[F] => Resource[F, Session[F]]] = {
Expand Down Expand Up @@ -508,9 +508,9 @@ object Session {
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Resource[F, Session[F]] =
Expand All @@ -532,9 +532,9 @@ object Session {
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
commandCache: Int = 2048,
queryCache: Int = 2048,
parseCache: Int = 2048,
readTimeout: Duration = Duration.Inf,
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
): Tracer[F] => Resource[F, Session[F]] =
Expand Down
96 changes: 96 additions & 0 deletions modules/core/shared/src/main/scala/data/Cache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data

/**
* Immutable, least recently used cache.
*
* Entries are stored in the `entries` hash map. A numeric stamp is assigned to
* each entry and stored in the `usages` field, which provides a bidirectional
* mapping between stamp and key, sorted by stamp. The `entries` and `usages`
* fields always have the same size.
*
* Upon put and get of an entry, a new stamp is assigned and `usages`
* is updated. Stamps are assigned in ascending order and each stamp is used only once.
* Hence, the head of `usages` contains the least recently used key.
*/
sealed abstract case class Cache[K, V] private (
max: Int,
entries: Map[K, V]
)(usages: SortedBiMap[Long, K],
stamp: Long
) {
assert(entries.size == usages.size)

def size: Int = entries.size

def contains(k: K): Boolean = entries.contains(k)

/**
* Gets the value associated with the specified key.
*
* Accessing an entry makes it the most recently used entry, hence a new cache
* is returned with the target entry updated to reflect the recent access.
*/
def get(k: K): Option[(Cache[K, V], V)] =
entries.get(k) match {
case Some(v) =>
val newUsages = usages + (stamp -> k)
val newCache = Cache(max, entries, newUsages, stamp + 1)
Some(newCache -> v)
case None =>
None
}

/**
* Returns a new cache with the specified entry added along with the
* entry that was evicted, if any.
*
* The evicted value is defined under two circumstances:
* - the cache already contains a different value for the specified key,
* in which case the old pairing is returned
* - the cache has reeached its max size, in which case the least recently
* used value is evicted
*
* Note: if the cache contains (k, v), calling `put(k, v)` does NOT result
* in an eviction, but calling `put(k, v2)` where `v != v2` does.
*/
def put(k: K, v: V): (Cache[K, V], Option[(K, V)]) =
if (max <= 0) {
// max is 0 so immediately evict the new entry
(this, Some((k, v)))
} else if (entries.size >= max && !contains(k)) {
// at max size already and we need to add a new key, hence we must evict
// the least recently used entry
val (lruStamp, lruKey) = usages.head
val newEntries = entries - lruKey + (k -> v)
val newUsages = usages - lruStamp + (stamp -> k)
val newCache = Cache(max, newEntries, newUsages, stamp + 1)
(newCache, Some(lruKey -> entries(lruKey)))
} else {
// not growing past max size at this point, so only need to evict if
// the new entry is replacing an existing entry with different value
val newEntries = entries + (k -> v)
val newUsages = usages + (stamp -> k)
val newCache = Cache(max, newEntries, newUsages, stamp + 1)
val evicted = entries.get(k).filter(_ != v).map(k -> _)
(newCache, evicted)
}

def values: Iterable[V] = entries.values

override def toString: String =
usages.entries.iterator.map { case (_, k) => s"$k -> ${entries(k)}" }.mkString("Cache(", ", ", ")")
}

object Cache {
private def apply[K, V](max: Int, entries: Map[K, V], usages: SortedBiMap[Long, K], stamp: Long): Cache[K, V] =
new Cache(max, entries)(usages, stamp) {}

def empty[K, V](max: Int): Cache[K, V] =
apply(max max 0, Map.empty, SortedBiMap.empty, 0L)
}


83 changes: 0 additions & 83 deletions modules/core/shared/src/main/scala/data/SemispaceCache.scala

This file was deleted.

48 changes: 48 additions & 0 deletions modules/core/shared/src/main/scala/data/SortedBiMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data

import scala.collection.immutable.SortedMap
import scala.math.Ordering

/** Immutable bi-directional map that is sorted by key. */
sealed abstract case class SortedBiMap[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]) {
assert(entries.size == inverse.size)

def size: Int = entries.size

def head: (K, V) = entries.head

def get(k: K): Option[V] = entries.get(k)

def put(k: K, v: V): SortedBiMap[K, V] =
// nb: couple important properties here:
// - SortedBiMap(k0 -> v, v -> k0).put(k1, v) == SortedBiMap(k1 -> v, v -> k1)
// - SortedBiMap(k -> v0, v0 -> k).put(k, v1) == SortedBiMap(k -> v1, v1 -> k)
SortedBiMap(
inverse.get(v).fold(entries)(entries - _) + (k -> v),
entries.get(k).fold(inverse)(inverse - _) + (v -> k))

def +(kv: (K, V)): SortedBiMap[K, V] = put(kv._1, kv._2)

def -(k: K): SortedBiMap[K, V] =
get(k) match {
case Some(v) => SortedBiMap(entries - k, inverse - v)
case None => this
}

def inverseGet(v: V): Option[K] = inverse.get(v)

override def toString: String =
entries.iterator.map { case (k, v) => s"$k <-> $v" }.mkString("SortedBiMap(", ", ", ")")
}

object SortedBiMap {
private def apply[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]): SortedBiMap[K, V] =
new SortedBiMap[K, V](entries, inverse) {}

def empty[K: Ordering, V]: SortedBiMap[K, V] = apply(SortedMap.empty, Map.empty)
}

35 changes: 23 additions & 12 deletions modules/core/shared/src/main/scala/util/StatementCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import cats.{ Functor, ~> }
import cats.syntax.all._
import skunk.Statement
import cats.effect.kernel.Ref
import skunk.data.SemispaceCache
import skunk.data.Cache

/** An LRU (by access) cache, keyed by statement `CacheKey`. */
sealed trait StatementCache[F[_], V] { outer =>
Expand All @@ -35,31 +35,42 @@ sealed trait StatementCache[F[_], V] { outer =>
object StatementCache {

def empty[F[_]: Functor: Ref.Make, V](max: Int, trackEviction: Boolean): F[StatementCache[F, V]] =
Ref[F].of(SemispaceCache.empty[Statement.CacheKey, V](max, trackEviction)).map { ref =>
// State is the cache and a set of evicted values; the evicted set only grows when trackEviction is true
Ref[F].of((Cache.empty[Statement.CacheKey, V](max), Set.empty[V])).map { ref =>
new StatementCache[F, V] {

def get(k: Statement[_]): F[Option[V]] =
ref.modify { c =>
c.lookup(k.cacheKey) match {
case Some((cʹ, v)) => (cʹ, Some(v))
case None => (c, None)
ref.modify { case (c, evicted) =>
c.get(k.cacheKey) match {
case Some((cʹ, v)) => (cʹ -> evicted, Some(v))
case None => (c -> evicted, None)
}
}

def put(k: Statement[_], v: V): F[Unit] =
ref.update(_.insert(k.cacheKey, v))
ref.update { case (c, evicted) =>
val (c2, e) = c.put(k.cacheKey, v)
// Remove the value we just inserted from the evicted set and add the newly evicted value, if any
val evicted2 = e.filter(_ => trackEviction).fold(evicted - v) { case (_, ev) => evicted - v + ev }
(c2, evicted2)
}

def containsKey(k: Statement[_]): F[Boolean] =
ref.get.map(_.containsKey(k.cacheKey))
ref.get.map(_._1.contains(k.cacheKey))

def clear: F[Unit] =
ref.update(_.evictAll)
ref.update { case (c, evicted) =>
val evicted2 = if (trackEviction) evicted ++ c.values else evicted
(Cache.empty[Statement.CacheKey, V](max), evicted2)
}

def values: F[List[V]] =
ref.get.map(_.values)
ref.get.map(_._1.values.toList)

def clearEvicted: F[List[V]] =
ref.modify(_.clearEvicted)
def clearEvicted: F[List[V]] =
ref.modify { case (c, evicted) =>
(c, Set.empty[V]) -> evicted.toList
}
}
}
}
26 changes: 14 additions & 12 deletions modules/tests/shared/src/test/scala/PrepareCacheTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import skunk.implicits._
import skunk.codec.numeric.int8
import skunk.codec.text
import skunk.codec.boolean
import cats.syntax.all.*
import cats.syntax.all._

class PrepareCacheTest extends SkunkTest {

Expand All @@ -17,16 +17,8 @@ class PrepareCacheTest extends SkunkTest {
private val pgStatementsCountByStatement = sql"select count(*) from pg_prepared_statements where statement = ${text.text}".query(int8)
private val pgStatementsCount = sql"select count(*) from pg_prepared_statements".query(int8)
private val pgStatements = sql"select statement from pg_prepared_statements order by prepare_time".query(text.text)

pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p =>
List.fill(4)(
p.use { s =>
s.execute(pgStatementsByName)("foo").void >> s.execute(pgStatementsByStatement)("bar").void >> s.execute(pgStatementsCountByStatement)("baz").void
}
).sequence
}

pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 1) { p =>

pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p =>
p.use { s =>
s.execute(pgStatementsByName)("foo").void >>
s.execute(pgStatementsByStatement)("bar").void >>
Expand All @@ -49,7 +41,7 @@ class PrepareCacheTest extends SkunkTest {
}
}

pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 1) { p =>
pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 2) { p =>
p.use { s =>
// creates entry in cache
s.prepare(pgStatementsByName)
Expand Down Expand Up @@ -97,4 +89,14 @@ class PrepareCacheTest extends SkunkTest {
}
}
}

pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 4) { p =>
List.fill(8)(
p.use { s =>
s.execute(pgStatementsByName)("foo").void >>
s.execute(pgStatementsByStatement)("bar").void >>
s.execute(pgStatementsCountByStatement)("baz").void
}
).sequence
}
}
Loading
Loading