Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Rate limit protection against rapid resets #4324

Merged
merged 3 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions akka-http-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ akka.http {
# Fail the connection if a sent ping is not acknowledged within this timeout.
# When zero the ping-interval is used, if set the value must be evenly divisible by less than or equal to the ping-interval.
ping-timeout = 0s

# Limit the number of RSTs a client is allowed to do on one connection, per interval
# Protects against rapid reset attacks. If a connection goes over the limit, it is closed with HTTP/2 protocol error ENHANCE_YOUR_CALM
max-resets = 400
max-resets-interval = 10s
}

websocket {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import akka.event.LoggingAdapter
import akka.http.impl.engine.{ HttpConnectionIdleTimeoutBidi, HttpIdleTimeoutException }
import akka.http.impl.engine.http2.FrameEvent._
import akka.http.impl.engine.http2.client.ResponseParsing
import akka.http.impl.engine.http2.framing.{ FrameRenderer, Http2FrameParsing }
import akka.http.impl.engine.http2.framing.{ FrameRenderer, Http2FrameParsing, RSTFrameLimit }
import akka.http.impl.engine.http2.hpack.{ HeaderCompression, HeaderDecompression }
import akka.http.impl.engine.parsing.HttpHeaderParser
import akka.http.impl.engine.rendering.DateHeaderRendering
Expand Down Expand Up @@ -108,7 +108,7 @@ private[http] object Http2Blueprint {
serverDemux(settings.http2Settings, initialDemuxerSettings, upgraded) atop
FrameLogger.logFramesIfEnabled(settings.http2Settings.logFrames) atop // enable for debugging
hpackCoding(masterHttpHeaderParser, settings.parserSettings) atop
framing(log) atop
framing(settings.http2Settings, log) atop
errorHandling(log) atop
idleTimeoutIfConfigured(settings.idleTimeout)
}
Expand Down Expand Up @@ -167,10 +167,12 @@ private[http] object Http2Blueprint {
Flow[ByteString]
)

def framing(log: LoggingAdapter): BidiFlow[FrameEvent, ByteString, ByteString, FrameEvent, NotUsed] =
def framing(http2ServerSettings: Http2ServerSettings, log: LoggingAdapter): BidiFlow[FrameEvent, ByteString, ByteString, FrameEvent, NotUsed] =
BidiFlow.fromFlows(
Flow[FrameEvent].map(FrameRenderer.render),
Flow[ByteString].via(new Http2FrameParsing(shouldReadPreface = true, log)))
Flow[ByteString].via(new Http2FrameParsing(shouldReadPreface = true, log))
.via(new RSTFrameLimit(http2ServerSettings))
)

