Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add read timeout #683

Merged
merged 3 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ version: '3'
services:
# main instance for testing
postgres:
image: postgres:11
# image: postgres:11
build:
context: world
command: -c ssl=on -c ssl_cert_file=/var/lib/postgresql/server.crt -c ssl_key_file=/var/lib/postgresql/server.key
volumes:
- ./world/world.sql:/docker-entrypoint-initdb.d/world.sql
- ./world/server.crt:/var/lib/postgresql/server.crt
- ./world/server.key:/var/lib/postgresql/server.key
# volumes:
# - ./world/world.sql:/docker-entrypoint-initdb.d/world.sql
# - ./world/server.crt:/var/lib/postgresql/server.crt:ro
# - ./world/server.key:/var/lib/postgresql/server.key:ro
svalaskevicius marked this conversation as resolved.
Show resolved Hide resolved
ports:
- 5432:5432
environment:
POSTGRES_USER: jimmy
POSTGRES_PASSWORD: banana
POSTGRES_DB: world
# environment:
# POSTGRES_USER: jimmy
# POSTGRES_PASSWORD: banana
# POSTGRES_DB: world
# for testing password-free login
trust:
image: postgres:11
Expand Down
15 changes: 10 additions & 5 deletions modules/core/shared/src/main/scala/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import skunk.net.SSLNegotiation
import skunk.data.TransactionIsolationLevel
import skunk.data.TransactionAccessMode
import skunk.net.protocol.Describe
import scala.concurrent.duration.Duration

