Skip to content

Commit

Permalink
KTOR-578 KTOR-800 Fix Netty HTTP/2 (#3152)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsinukov authored Sep 5, 2022
1 parent 08400a5 commit b9d20b4
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 56 deletions.
5 changes: 5 additions & 0 deletions ktor-server/ktor-server-netty/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ kotlin.sourceSets {
}
}
}

val jvmTest: org.jetbrains.kotlin.gradle.targets.jvm.tasks.KotlinJvmTest by tasks
jvmTest.apply {
systemProperty("enable.http2", "true")
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
package io.ktor.server.netty

import io.ktor.server.engine.*
import io.ktor.server.netty.cio.*
import io.ktor.server.netty.http1.*
import io.ktor.server.netty.http2.*
import io.ktor.util.logging.*
import io.netty.channel.*
import io.netty.channel.socket.SocketChannel
import io.netty.handler.codec.http.*
Expand Down Expand Up @@ -103,7 +103,13 @@ public class NettyChannelInitializer(
private fun configurePipeline(pipeline: ChannelPipeline, protocol: String) {
when (protocol) {
ApplicationProtocolNames.HTTP_2 -> {
val handler = NettyHttp2Handler(enginePipeline, environment.application, callEventGroup, userContext)
val handler = NettyHttp2Handler(
enginePipeline,
environment.application,
callEventGroup,
userContext,
runningLimit
)
@Suppress("DEPRECATION")
pipeline.addLast(Http2MultiplexCodecBuilder.forServer(handler).build())
pipeline.channel().closeFuture().addListener {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2014-2022 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.netty

import io.netty.channel.*
import kotlinx.atomicfu.*

internal class NettyHttpHandlerState(private val runningLimit: Int) {

internal val activeRequests: AtomicLong = atomic(0L)
internal val isCurrentRequestFullyRead: AtomicBoolean = atomic(false)
internal val isChannelReadCompleted: AtomicBoolean = atomic(false)
internal val skippedRead: AtomicBoolean = atomic(false)

internal fun onLastResponseMessage(context: ChannelHandlerContext) {
activeRequests.decrementAndGet()

if (skippedRead.compareAndSet(expect = false, update = true) && activeRequests.value < runningLimit) {
context.read()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,10 @@ private const val UNFLUSHED_LIMIT = 65536

/**
* Contains methods for handling http request with Netty
* @param context
* @param coroutineContext
* @param activeRequests
* @param isCurrentRequestFullyRead
* @param isChannelReadCompleted
*/
@OptIn(InternalAPI::class)
internal class NettyHttpResponsePipeline constructor(
private val context: ChannelHandlerContext,
private val httpHandler: NettyHttp1Handler,
private val httpHandlerState: NettyHttpHandlerState,
override val coroutineContext: CoroutineContext
) : CoroutineScope {
/**
Expand All @@ -56,8 +50,8 @@ internal class NettyHttpResponsePipeline constructor(
internal fun flushIfNeeded() {
if (
isDataNotFlushed.value &&
httpHandler.isChannelReadCompleted.value &&
httpHandler.activeRequests.value == 0L
httpHandlerState.isChannelReadCompleted.value &&
httpHandlerState.activeRequests.value == 0L
) {
context.flush()
isDataNotFlushed.compareAndSet(expect = true, update = false)
Expand Down Expand Up @@ -145,7 +139,7 @@ internal class NettyHttpResponsePipeline constructor(
null
}

httpHandler.onLastResponseMessage(context)
httpHandlerState.onLastResponseMessage(context)
call.finishedEvent.setSuccess()

lastMessageFuture?.addListener {
Expand Down Expand Up @@ -232,9 +226,9 @@ internal class NettyHttpResponsePipeline constructor(
* True if client is waiting for response header, false otherwise
*/
private fun isHeaderFlushNeeded(): Boolean {
val activeRequestsValue = httpHandler.activeRequests.value
return httpHandler.isChannelReadCompleted.value &&
!httpHandler.isCurrentRequestFullyRead.value &&
val activeRequestsValue = httpHandlerState.activeRequests.value
return httpHandlerState.isChannelReadCompleted.value &&
!httpHandlerState.isCurrentRequestFullyRead.value &&
activeRequestsValue == 1L
}

Expand Down Expand Up @@ -365,7 +359,6 @@ internal class NettyHttpResponsePipeline constructor(
}
}

@OptIn(InternalAPI::class)
private fun NettyApplicationResponse.isUpgradeResponse() =
status()?.value == HttpStatusCode.SwitchingProtocols.value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,15 @@ internal class NettyHttp1Handler(
override val coroutineContext: CoroutineContext get() = handlerJob

private var skipEmpty = false
private val skippedRead: AtomicBoolean = atomic(false)

private lateinit var responseWriter: NettyHttpResponsePipeline

/**
* Represents current number of processing requests
*/
internal val activeRequests: AtomicLong = atomic(0L)

/**
* True if current request's last http content is read, false otherwise.
*/
internal val isCurrentRequestFullyRead: AtomicBoolean = atomic(false)

/**
* True if [channelReadComplete] was invoked for the current request, false otherwise
*/
internal val isChannelReadCompleted: AtomicBoolean = atomic(false)
private val state = NettyHttpHandlerState(runningLimit)

override fun channelActive(context: ChannelHandlerContext) {
responseWriter = NettyHttpResponsePipeline(
context,
this,
state,
coroutineContext
)

Expand All @@ -68,16 +54,16 @@ internal class NettyHttp1Handler(

override fun channelRead(context: ChannelHandlerContext, message: Any) {
if (message is LastHttpContent) {
isCurrentRequestFullyRead.compareAndSet(expect = false, update = true)
state.isCurrentRequestFullyRead.compareAndSet(expect = false, update = true)
}

when {
message is HttpRequest -> {
if (message !is LastHttpContent) {
isCurrentRequestFullyRead.compareAndSet(expect = true, update = false)
state.isCurrentRequestFullyRead.compareAndSet(expect = true, update = false)
}
isChannelReadCompleted.compareAndSet(expect = true, update = false)
activeRequests.incrementAndGet()
state.isChannelReadCompleted.compareAndSet(expect = true, update = false)
state.activeRequests.incrementAndGet()

handleRequest(context, message)
callReadIfNeeded(context)
Expand Down Expand Up @@ -110,7 +96,7 @@ internal class NettyHttp1Handler(
}

override fun channelReadComplete(context: ChannelHandlerContext?) {
isChannelReadCompleted.compareAndSet(expect = false, update = true)
state.isChannelReadCompleted.compareAndSet(expect = false, update = true)
responseWriter.flushIfNeeded()
super.channelReadComplete(context)
}
Expand Down Expand Up @@ -165,19 +151,11 @@ internal class NettyHttp1Handler(
}

private fun callReadIfNeeded(context: ChannelHandlerContext) {
if (activeRequests.value < runningLimit) {
if (state.activeRequests.value < runningLimit) {
context.read()
skippedRead.value = false
state.skippedRead.value = false
} else {
skippedRead.value = true
}
}

internal fun onLastResponseMessage(context: ChannelHandlerContext) {
activeRequests.decrementAndGet()

if (skippedRead.compareAndSet(expect = false, update = true) && activeRequests.value < runningLimit) {
context.read()
state.skippedRead.value = true
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.engine.*
import io.ktor.server.netty.*
import io.ktor.server.netty.cio.*
import io.ktor.server.response.*
import io.ktor.util.*
import io.netty.channel.*
import io.netty.handler.codec.http2.*
import io.netty.util.AttributeKey
import io.netty.util.*
import io.netty.util.concurrent.*
import kotlinx.coroutines.*
import java.lang.reflect.*
Expand All @@ -24,16 +24,22 @@ internal class NettyHttp2Handler(
private val enginePipeline: EnginePipeline,
private val application: Application,
private val callEventGroup: EventExecutorGroup,
private val userCoroutineContext: CoroutineContext
private val userCoroutineContext: CoroutineContext,
runningLimit: Int
) : ChannelInboundHandlerAdapter(), CoroutineScope {
private val handlerJob = SupervisorJob(userCoroutineContext[Job])

private val state = NettyHttpHandlerState(runningLimit)
private lateinit var responseWriter: NettyHttpResponsePipeline

override val coroutineContext: CoroutineContext
get() = handlerJob

override fun channelRead(context: ChannelHandlerContext, message: Any?) {
override fun channelRead(context: ChannelHandlerContext, message: Any) {
when (message) {
is Http2HeadersFrame -> {
state.isChannelReadCompleted.compareAndSet(expect = true, update = false)
state.activeRequests.incrementAndGet()
startHttp2(context, message.headers())
}
is Http2DataFrame -> {
Expand All @@ -42,6 +48,9 @@ internal class NettyHttp2Handler(
contentActor.trySend(message).isSuccess
if (eof) {
contentActor.close()
state.isCurrentRequestFullyRead.compareAndSet(expect = false, update = true)
} else {
state.isCurrentRequestFullyRead.compareAndSet(expect = true, update = false)
}
} ?: message.release()
}
Expand All @@ -55,20 +64,30 @@ internal class NettyHttp2Handler(
}
}

override fun channelRegistered(ctx: ChannelHandlerContext?) {
super.channelRegistered(ctx)
override fun channelActive(context: ChannelHandlerContext) {
responseWriter = NettyHttpResponsePipeline(
context,
state,
coroutineContext
)

ctx?.pipeline()?.apply {
context.pipeline()?.apply {
addLast(callEventGroup, NettyApplicationCallHandler(userCoroutineContext, enginePipeline))
}
context.fireChannelActive()
}

override fun channelReadComplete(context: ChannelHandlerContext) {
state.isChannelReadCompleted.compareAndSet(expect = false, update = true)
responseWriter.flushIfNeeded()
context.fireChannelReadComplete()
}

@Suppress("OverridingDeprecatedMember")
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
ctx.close()
}

@OptIn(InternalAPI::class)
private fun startHttp2(context: ChannelHandlerContext, headers: Http2Headers) {
val call = NettyHttp2ApplicationCall(
application,
Expand All @@ -79,6 +98,9 @@ internal class NettyHttp2Handler(
userCoroutineContext
)
context.applicationCall = call

context.fireChannelRead(call)
responseWriter.processResponse(call)
}

@Suppress("DEPRECATION")
Expand Down

0 comments on commit b9d20b4

Please sign in to comment.