Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CORS check #131

Merged
merged 2 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ kotlin {
commonMain {
dependencies {
api("io.ktor:ktor-server-core:$ktorVersion")
api("io.ktor:ktor-server-cors:$ktorVersion")
api("org.jetbrains.kotlinx:kotlinx-datetime:0.3.2")
}
}
Expand All @@ -68,6 +69,7 @@ kotlin {

implementation("io.ktor:ktor-server-test-host:$ktorVersion")
implementation("io.ktor:ktor-server-cio:$ktorVersion")
implementation("io.ktor:ktor-server-auth:$ktorVersion")
implementation("io.ktor:ktor-client-cio:$ktorVersion")

implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.1")
Expand Down
8 changes: 6 additions & 2 deletions src/commonMain/kotlin/app/softwork/ratelimit/Configuration.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ public class Configuration(public val storage: Storage) {
*/
public var sendRetryAfterHeader: Boolean = true

public companion object { }
/**
* When true, this plugin does not check, if CORS is appplied before this plugin to prevent limiting CORS requests.
*/
public var ignoreCORSInstallationCheck: Boolean = false

/**
* Build a non mutating copy
Expand All @@ -81,6 +84,7 @@ public class Configuration(public val storage: Storage) {
limit = limit,
timeout = timeout,
skip = skip,
sendRetryAfterHeader = sendRetryAfterHeader
sendRetryAfterHeader = sendRetryAfterHeader,
ignoreCORSCheck = ignoreCORSInstallationCheck
)
}
13 changes: 11 additions & 2 deletions src/commonMain/kotlin/app/softwork/ratelimit/RateLimit.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package app.softwork.ratelimit
import app.softwork.ratelimit.SkipResult.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.plugins.cors.*
import io.ktor.server.response.*
import io.ktor.util.*
import kotlin.time.*
Expand All @@ -19,6 +20,13 @@ public fun RateLimit(storage: Storage): RouteScopedPlugin<Configuration> = creat
) {
val rateLimit = pluginConfig.build()

if (!rateLimit.ignoreCORSCheck) {
checkNotNull(application.pluginOrNull(CORS)) {
"Please install CORS before this plugin to prevent limiting CORS request " +
"or suppress this check with ignoreCORSInstallationCheck = true."
}
}

onCall { call ->
val host = rateLimit.host(call)
if (rateLimit.skip(call) == SkipRateLimit) {
Expand All @@ -41,15 +49,16 @@ public fun RateLimit(storage: Storage): RouteScopedPlugin<Configuration> = creat
/**
* Non mutating config
*/
internal class RateLimit(
internal data class RateLimit(
val storage: Storage,
val host: (ApplicationCall) -> String,
val alwaysAllow: (String) -> Boolean,
val alwaysBlock: (String) -> Boolean,
val limit: Int,
val timeout: Duration,
val skip: (ApplicationCall) -> SkipResult,
val sendRetryAfterHeader: Boolean
val sendRetryAfterHeader: Boolean,
val ignoreCORSCheck: Boolean
) {
/**
* Check if a [host] is allowed to request the requested resource.
Expand Down
69 changes: 69 additions & 0 deletions src/jvmTest/kotlin/app/softwork/ratelimit/CompatibilityTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package app.softwork.ratelimit

import app.softwork.ratelimit.MockStorage.Companion.toClock
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.plugins.cors.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.testing.*
import kotlin.test.*
import kotlin.time.*

@ExperimentalTime
class CompatibilityTest {
@Test
fun authentication() = testApplication {
val clock = TestTimeSource().toClock()
install(CORS)
install(Authentication) {
basic {
validate {
UserIdPrincipal(it.name)
}
}
}

routing {
route("/foo") {
install(RateLimit(MockStorage(clock))) {
limit = 5
}
authenticate {
get {
val user = call.principal<UserIdPrincipal>()!!
call.respondText { "Hello ${user.name}" }
}
}
}
}
repeat(5) {
assertEquals(actual = client.get("/foo") {
basicAuth("foo", "bar")
}.status, expected = HttpStatusCode.OK)
}
assertEquals(actual = client.get("/foo") {
basicAuth("foo", "bar")
}.status, expected = HttpStatusCode.TooManyRequests)
}

@Test
fun cors() = testApplication {
install(CORS)

routing {
install(RateLimit(MockStorage()))
}
}

@Test
fun corsMissing() {
assertFailsWith<IllegalStateException> {
testApplication {
install(RateLimit(MockStorage()))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.cio.*
import io.ktor.server.engine.*
import io.ktor.server.plugins.cors.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import kotlinx.coroutines.*
Expand Down Expand Up @@ -76,10 +77,10 @@ class DatabaseStorageTest {
fun dbBasedRateLimit() = dbTest(setup = {
SchemaUtils.create(Locks)
}) { db ->
val rateLimit = Configuration(storage = DBStorage(db = db, testTimeSource.toClock())) {
val rateLimit = Configuration(storage = DBStorage(db = db, testTimeSource.toClock())).apply {
limit = 3
timeout = 3.seconds
}
}.build()

rateLimit.test(limit = 3, timeout = 3.seconds)
}
Expand All @@ -89,6 +90,7 @@ class DatabaseStorageTest {
SchemaUtils.create(Locks)
}) {db ->
val server = embeddedServer(CIO, port = 0) {
install(CORS)
routing {
install(RateLimit(storage = DBStorage(db = db, Clock.System))) {
limit = 10
Expand Down
3 changes: 2 additions & 1 deletion src/jvmTest/kotlin/app/softwork/ratelimit/MockStorage.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package app.softwork.ratelimit
import kotlinx.datetime.*
import kotlin.time.*

@ExperimentalTime
class MockStorage(
override val clock: Clock
override val clock: Clock = TestTimeSource().toClock()
) : Storage {
data class Requested(override val trial: Int, override val lastRequest: Instant): Storage.Requested

Expand Down
4 changes: 2 additions & 2 deletions src/jvmTest/kotlin/app/softwork/ratelimit/MockStorageTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class MockStorageTest {
fun testMock() = runTest {
val rateLimit: RateLimit = Configuration(
storage = MockStorage(testTimeSource.toClock()),
) {
).apply {
limit = 3
timeout = 3.seconds
}
}.build()
rateLimit.test(limit = 3, timeout = 3.seconds)
}
}
34 changes: 27 additions & 7 deletions src/jvmTest/kotlin/app/softwork/ratelimit/RateLimitTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import kotlin.time.Duration.Companion.seconds
class RateLimitTest {
@Test
fun installTest() = testApplication {
install(RateLimit(MockStorage(TestTimeSource().toClock()))) {
testRateLimit {
limit = 10
}
routing {
Expand All @@ -36,7 +36,7 @@ class RateLimitTest {

@Test
fun noHeader() = testApplication {
install(RateLimit(MockStorage(TestTimeSource().toClock()))) {
testRateLimit {
limit = 10
sendRetryAfterHeader = false
}
Expand All @@ -55,7 +55,7 @@ class RateLimitTest {

@Test
fun rateLimitOnlyLoginEndpoint() = testApplication {
install(RateLimit(MockStorage(TestTimeSource().toClock()))) {
testRateLimit {
limit = 3
skip { call ->
if (call.request.local.uri == "/login") {
Expand Down Expand Up @@ -87,8 +87,9 @@ class RateLimitTest {
call.respondText { "42" }
}
route("/login") {
install(RateLimit(MockStorage(TestTimeSource().toClock()))) {
install(RateLimit(MockStorage())) {
limit = 3
ignoreCORSInstallationCheck = true

this@route.get {
call.respondText { "/login called" }
Expand All @@ -105,7 +106,7 @@ class RateLimitTest {

@Test
fun blockAllowTest() = testApplication {
install(RateLimit(MockStorage(TestTimeSource().toClock()))) {
testRateLimit {
limit = 3
alwaysBlock { host ->
host == "blockedHost"
Expand All @@ -117,6 +118,7 @@ class RateLimitTest {
call.request.local.host
}
}

routing {
get {
call.respondText { "Hello" }
Expand Down Expand Up @@ -164,5 +166,23 @@ internal suspend fun RateLimit.test(limit: Int, timeout: Duration) {
}
}

internal operator fun Configuration.Companion.invoke(storage: Storage, block: Configuration.() -> Unit): RateLimit =
Configuration(storage).apply(block).build()
@ExperimentalTime
fun TestApplicationBuilder.testRateLimit(
storage: Storage = MockStorage(clock = TestTimeSource().toClock()),
block: Configuration.() -> Unit
) {
application {
testRateLimit(storage, block)
}
}

@ExperimentalTime
fun Application.testRateLimit(
storage: Storage = MockStorage(clock = TestTimeSource().toClock()),
block: Configuration.() -> Unit
) {
install(RateLimit(storage)) {
ignoreCORSInstallationCheck = true
block()
}
}