/**
* Represents a live connection to a Postgres database. Operations provided here are safe to use
Expand Down Expand Up @@ -262,7 +263,7 @@ object Session {
* @param queryCache Size of the cache for query checking
* @group Constructors
*/
def pooled[F[_]: Concurrent: Trace: Network: Console](
def pooled[F[_]: Temporal: Trace: Network: Console](
host: String,
port: Int = 5432,
user: String,
Expand All @@ -276,10 +277,11 @@ object Session {
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
commandCache: Int = 1024,
queryCache: Int = 1024,
readTimeout: Duration = Duration.Inf,
): Resource[F, Resource[F, Session[F]]] = {

def session(socketGroup: SocketGroup[F], sslOp: Option[SSLNegotiation.Options[F]], cache: Describe.Cache[F]): Resource[F, Session[F]] =
fromSocketGroup[F](socketGroup, host, port, user, database, password, debug, strategy, socketOptions, sslOp, parameters, cache)
fromSocketGroup[F](socketGroup, host, port, user, database, password, debug, strategy, socketOptions, sslOp, parameters, cache, readTimeout)

val logger: String => F[Unit] = s => Console[F].println(s"TLS: $s")

Expand All @@ -297,7 +299,7 @@ object Session {
* single-session pool. This method is shorthand for `Session.pooled(..., max = 1, ...).flatten`.
* @see pooled
*/
def single[F[_]: Concurrent: Trace: Network: Console](
def single[F[_]: Temporal: Trace: Network: Console](
host: String,
port: Int = 5432,
user: String,
Expand All @@ -309,6 +311,7 @@ object Session {
parameters: Map[String, String] = Session.DefaultConnectionParameters,
commandCache: Int = 1024,
queryCache: Int = 1024,
readTimeout: Duration = Duration.Inf,
): Resource[F, Session[F]] =
pooled(
host = host,
Expand All @@ -323,9 +326,10 @@ object Session {
parameters = parameters,
commandCache = commandCache,
queryCache = queryCache,
readTimeout = readTimeout
).flatten

def fromSocketGroup[F[_]: Concurrent: Trace: Console](
def fromSocketGroup[F[_]: Temporal: Trace: Console](
socketGroup: SocketGroup[F],
host: String,
port: Int = 5432,
Expand All @@ -338,10 +342,11 @@ object Session {
sslOptions: Option[SSLNegotiation.Options[F]],
parameters: Map[String, String],
describeCache: Describe.Cache[F],
readTimeout: Duration = Duration.Inf,
): Resource[F, Session[F]] =
for {
namer <- Resource.eval(Namer[F])
proto <- Protocol[F](host, port, debug, namer, socketGroup, socketOptions, sslOptions, describeCache)
proto <- Protocol[F](host, port, debug, namer, socketGroup, socketOptions, sslOptions, describeCache, readTimeout)
_ <- Resource.eval(proto.startup(user, database, password, parameters))
sess <- Resource.eval(fromProtocol(proto, namer, strategy))
} yield sess
Expand Down
23 changes: 16 additions & 7 deletions modules/core/shared/src/main/scala/net/BitVectorSocket.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

package skunk.net

import cats._
import cats.effect._
import cats.syntax.all._
import cats.effect._
import cats.effect.syntax.temporal._
import fs2.Chunk
import scodec.bits.BitVector
import fs2.io.net.{Socket, SocketGroup, SocketOption}
import com.comcast.ip4s._
import skunk.exception.{EofException, SkunkException}
import scala.concurrent.duration.Duration
import scala.concurrent.duration.FiniteDuration

/** A higher-level `Socket` interface defined in terms of `BitVector`. */
trait BitVectorSocket[F[_]] {
Expand All @@ -34,14 +36,20 @@ object BitVectorSocket {
* @group Constructors
*/
def fromSocket[F[_]](
socket: Socket[F]
socket: Socket[F],
readTimeout: Duration
)(
implicit ev: MonadError[F, Throwable]
implicit ev: Temporal[F]
): BitVectorSocket[F] =
new BitVectorSocket[F] {

val withTimeout: F[Chunk[Byte]] => F[Chunk[Byte]] = readTimeout match {
case _: Duration.Infinite => identity
case finite: FiniteDuration => _.timeout(finite)
}

def readBytes(n: Int): F[Array[Byte]] =
socket.readN(n).flatMap { c =>
withTimeout(socket.readN(n)).flatMap { c =>
if (c.size == n) c.toArray.pure[F]
else ev.raiseError(EofException(n, c.size))
}
Expand All @@ -66,7 +74,8 @@ object BitVectorSocket {
sg: SocketGroup[F],
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
)(implicit ev: MonadError[F, Throwable]): Resource[F, BitVectorSocket[F]] = {
readTimeout: Duration
)(implicit ev: Temporal[F]): Resource[F, BitVectorSocket[F]] = {

def fail[A](msg: String): Resource[F, A] =
Resource.eval(ev.raiseError(new SkunkException(message = msg, sql = None)))
Expand All @@ -82,7 +91,7 @@ object BitVectorSocket {
for {
sock <- sock
sockʹ <- sslOptions.fold(sock.pure[Resource[F, *]])(SSLNegotiation.negotiateSSL(sock, _))
} yield fromSocket(sockʹ)
} yield fromSocket(sockʹ, readTimeout)

}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import fs2.Stream
import skunk.data._
import skunk.net.message._
import fs2.io.net.{SocketGroup, SocketOption}
import scala.concurrent.duration.Duration

/**
* A `MessageSocket` that buffers incoming messages, removing and handling asynchronous back-end
Expand Down Expand Up @@ -74,17 +75,18 @@ trait BufferedMessageSocket[F[_]] extends MessageSocket[F] {

object BufferedMessageSocket {

def apply[F[_]: Concurrent: Console](
def apply[F[_]: Temporal: Console](
host: String,
port: Int,
queueSize: Int,
debug: Boolean,
sg: SocketGroup[F],
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
readTimeout: Duration
): Resource[F, BufferedMessageSocket[F]] =
for {
ms <- MessageSocket(host, port, debug, sg, socketOptions, sslOptions)
ms <- MessageSocket(host, port, debug, sg, socketOptions, sslOptions, readTimeout)
ams <- Resource.make(BufferedMessageSocket.fromMessageSocket[F](ms, queueSize))(_.terminate)
} yield ams

Expand Down
6 changes: 4 additions & 2 deletions modules/core/shared/src/main/scala/net/MessageSocket.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import scodec.interop.cats._
import skunk.net.message.{ Sync => _, _ }
import skunk.util.Origin
import fs2.io.net.{ SocketGroup, SocketOption }
import scala.concurrent.duration.Duration

/** A higher-level `BitVectorSocket` that speaks in terms of `Message`. */
trait MessageSocket[F[_]] {
Expand Down Expand Up @@ -86,16 +87,17 @@ object MessageSocket {
}
}

def apply[F[_]: Concurrent: Console](
def apply[F[_]: Console: Temporal](
host: String,
port: Int,
debug: Boolean,
sg: SocketGroup[F],
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
readTimeout: Duration
): Resource[F, MessageSocket[F]] =
for {
bvs <- BitVectorSocket(host, port, sg, socketOptions, sslOptions)
bvs <- BitVectorSocket(host, port, sg, socketOptions, sslOptions, readTimeout)
ms <- Resource.eval(fromBitVectorSocket(bvs, debug))
} yield ms

Expand Down
8 changes: 5 additions & 3 deletions modules/core/shared/src/main/scala/net/Protocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
package skunk.net

import cats.syntax.all._
import cats.effect.{ Concurrent, Resource }
import cats.effect.{ Concurrent, Temporal, Resource }
import cats.effect.std.Console
import fs2.concurrent.Signal
import fs2.Stream
Expand All @@ -17,6 +17,7 @@ import natchez.Trace
import fs2.io.net.{ SocketGroup, SocketOption }
import skunk.net.protocol.Exchange
import skunk.net.protocol.Describe
import scala.concurrent.duration.Duration

/**
* Interface for a Postgres database, expressed through high-level operations that rely on exchange
Expand Down Expand Up @@ -187,7 +188,7 @@ object Protocol {
* @param host Postgres server host
* @param port Postgres port, default 5432
*/
def apply[F[_]: Concurrent: Trace: Console](
def apply[F[_]: Temporal: Trace: Console](
host: String,
port: Int,
debug: Boolean,
Expand All @@ -196,9 +197,10 @@ object Protocol {
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
describeCache: Describe.Cache[F],
readTimeout: Duration
): Resource[F, Protocol[F]] =
for {
bms <- BufferedMessageSocket[F](host, port, 256, debug, sg, socketOptions, sslOptions) // TODO: should we expose the queue size?
bms <- BufferedMessageSocket[F](host, port, 256, debug, sg, socketOptions, sslOptions, readTimeout) // TODO: should we expose the queue size?
p <- Resource.eval(fromMessageSocket(bms, nam, describeCache))
} yield p

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import java.time._
import skunk.codec.temporal._
import cats.effect.{IO, Resource}
import skunk._, skunk.implicits._
import scala.concurrent.duration.{ Duration => SDuration }

class TemporalCodecTest extends CodecTest {

Expand All @@ -23,8 +24,8 @@ class TemporalCodecTest extends CodecTest {

// Also, run these tests with the session set to a timezone other than UTC. Our test instance is
// set to UTC, which masks the error reported at https://github.com/tpolecat/skunk/issues/313.
override def session: Resource[IO,Session[IO]] =
super.session.evalTap(s => s.execute(sql"SET TIME ZONE +3".command))
override def session(readTimeout: SDuration): Resource[IO,Session[IO]] =
super.session(readTimeout).evalTap(s => s.execute(sql"SET TIME ZONE +3".command))

// Date
val dates: List[LocalDate] =
Expand Down Expand Up @@ -124,4 +125,4 @@ class TemporalCodecTest extends CodecTest {
roundtripTest(interval)(intervals: _*)
decodeFailureTest(interval, List("x"))

}
}
5 changes: 3 additions & 2 deletions modules/tests/shared/src/test/scala/BitVectorSocketTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.comcast.ip4s.{Host, IpAddress, Port, SocketAddress}
import fs2.io.net.{Socket, SocketGroup, SocketOption}
import skunk.exception.SkunkException
import skunk.net.BitVectorSocket
import scala.concurrent.duration.Duration

class BitVectorSocketTest extends ffstest.FTest {

Expand All @@ -23,11 +24,11 @@ class BitVectorSocketTest extends ffstest.FTest {
private val socketOptions = List(SocketOption.noDelay(true))

test("Invalid host") {
BitVectorSocket("", 1, dummySg, socketOptions, None).use(_ => IO.unit).assertFailsWith[SkunkException]
BitVectorSocket("", 1, dummySg, socketOptions, None, Duration.Inf).use(_ => IO.unit).assertFailsWith[SkunkException]
.flatMap(e => assertEqual("message", e.message, """Hostname: "" is not syntactically valid."""))
}
test("Invalid port") {
BitVectorSocket("localhost", -1, dummySg, socketOptions, None).use(_ => IO.unit).assertFailsWith[SkunkException]
BitVectorSocket("localhost", -1, dummySg, socketOptions, None, Duration.Inf).use(_ => IO.unit).assertFailsWith[SkunkException]
.flatMap(e => assertEqual("message", e.message, "Port: -1 falls out of the allowed range."))
}

Expand Down
34 changes: 34 additions & 0 deletions modules/tests/shared/src/test/scala/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import skunk.codec.all._
import skunk.implicits._
import tests.SkunkTest
import cats.Eq
import scala.concurrent.duration._
import skunk.Decoder
import skunk.data.Type

class QueryTest extends SkunkTest{

Expand Down Expand Up @@ -54,5 +57,36 @@ class QueryTest extends SkunkTest{
}
}

val void: Decoder[skunk.Void] = new Decoder[skunk.Void] {
def types: List[Type] = List(Type.void)
def decode(offset: Int, ss: List[Option[String]]): Either[Decoder.Error, skunk.Void] = Right(skunk.Void)
}


pooledTest("timeout", 2.seconds) { getS =>
val f = sql"select pg_sleep($int4)"
def getErr[X]: Either[Throwable, X] => Option[String] = _.swap.toOption.collect {
case e: java.util.concurrent.TimeoutException => e.getMessage()
}
for {
sessionBroken <- getS.use { s =>
s.prepare(f.query(void)).use { ps =>
for {
ret <- ps.unique(8).attempt
_ <- assertEqual("timeout error check", getErr(ret), Option("2 seconds"))
} yield "ok"
}
}.attempt
_ <- assertEqual("timeout error check", getErr(sessionBroken), Option("2 seconds"))
_ <- getS.use { s =>
s.prepare(f.query(void)).use { ps =>
for {
ret <- ps.unique(1).attempt
_ <- assertEqual("timeout error ok", ret.isRight, true)
} yield "ok"
}
}
} yield "ok"
}

}
Loading