Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLS handshaking issues with Postgres #1897

Merged
merged 2 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ lazy val mimaSettings = Seq(
),
ProblemFilters.exclude[DirectMissingMethodProblem](
"fs2.io.tls.TLSSocket.fs2$io$tls$TLSSocket$$binding$default$3"
)
),
// InputOutputBuffer is private[tls]
ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.io.tls.InputOutputBuffer.output"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.io.tls.InputOutputBuffer.output")
)
)

Expand Down
11 changes: 6 additions & 5 deletions io/src/main/scala/fs2/io/tls/InputOutputBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ private[tls] trait InputOutputBuffer[F[_]] {
/** Adds the specified chunk to the input buffer. */
def input(data: Chunk[Byte]): F[Unit]

/** Removes all available data from the output buffer. */
def output: F[Chunk[Byte]]
/** Removes available data from the output buffer. */
def output(maxBytes: Int): F[Chunk[Byte]]

/**
* Performs an operation that may read from the input buffer and write to the output buffer.
Expand Down Expand Up @@ -75,16 +75,17 @@ private[tls] object InputOutputBuffer {
}
}

def output: F[Chunk[Byte]] =
def output(maxBytes: Int): F[Chunk[Byte]] =
outBuff.get.flatMap { out =>
if (out.position() == 0) Applicative[F].pure(Chunk.empty)
else
Sync[F].delay {
(out: Buffer).flip()
val cap = out.limit()
val dest = new Array[Byte](cap)
val sz = cap.min(maxBytes)
val dest = new Array[Byte](sz)
out.get(dest)
(out: Buffer).clear()
out.compact()
Chunk.bytes(dest)
}
}
Expand Down
57 changes: 33 additions & 24 deletions io/src/main/scala/fs2/io/tls/TLSEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,32 +93,38 @@ private[tls] object TLSEngine {
}

private def doWrite(timeout: Option[FiniteDuration]): F[Unit] =
wrapBuffer.output.flatMap { out =>
wrapBuffer.output(Int.MaxValue).flatMap { out =>
if (out.isEmpty) Applicative[F].unit
else binding.write(out, timeout)
}

def read(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] =
readSemaphore.withPermit(read0(maxBytes, timeout))

private def initialHandshakeDone: F[Boolean] =
Sync[F].delay(engine.getSession.getCipherSuite != "SSL_NULL_WITH_NULL_NULL")

private def read0(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] =
// Check if a session has been established -- if so, read; otherwise, handshake and then read
blocker
.delay(engine.getSession.isValid)
.ifM(
binding.read(maxBytes, timeout).flatMap {
case Some(c) =>
unwrapBuffer.input(c) >> unwrap(timeout).flatMap {
case s @ Some(_) => Applicative[F].pure(s)
case None => read0(maxBytes, timeout)
}
case None => Applicative[F].pure(None)
},
write(Chunk.empty, None) >> read0(maxBytes, timeout)
)
// Check if the initial handshake has finished -- if so, read; otherwise, handshake and then read
initialHandshakeDone.ifM(
dequeueUnwrap(maxBytes).flatMap { out =>
if (out.isEmpty) read1(maxBytes, timeout) else Applicative[F].pure(out)
},
write(Chunk.empty, None) >> read1(maxBytes, timeout)
)

private def read1(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] =
binding.read(maxBytes, timeout).flatMap {
case Some(c) =>
unwrapBuffer.input(c) >> unwrap(maxBytes, timeout).flatMap {
case s @ Some(_) => Applicative[F].pure(s)
case None => read1(maxBytes, timeout)
}
case None => Applicative[F].pure(None)
}

/** Performs an unwrap operation on the underlying engine. */
private def unwrap(timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] =
private def unwrap(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] =
unwrapBuffer
.perform(engine.unwrap(_, _))
.flatTap(result => log(s"unwrap result: $result"))
Expand All @@ -129,24 +135,27 @@ private[tls] object TLSEngine {
case SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING =>
unwrapBuffer.inputRemains
.map(_ > 0 && result.bytesConsumed > 0)
.ifM(unwrap(timeout), dequeueUnwrap)
.ifM(unwrap(maxBytes, timeout), dequeueUnwrap(maxBytes))
case SSLEngineResult.HandshakeStatus.FINISHED =>
unwrap(timeout)
unwrap(maxBytes, timeout)
case _ =>
handshakeSemaphore
.withPermit(stepHandshake(result, false, timeout)) >> unwrap(timeout)
.withPermit(stepHandshake(result, false, timeout)) >> unwrap(
maxBytes,
timeout
)
}
case SSLEngineResult.Status.BUFFER_UNDERFLOW =>
dequeueUnwrap
dequeueUnwrap(maxBytes)
case SSLEngineResult.Status.BUFFER_OVERFLOW =>
unwrapBuffer.expandOutput >> unwrap(timeout)
unwrapBuffer.expandOutput >> unwrap(maxBytes, timeout)
case SSLEngineResult.Status.CLOSED =>
stopWrap >> stopUnwrap >> dequeueUnwrap
stopWrap >> stopUnwrap >> dequeueUnwrap(maxBytes)
}
}

private def dequeueUnwrap: F[Option[Chunk[Byte]]] =
unwrapBuffer.output.map(out => if (out.isEmpty) None else Some(out))
private def dequeueUnwrap(maxBytes: Int): F[Option[Chunk[Byte]]] =
unwrapBuffer.output(maxBytes).map(out => if (out.isEmpty) None else Some(out))

/**
* Determines what to do next given the result of a handshake operation.
Expand Down