Skip to content

Commit

Permalink
Merge pull request #728 from vbergeron/parse-cache
Browse files Browse the repository at this point in the history
Implement per-session parsed statement cache
  • Loading branch information
mpilquist authored Dec 17, 2022
2 parents c463914 + fcbc8fd commit fd3b400
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 42 deletions.
55 changes: 45 additions & 10 deletions modules/core/shared/src/main/scala/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import skunk.data.TransactionIsolationLevel
import skunk.data.TransactionAccessMode
import skunk.net.protocol.Describe
import scala.concurrent.duration.Duration
import skunk.net.protocol.Parse

/**
* Represents a live connection to a Postgres database. Operations provided here are safe to use
Expand Down Expand Up @@ -132,14 +133,30 @@ trait Session[F[_]] {
* times with different arguments.
* @group Queries
*/
def prepare[A, B](query: Query[A, B]): Resource[F, PreparedQuery[F, A, B]]
def prepare[A, B](query: Query[A, B]): Resource[F, PreparedQuery[F, A, B]] =
Resource.eval(prepareAndCache(query))

/**
* Prepare an `INSERT`, `UPDATE`, or `DELETE` command that returns no rows. The resulting
* `PreparedCommand` can be executed multiple times with different arguments.
* @group Commands
*/
def prepare[A](command: Command[A]): Resource[F, PreparedCommand[F, A]]
def prepare[A](command: Command[A]): Resource[F, PreparedCommand[F, A]] =
Resource.eval(prepareAndCache(command))

/**
* Prepares then caches a query, yielding a `PreparedQuery` which can be executed multiple
* times with different arguments.
* @group Queries
*/
def prepareAndCache[A, B](query: Query[A, B]): F[PreparedQuery[F, A, B]]

/**
* Prepares then caches an `INSERT`, `UPDATE`, or `DELETE` command that returns no rows. The resulting
* `PreparedCommand` can be executed multiple times with different arguments.
* @group Commands
*/
def prepareAndCache[A](command: Command[A]): F[PreparedCommand[F, A]]

/**
* Transform a `Command` into a `Pipe` from inputs to `Completion`s.
Expand Down Expand Up @@ -187,6 +204,13 @@ trait Session[F[_]] {
* the cache through this accessor.
*/
def describeCache: Describe.Cache[F]

/**
* Each session has access to a cache of all statements that have been parsed by the
* `Parse` protocol, which allows us to skip a network round-trip. Users can inspect and clear
* the cache through this accessor.
*/
def parseCache: Parse.Cache[F]

}