def framingClient(log: LoggingAdapter): BidiFlow[FrameEvent, ByteString, ByteString, FrameEvent, NotUsed] =
BidiFlow.fromFlows(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package akka.http.impl.engine.http2.framing

import akka.annotation.InternalApi
import akka.http.impl.engine.http2.{ FrameEvent, Http2Compliance }
import akka.http.impl.engine.http2.FrameEvent.RstStreamFrame
import akka.http.impl.engine.http2.Http2Protocol.ErrorCode
import akka.http.scaladsl.settings.Http2ServerSettings
import akka.stream.{ Attributes, FlowShape, Inlet, Outlet }
import akka.stream.stage.{ GraphStage, GraphStageLogic, InHandler, OutHandler }

import scala.concurrent.duration._

/**
* INTERNAL API
*/
@InternalApi
private[akka] final class RSTFrameLimit(http2ServerSettings: Http2ServerSettings) extends GraphStage[FlowShape[FrameEvent, FrameEvent]] {

private val maxResets = http2ServerSettings.maxResets
private val maxResetsIntervalNanos = http2ServerSettings.maxResetsInterval.toNanos

val in = Inlet[FrameEvent]("in")
val out = Outlet[FrameEvent]("out")
val shape = FlowShape(in, out)

override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new GraphStageLogic(shape) with InHandler with OutHandler {
private var rstCount = 0
private var rstSpanStartNanos = 0L

setHandlers(in, out, this)

override def onPush(): Unit = {
grab(in) match {
case frame: RstStreamFrame =>
rstCount += 1
val now = System.nanoTime()
if (rstSpanStartNanos == 0L) {
johanandren marked this conversation as resolved.
Show resolved Hide resolved
rstSpanStartNanos = now
push(out, frame)
} else if ((now - rstSpanStartNanos) <= maxResetsIntervalNanos) {
if (rstCount > maxResets) {
failStage(new Http2Compliance.Http2ProtocolException(
ErrorCode.ENHANCE_YOUR_CALM,
s"Too many RST frames per second for this connection. (Configured limit ${maxResets}/${maxResetsIntervalNanos.nanos.toSeconds} s)"))
johanandren marked this conversation as resolved.
Show resolved Hide resolved
} else {
push(out, frame)
}
} else {
// outside time window, reset counter
rstCount = 1
rstSpanStartNanos = now
push(out, frame)
}

case frame =>
push(out, frame)
}
}

override def onPull(): Unit = pull(in)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ trait Http2ServerSettings { self: scaladsl.settings.Http2ServerSettings with akk

def getPingTimeout: Duration = Duration.ofMillis(pingTimeout.toMillis)
def withPingTimeout(timeout: Duration): Http2ServerSettings = withPingTimeout(timeout.toMillis.millis)

def maxResets: Int

def withMaxResets(n: Int): Http2ServerSettings = copy(maxResets = n)

def getMaxResetsInterval: Duration = Duration.ofMillis(maxResetsInterval.toMillis)

def withMaxResetsInterval(interval: Duration): Http2ServerSettings = copy(maxResetsInterval = interval.toMillis.millis)

}
object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
def create(config: Config): Http2ServerSettings = scaladsl.settings.Http2ServerSettings(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ trait Http2ServerSettings extends javadsl.settings.Http2ServerSettings with Http
def pingTimeout: FiniteDuration
def withPingTimeout(timeout: FiniteDuration): Http2ServerSettings = copy(pingTimeout = timeout)

def maxResets: Int

override def withMaxResets(n: Int): Http2ServerSettings = copy(maxResets = n)

def maxResetsInterval: FiniteDuration

def withMaxResetsInterval(interval: FiniteDuration): Http2ServerSettings = copy(maxResetsInterval = interval)

@InternalApi
private[http] def internalSettings: Option[Http2InternalServerSettings]
@InternalApi
Expand All @@ -110,7 +118,10 @@ object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
logFrames: Boolean,
pingInterval: FiniteDuration,
pingTimeout: FiniteDuration,
internalSettings: Option[Http2InternalServerSettings])
maxResets: Int,
maxResetsInterval: FiniteDuration,
internalSettings: Option[Http2InternalServerSettings]
)
extends Http2ServerSettings {
require(maxConcurrentStreams >= 0, "max-concurrent-streams must be >= 0")
require(requestEntityChunkSize > 0, "request-entity-chunk-size must be > 0")
Expand All @@ -134,7 +145,9 @@ object Http2ServerSettings extends SettingsCompanion[Http2ServerSettings] {
logFrames = c.getBoolean("log-frames"),
pingInterval = c.getFiniteDuration("ping-interval"),
pingTimeout = c.getFiniteDuration("ping-timeout"),
None // no possibility to configure internal settings with config
maxResets = c.getInt("max-resets"),
maxResetsInterval = c.getFiniteDuration("max-resets-interval"),
internalSettings = None, // no possibility to configure internal settings with config
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import akka.http.impl.engine.http2.Http2Protocol.ErrorCode
import akka.http.impl.engine.http2.Http2Protocol.Flags
import akka.http.impl.engine.http2.Http2Protocol.FrameType
import akka.http.impl.engine.http2.Http2Protocol.SettingIdentifier
import akka.http.impl.engine.http2.framing.FrameRenderer
import akka.http.impl.engine.server.{ HttpAttributes, ServerTerminator }
import akka.http.impl.engine.ws.ByteStringSinkProbe
import akka.http.impl.util.AkkaSpecWithMaterializer
Expand All @@ -22,28 +23,29 @@ import akka.http.scaladsl.model._
import akka.http.scaladsl.model.headers.CacheDirectives
import akka.http.scaladsl.model.headers.RawHeader
import akka.http.scaladsl.settings.ServerSettings
import akka.stream.Attributes
import akka.stream.{ Attributes, DelayOverflowStrategy, OverflowStrategy }
import akka.stream.Attributes.LogLevels
import akka.stream.OverflowStrategy
import akka.stream.scaladsl.{ BidiFlow, Flow, Keep, Sink, Source, SourceQueueWithComplete }
import akka.stream.testkit.TestPublisher.{ ManualProbe, Probe }
import akka.stream.testkit.scaladsl.StreamTestKit
import akka.stream.testkit.TestPublisher
import akka.stream.testkit.TestSubscriber
import akka.testkit._
import akka.util.ByteString
import akka.util.{ ByteString, ByteStringBuilder }

import scala.annotation.nowarn
import javax.net.ssl.SSLContext
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.PatienceConfiguration.Timeout

import java.nio.ByteOrder
import scala.collection.immutable
import scala.concurrent.duration._
import scala.concurrent.Await
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Success

/**
* This tests the http2 server protocol logic.
Expand Down Expand Up @@ -1686,6 +1688,31 @@ class Http2ServerSpec extends AkkaSpecWithMaterializer("""
terminated.futureValue
}
}

"not allow high a frequency of resets for one connection" in StreamTestKit.assertAllStagesStopped(new TestSetup {

override def settings: ServerSettings = super.settings.withHttp2Settings(super.settings.http2Settings.withMaxResets(100).withMaxResetsInterval(2.seconds))

// covers CVE-2023-44487 with a rapid sequence of RSTs
override def handlerFlow: Flow[HttpRequest, HttpResponse, NotUsed] = Flow[HttpRequest].buffer(1000, OverflowStrategy.backpressure).mapAsync(300) { req =>
// never actually reached since rst is in headers
req.entity.discardBytes()
Future.successful(HttpResponse(entity = "Ok").withAttributes(req.attributes))
}

network.toNet.request(100000L)
val request = HttpRequest(protocol = HttpProtocols.`HTTP/2.0`, uri = "/foo")
val error = intercept[AssertionError] {
for (streamId <- 1 to 300 by 2) {
network.sendBytes(
FrameRenderer.render(HeadersFrame(streamId, true, true, network.encodeRequestHeaders(request), None))
++ FrameRenderer.render(RstStreamFrame(streamId, ErrorCode.CANCEL))
)
}
}
error.getMessage should include("Too many RST frames per second for this connection.")
johanandren marked this conversation as resolved.
Show resolved Hide resolved
network.toNet.cancel()
})
}

implicit class InWithStoppedStages(name: String) {
Expand Down
Loading