diff --git a/build.sbt b/build.sbt index 6a70b9f40a..56f3d731cd 100644 --- a/build.sbt +++ b/build.sbt @@ -248,7 +248,10 @@ lazy val mimaSettings = Seq( ), ProblemFilters.exclude[DirectMissingMethodProblem]( "fs2.io.tls.TLSSocket.fs2$io$tls$TLSSocket$$binding$default$3" - ) + ), + // InputOutputBuffer is private[tls] + ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.io.tls.InputOutputBuffer.output"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.io.tls.InputOutputBuffer.output") ) ) diff --git a/io/src/main/scala/fs2/io/tls/InputOutputBuffer.scala b/io/src/main/scala/fs2/io/tls/InputOutputBuffer.scala index b192b3b7b3..0d5b85e94b 100644 --- a/io/src/main/scala/fs2/io/tls/InputOutputBuffer.scala +++ b/io/src/main/scala/fs2/io/tls/InputOutputBuffer.scala @@ -21,8 +21,8 @@ private[tls] trait InputOutputBuffer[F[_]] { /** Adds the specified chunk to the input buffer. */ def input(data: Chunk[Byte]): F[Unit] - /** Removes all available data from the output buffer. */ - def output: F[Chunk[Byte]] + /** Removes available data from the output buffer. */ + def output(maxBytes: Int): F[Chunk[Byte]] /** * Performs an operation that may read from the input buffer and write to the output buffer. @@ -75,16 +75,17 @@ private[tls] object InputOutputBuffer { } } - def output: F[Chunk[Byte]] = + def output(maxBytes: Int): F[Chunk[Byte]] = outBuff.get.flatMap { out => if (out.position() == 0) Applicative[F].pure(Chunk.empty) else Sync[F].delay { (out: Buffer).flip() val cap = out.limit() - val dest = new Array[Byte](cap) + val sz = cap.min(maxBytes) + val dest = new Array[Byte](sz) out.get(dest) - (out: Buffer).clear() + out.compact() Chunk.bytes(dest) } } diff --git a/io/src/main/scala/fs2/io/tls/TLSEngine.scala b/io/src/main/scala/fs2/io/tls/TLSEngine.scala index 6a818bc0e5..f917152d0e 100644 --- a/io/src/main/scala/fs2/io/tls/TLSEngine.scala +++ b/io/src/main/scala/fs2/io/tls/TLSEngine.scala @@ -93,7 +93,7 @@ private[tls] object TLSEngine { } private def doWrite(timeout: Option[FiniteDuration]): F[Unit] = - wrapBuffer.output.flatMap { out => + wrapBuffer.output(Int.MaxValue).flatMap { out => if (out.isEmpty) Applicative[F].unit else binding.write(out, timeout) } @@ -101,24 +101,30 @@ private[tls] object TLSEngine { def read(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] = readSemaphore.withPermit(read0(maxBytes, timeout)) + private def initialHandshakeDone: F[Boolean] = + Sync[F].delay(engine.getSession.getCipherSuite != "SSL_NULL_WITH_NULL_NULL") + private def read0(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] = - // Check if a session has been established -- if so, read; otherwise, handshake and then read - blocker - .delay(engine.getSession.isValid) - .ifM( - binding.read(maxBytes, timeout).flatMap { - case Some(c) => - unwrapBuffer.input(c) >> unwrap(timeout).flatMap { - case s @ Some(_) => Applicative[F].pure(s) - case None => read0(maxBytes, timeout) - } - case None => Applicative[F].pure(None) - }, - write(Chunk.empty, None) >> read0(maxBytes, timeout) - ) + // Check if the initial handshake has finished -- if so, read; otherwise, handshake and then read + initialHandshakeDone.ifM( + dequeueUnwrap(maxBytes).flatMap { out => + if (out.isEmpty) read1(maxBytes, timeout) else Applicative[F].pure(out) + }, + write(Chunk.empty, None) >> read1(maxBytes, timeout) + ) + + private def read1(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] = + binding.read(maxBytes, timeout).flatMap { + case Some(c) => + unwrapBuffer.input(c) >> unwrap(maxBytes, timeout).flatMap { + case s @ Some(_) => Applicative[F].pure(s) + case None => read1(maxBytes, timeout) + } + case None => Applicative[F].pure(None) + } /** Performs an unwrap operation on the underlying engine. */ - private def unwrap(timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] = + private def unwrap(maxBytes: Int, timeout: Option[FiniteDuration]): F[Option[Chunk[Byte]]] = unwrapBuffer .perform(engine.unwrap(_, _)) .flatTap(result => log(s"unwrap result: $result")) @@ -129,24 +135,27 @@ private[tls] object TLSEngine { case SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING => unwrapBuffer.inputRemains .map(_ > 0 && result.bytesConsumed > 0) - .ifM(unwrap(timeout), dequeueUnwrap) + .ifM(unwrap(maxBytes, timeout), dequeueUnwrap(maxBytes)) case SSLEngineResult.HandshakeStatus.FINISHED => - unwrap(timeout) + unwrap(maxBytes, timeout) case _ => handshakeSemaphore - .withPermit(stepHandshake(result, false, timeout)) >> unwrap(timeout) + .withPermit(stepHandshake(result, false, timeout)) >> unwrap( + maxBytes, + timeout + ) } case SSLEngineResult.Status.BUFFER_UNDERFLOW => - dequeueUnwrap + dequeueUnwrap(maxBytes) case SSLEngineResult.Status.BUFFER_OVERFLOW => - unwrapBuffer.expandOutput >> unwrap(timeout) + unwrapBuffer.expandOutput >> unwrap(maxBytes, timeout) case SSLEngineResult.Status.CLOSED => - stopWrap >> stopUnwrap >> dequeueUnwrap + stopWrap >> stopUnwrap >> dequeueUnwrap(maxBytes) } } - private def dequeueUnwrap: F[Option[Chunk[Byte]]] = - unwrapBuffer.output.map(out => if (out.isEmpty) None else Some(out)) + private def dequeueUnwrap(maxBytes: Int): F[Option[Chunk[Byte]]] = + unwrapBuffer.output(maxBytes).map(out => if (out.isEmpty) None else Some(out)) /** * Determines what to do next given the result of a handshake operation.