Expand Down Expand Up @@ -277,11 +301,14 @@ object Session {
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
readTimeout: Duration = Duration.Inf,
): Resource[F, Resource[F, Session[F]]] = {

def session(socketGroup: SocketGroup[F], sslOp: Option[SSLNegotiation.Options[F]], cache: Describe.Cache[F]): Resource[F, Session[F]] =
fromSocketGroup[F](socketGroup, host, port, user, database, password, debug, strategy, socketOptions, sslOp, parameters, cache, readTimeout)
for {
pc <- Resource.eval(Parse.Cache.empty[F](parseCache))
s <- fromSocketGroup[F](socketGroup, host, port, user, database, password, debug, strategy, socketOptions, sslOp, parameters, cache, pc, readTimeout)
} yield s

val logger: String => F[Unit] = s => Console[F].println(s"TLS: $s")

Expand Down Expand Up @@ -311,6 +338,7 @@ object Session {
parameters: Map[String, String] = Session.DefaultConnectionParameters,
commandCache: Int = 1024,
queryCache: Int = 1024,
parseCache: Int = 1024,
readTimeout: Duration = Duration.Inf,
): Resource[F, Session[F]] =
pooled(
Expand All @@ -326,6 +354,7 @@ object Session {
parameters = parameters,
commandCache = commandCache,
queryCache = queryCache,
parseCache = parseCache,
readTimeout = readTimeout
).flatten

Expand All @@ -342,13 +371,14 @@ object Session {
sslOptions: Option[SSLNegotiation.Options[F]],
parameters: Map[String, String],
describeCache: Describe.Cache[F],
parseCache: Parse.Cache[F],
readTimeout: Duration = Duration.Inf,
): Resource[F, Session[F]] =
for {
namer <- Resource.eval(Namer[F])
proto <- Protocol[F](host, port, debug, namer, socketGroup, socketOptions, sslOptions, describeCache, readTimeout)
proto <- Protocol[F](host, port, debug, namer, socketGroup, socketOptions, sslOptions, describeCache, parseCache, readTimeout)
_ <- Resource.eval(proto.startup(user, database, password, parameters))
sess <- Resource.eval(fromProtocol(proto, namer, strategy))
sess <- Resource.make(fromProtocol(proto, namer, strategy))(_ => proto.cleanup)
} yield sess

/**
Expand Down Expand Up @@ -408,10 +438,10 @@ object Session {
case _ => ev.raiseError(new RuntimeException("Expected at most one row, more returned."))
}

override def prepare[A, B](query: Query[A, B]): Resource[F, PreparedQuery[F, A, B]] =
override def prepareAndCache[A, B](query: Query[A, B]): F[PreparedQuery[F, A, B]] =
proto.prepare(query, typer).map(PreparedQuery.fromProto(_))

override def prepare[A](command: Command[A]): Resource[F, PreparedCommand[F, A]] =
override def prepareAndCache[A](command: Command[A]): F[PreparedCommand[F, A]] =
proto.prepare(command, typer).map(PreparedCommand.fromProto(_))

override def transaction[A]: Resource[F, Transaction[F]] =
Expand All @@ -423,6 +453,9 @@ object Session {
override def describeCache: Describe.Cache[F] =
proto.describeCache

override def parseCache: Parse.Cache[F] =
proto.parseCache

}
}
}
Expand Down Expand Up @@ -465,9 +498,9 @@ object Session {

override def parameters: Signal[G,Map[String,String]] = outer.parameters.mapK(fk)

override def prepare[A, B](query: Query[A,B]): Resource[G,PreparedQuery[G,A,B]] = outer.prepare(query).mapK(fk).map(_.mapK(fk))
override def prepareAndCache[A, B](query: Query[A,B]): G[PreparedQuery[G,A,B]] = fk(outer.prepareAndCache(query)).map(_.mapK(fk))

override def prepare[A](command: Command[A]): Resource[G,PreparedCommand[G,A]] = outer.prepare(command).mapK(fk).map(_.mapK(fk))
override def prepareAndCache[A](command: Command[A]): G[PreparedCommand[G,A]] = fk(outer.prepareAndCache(command)).map(_.mapK(fk))

override def transaction[A]: Resource[G,Transaction[G]] = outer.transaction[A].mapK(fk).map(_.mapK(fk))

Expand All @@ -480,6 +513,8 @@ object Session {

override def describeCache: Describe.Cache[G] = outer.describeCache.mapK(fk)

override def parseCache: Parse.Cache[G] = outer.parseCache.mapK(fk)

}
}

Expand Down
4 changes: 3 additions & 1 deletion modules/core/shared/src/main/scala/data/SemispaceCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ sealed abstract case class SemispaceCache[K, V](gen0: Map[K, V], gen1: Map[K, V]
def containsKey(k: K): Boolean =
gen0.contains(k) || gen1.contains(k)

def values: List[V] =
(gen0.values.toSet | gen1.values.toSet).toList
}

object SemispaceCache {
Expand All @@ -38,4 +40,4 @@ object SemispaceCache {
def empty[K, V](max: Int): SemispaceCache[K, V] =
SemispaceCache[K, V](Map.empty, Map.empty, max max 0)

}
}
36 changes: 26 additions & 10 deletions modules/core/shared/src/main/scala/net/Protocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ import skunk.util.{ Namer, Origin }
import skunk.util.Typer
import natchez.Trace
import fs2.io.net.{ SocketGroup, SocketOption }
import skunk.net.protocol.Exchange
import skunk.net.protocol.Describe
import scala.concurrent.duration.Duration
import skunk.net.protocol.Exchange
import skunk.net.protocol.Parse

/**
* Interface for a Postgres database, expressed through high-level operations that rely on exchange
Expand Down Expand Up @@ -62,15 +63,15 @@ trait Protocol[F[_]] {

/**
* Prepare a command (a statement that produces no rows), yielding a `Protocol.PreparedCommand`
* which will be closed after use.
* which will be cached per session and closed on session close.
*/
def prepare[A](command: Command[A], ty: Typer): Resource[F, Protocol.PreparedCommand[F, A]]
def prepare[A](command: Command[A], ty: Typer): F[Protocol.PreparedCommand[F, A]]

/**
* Prepare a query (a statement that produces rows), yielding a `Protocol.PreparedCommand` which
* which will be closed after use.
* which will be cached per session and closed on session close.
*/
def prepare[A, B](query: Query[A, B], ty: Typer): Resource[F, Protocol.PreparedQuery[F, A, B]]
def prepare[A, B](query: Query[A, B], ty: Typer): F[Protocol.PreparedQuery[F, A, B]]

/**
* Execute a non-parameterized command (a statement that produces no rows), yielding a
Expand All @@ -92,6 +93,11 @@ trait Protocol[F[_]] {
*/
def startup(user: String, database: String, password: Option[String], parameters: Map[String, String]): F[Unit]

/**
* Cleanup the session. This will close any cached prepared statements.
*/
def cleanup: F[Unit]

/**
* Signal representing the current transaction status as reported by `ReadyForQuery`. It's not
* clear that this is a useful thing to expose.
Expand All @@ -101,6 +107,8 @@ trait Protocol[F[_]] {
/** Cache for the `Describe` protocol. */
def describeCache: Describe.Cache[F]

/** Cache for the `Parse` protocol. */
def parseCache: Parse.Cache[F]
}

object Protocol {
Expand Down Expand Up @@ -197,17 +205,19 @@ object Protocol {
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
describeCache: Describe.Cache[F],
parseCache: Parse.Cache[F],
readTimeout: Duration
): Resource[F, Protocol[F]] =
for {
bms <- BufferedMessageSocket[F](host, port, 256, debug, sg, socketOptions, sslOptions, readTimeout) // TODO: should we expose the queue size?
p <- Resource.eval(fromMessageSocket(bms, nam, describeCache))
p <- Resource.eval(fromMessageSocket(bms, nam, describeCache, parseCache))
} yield p

def fromMessageSocket[F[_]: Concurrent: Trace](
bms: BufferedMessageSocket[F],
nam: Namer[F],
dc: Describe.Cache[F],
pc: Parse.Cache[F]
): F[Protocol[F]] =
Exchange[F].map { ex =>
new Protocol[F] {
Expand All @@ -224,11 +234,11 @@ object Protocol {
override def parameters: Signal[F, Map[String, String]] =
bms.parameters

override def prepare[A](command: Command[A], ty: Typer): Resource[F, PreparedCommand[F, A]] =
protocol.Prepare[F](describeCache).apply(command, ty)
override def prepare[A](command: Command[A], ty: Typer): F[PreparedCommand[F, A]] =
protocol.Prepare[F](describeCache, parseCache).apply(command, ty)

override def prepare[A, B](query: Query[A, B], ty: Typer): Resource[F, PreparedQuery[F, A, B]] =
protocol.Prepare[F](describeCache).apply(query, ty)
override def prepare[A, B](query: Query[A, B], ty: Typer): F[PreparedQuery[F, A, B]] =
protocol.Prepare[F](describeCache, parseCache).apply(query, ty)

override def execute(command: Command[Void]): F[Completion] =
protocol.Query[F].apply(command)
Expand All @@ -239,12 +249,18 @@ object Protocol {
override def startup(user: String, database: String, password: Option[String], parameters: Map[String, String]): F[Unit] =
protocol.Startup[F].apply(user, database, password, parameters)

override def cleanup: F[Unit] =
parseCache.value.values.flatMap(xs => xs.traverse_(protocol.Close[F].apply))

override def transactionStatus: Signal[F, TransactionStatus] =
bms.transactionStatus

override val describeCache: Describe.Cache[F] =
dc

override val parseCache: Parse.Cache[F] =
pc

}
}

Expand Down
32 changes: 22 additions & 10 deletions modules/core/shared/src/main/scala/net/protocol/Parse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

package skunk.net.protocol

import cats.effect.Resource
import cats._
import cats.effect.Ref
import cats.syntax.all._
import cats.MonadError
import skunk.util.StatementCache
import skunk.exception.PostgresErrorException
import skunk.net.message.{ Parse => ParseMessage, Close => _, _ }
import skunk.net.MessageSocket
Expand All @@ -16,26 +17,26 @@ import skunk.util.Namer
import skunk.util.Typer
import skunk.exception.{ UnknownTypeException, TooManyParametersException }
import natchez.Trace
import cats.data.OptionT

trait Parse[F[_]] {
def apply[A](statement: Statement[A], ty: Typer): Resource[F, StatementId]
def apply[A](statement: Statement[A], ty: Typer): F[StatementId]
}

object Parse {

def apply[F[_]: Exchange: MessageSocket: Namer: Trace](
def apply[F[_]: Exchange: MessageSocket: Namer: Trace](cache: Cache[F])(
implicit ev: MonadError[F, Throwable]
): Parse[F] =
new Parse[F] {

override def apply[A](statement: Statement[A], ty: Typer): Resource[F, StatementId] =
override def apply[A](statement: Statement[A], ty: Typer): F[StatementId] =
statement.encoder.oids(ty) match {

case Right(os) if os.length > Short.MaxValue =>
Resource.eval(TooManyParametersException(statement).raiseError[F, StatementId])
TooManyParametersException(statement).raiseError[F, StatementId]

case Right(os) =>
Resource.make {
OptionT(cache.value.get(statement)).getOrElseF {
exchange("parse") {
for {
id <- nextName("statement").map(StatementId(_))
Expand All @@ -52,10 +53,10 @@ object Parse {
}
} yield id
}
} { Close[F].apply }
}

case Left(err) =>
Resource.eval(UnknownTypeException(statement, err, ty.strategy).raiseError[F, StatementId])
UnknownTypeException(statement, err, ty.strategy).raiseError[F, StatementId]

}

Expand All @@ -74,4 +75,15 @@ object Parse {

}

/** A cache for the `Parse` protocol. */
final case class Cache[F[_]](value: StatementCache[F, StatementId]) {
def mapK[G[_]](fk: F ~> G): Cache[G] =
Cache(value.mapK(fk))
}

object Cache {
def empty[F[_]: Functor: Ref.Make](capacity: Int): F[Cache[F]] =
StatementCache.empty[F, StatementId](capacity).map(Parse.Cache(_))
}

}
Loading

0 comments on commit fd3b400

Please sign in to comment.