Skip to content

Commit

Permalink
feat: get access key from request and read valid access keys from file
Browse files Browse the repository at this point in the history
issue: #218
  • Loading branch information
fengelniederhammer committed May 2, 2023
1 parent aff116c commit 4e07a6b
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 48 deletions.
18 changes: 10 additions & 8 deletions lapis2/src/main/kotlin/org/genspectrum/lapis/LapisSpringConfig.kt
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package org.genspectrum.lapis

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
import com.fasterxml.jackson.module.kotlin.readValue
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
import mu.KotlinLogging
import org.genspectrum.lapis.auth.DataOpennessAuthorizationFilter
import org.genspectrum.lapis.auth.DataOpennessAuthorizationFilterFactory
import org.genspectrum.lapis.config.DatabaseConfig
import org.genspectrum.lapis.config.SequenceFilterFields
import org.genspectrum.lapis.logging.RequestContext
import org.genspectrum.lapis.logging.RequestContextLogger
import org.genspectrum.lapis.logging.StatisticsLogObjectMapper
import org.genspectrum.lapis.util.TimeFactory
import org.genspectrum.lapis.util.YamlObjectMapper
import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
Expand All @@ -24,8 +22,11 @@ class LapisSpringConfig {
fun openAPI(sequenceFilterFields: SequenceFilterFields) = buildOpenApiSchema(sequenceFilterFields)

@Bean
fun databaseConfig(@Value("\${lapis.databaseConfig.path}") configPath: String): DatabaseConfig {
return ObjectMapper(YAMLFactory()).registerKotlinModule().readValue(File(configPath))
fun databaseConfig(
@Value("\${lapis.databaseConfig.path}") configPath: String,
yamlObjectMapper: YamlObjectMapper,
): DatabaseConfig {
return yamlObjectMapper.objectMapper.readValue(File(configPath))
}

@Bean
Expand Down Expand Up @@ -55,6 +56,7 @@ class LapisSpringConfig {
)

@Bean
fun dataOpennessAuthorizationFilter(databaseConfig: DatabaseConfig, objectMapper: ObjectMapper) =
DataOpennessAuthorizationFilter.createFromConfig(databaseConfig, objectMapper)
fun dataOpennessAuthorizationFilter(
dataOpennessAuthorizationFilterFactory: DataOpennessAuthorizationFilterFactory,
) = dataOpennessAuthorizationFilterFactory.create()
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,52 @@
package org.genspectrum.lapis.auth

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import mu.KotlinLogging
import org.genspectrum.lapis.config.AccessKeys
import org.genspectrum.lapis.config.AccessKeysReader
import org.genspectrum.lapis.config.DatabaseConfig
import org.genspectrum.lapis.config.OpennessLevel
import org.genspectrum.lapis.controller.LapisHttpErrorResponse
import org.genspectrum.lapis.util.CachedBodyHttpServletRequest
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.stereotype.Component
import org.springframework.web.filter.OncePerRequestFilter

abstract class DataOpennessAuthorizationFilter(val objectMapper: ObjectMapper) : OncePerRequestFilter() {
const val ACCESS_KEY_PROPERTY = "accessKey"

private val log = KotlinLogging.logger {}

@Component
class DataOpennessAuthorizationFilterFactory(
private val databaseConfig: DatabaseConfig,
private val objectMapper: ObjectMapper,
private val accessKeysReader: AccessKeysReader,
) {
fun create() = when (databaseConfig.schema.opennessLevel) {
OpennessLevel.OPEN -> AlwaysAuthorizedAuthorizationFilter(objectMapper)
OpennessLevel.GISAID -> ProtectedGisaidDataAuthorizationFilter(
objectMapper,
accessKeysReader.read(),
databaseConfig.schema.metadata.filter { it.unique }.map { it.name },
)
}
}

abstract class DataOpennessAuthorizationFilter(protected val objectMapper: ObjectMapper) : OncePerRequestFilter() {
override fun doFilterInternal(
request: HttpServletRequest,
response: HttpServletResponse,
filterChain: FilterChain,
) {
when (val result = isAuthorizedForEndpoint(request)) {
AuthorizationResult.Success -> filterChain.doFilter(request, response)
val reReadableRequest = CachedBodyHttpServletRequest(request)

when (val result = isAuthorizedForEndpoint(reReadableRequest)) {
AuthorizationResult.Success -> filterChain.doFilter(reReadableRequest, response)
is AuthorizationResult.Failure -> {
response.status = HttpStatus.FORBIDDEN.value()
response.contentType = MediaType.APPLICATION_JSON_VALUE
Expand All @@ -34,15 +62,7 @@ abstract class DataOpennessAuthorizationFilter(val objectMapper: ObjectMapper) :
}
}

abstract fun isAuthorizedForEndpoint(request: HttpServletRequest): AuthorizationResult

companion object {
fun createFromConfig(databaseConfig: DatabaseConfig, objectMapper: ObjectMapper) =
when (databaseConfig.schema.opennessLevel) {
OpennessLevel.OPEN -> NoOpAuthorizationFilter(objectMapper)
OpennessLevel.GISAID -> ProtectedGisaidDataAuthorizationFilter(objectMapper)
}
}
abstract fun isAuthorizedForEndpoint(request: CachedBodyHttpServletRequest): AuthorizationResult
}

sealed interface AuthorizationResult {
Expand All @@ -52,24 +72,64 @@ sealed interface AuthorizationResult {
fun failure(message: String): AuthorizationResult = Failure(message)
}

fun isSuccessful(): Boolean

object Success : AuthorizationResult {
override fun isSuccessful() = true
}
object Success : AuthorizationResult

class Failure(val message: String) : AuthorizationResult {
override fun isSuccessful() = false
}
class Failure(val message: String) : AuthorizationResult
}

private class NoOpAuthorizationFilter(objectMapper: ObjectMapper) : DataOpennessAuthorizationFilter(objectMapper) {
override fun isAuthorizedForEndpoint(request: HttpServletRequest) = AuthorizationResult.success()
private class AlwaysAuthorizedAuthorizationFilter(objectMapper: ObjectMapper) :
DataOpennessAuthorizationFilter(objectMapper) {

override fun isAuthorizedForEndpoint(request: CachedBodyHttpServletRequest) = AuthorizationResult.success()
}

private class ProtectedGisaidDataAuthorizationFilter(objectMapper: ObjectMapper) :
private class ProtectedGisaidDataAuthorizationFilter(
objectMapper: ObjectMapper,
private val accessKeys: AccessKeys,
private val fieldsThatServeNonAggregatedData: List<String>,
) :
DataOpennessAuthorizationFilter(objectMapper) {

override fun isAuthorizedForEndpoint(request: HttpServletRequest) =
AuthorizationResult.failure("An access key is required to access this endpoint.")
companion object {
private val ENDPOINTS_THAT_SERVE_AGGREGATED_DATA = listOf("/aggregated", "/nucleotideMutations")
}

override fun isAuthorizedForEndpoint(request: CachedBodyHttpServletRequest): AuthorizationResult {
val requestFields = getRequestFields(request)

val accessKey = requestFields[ACCESS_KEY_PROPERTY]
?: return AuthorizationResult.failure("An access key is required to access this endpoint.")

if (accessKeys.fullAccessKey == accessKey) {
return AuthorizationResult.success()
}

val endpointServesAggregatedData = ENDPOINTS_THAT_SERVE_AGGREGATED_DATA.contains(request.requestURI) &&
fieldsThatServeNonAggregatedData.intersect(requestFields.keys).isEmpty()

if (endpointServesAggregatedData && accessKeys.aggregatedDataAccessKey == accessKey) {
return AuthorizationResult.success()
}

return AuthorizationResult.failure("You are not authorized to access this endpoint.")
}

private fun getRequestFields(request: CachedBodyHttpServletRequest): Map<String, String> {
if (request.parameterNames.hasMoreElements()) {
return request.parameterMap.mapValues { (_, value) -> value.joinToString() }
}

if (request.contentLength == 0) {
log.warn { "Could not read access key from body, because content length is 0." }
return emptyMap()
}

return try {
objectMapper.readValue(request.inputStream)
} catch (exception: Exception) {
log.error { "Failed to read access key from request body: ${exception.message}" }
log.debug { exception.stackTraceToString() }
emptyMap()
}
}
}
23 changes: 23 additions & 0 deletions lapis2/src/main/kotlin/org/genspectrum/lapis/config/AccessKeys.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.genspectrum.lapis.config

import com.fasterxml.jackson.module.kotlin.readValue
import org.genspectrum.lapis.util.YamlObjectMapper
import org.springframework.beans.factory.annotation.Value
import org.springframework.stereotype.Component
import java.io.File

@Component
class AccessKeysReader(
@Value("\${lapis.accessKeys.path:#{null}}") private val accessKeysFile: String?,
private val yamlObjectMapper: YamlObjectMapper,
) {
fun read(): AccessKeys {
if (accessKeysFile == null) {
throw IllegalArgumentException("Cannot read LAPIS access keys, lapis.accessKeys.path was not set.")
}

return yamlObjectMapper.objectMapper.readValue(File(accessKeysFile))
}
}

data class AccessKeys(val fullAccessKey: String, val aggregatedDataAccessKey: String)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ data class DatabaseSchema(
val features: List<DatabaseFeature> = emptyList(),
)

data class DatabaseMetadata(val name: String, val type: String)
data class DatabaseMetadata(val name: String, val type: String, val unique: Boolean = false)

data class DatabaseFeature(val name: String)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.swagger.v3.oas.annotations.media.ArraySchema
import io.swagger.v3.oas.annotations.media.Content
import io.swagger.v3.oas.annotations.media.Schema
import io.swagger.v3.oas.annotations.responses.ApiResponse
import org.genspectrum.lapis.auth.ACCESS_KEY_PROPERTY
import org.genspectrum.lapis.logging.RequestContext
import org.genspectrum.lapis.model.SiloQueryModel
import org.genspectrum.lapis.response.AggregatedResponse
Expand All @@ -27,6 +28,9 @@ private const val DEFAULT_MIN_PROPORTION = 0.05

@RestController
class LapisController(private val siloQueryModel: SiloQueryModel, private val requestContext: RequestContext) {
companion object {
private val nonSequenceFilterFields = listOf(MIN_PROPORTION_PROPERTY, ACCESS_KEY_PROPERTY)
}

@GetMapping("/aggregated")
@LapisAggregatedResponse
Expand All @@ -41,7 +45,7 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
): AggregatedResponse {
requestContext.filter = sequenceFilters

return siloQueryModel.aggregate(sequenceFilters)
return siloQueryModel.aggregate(sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) })
}

@PostMapping("/aggregated")
Expand All @@ -53,7 +57,7 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
): AggregatedResponse {
requestContext.filter = sequenceFilters

return siloQueryModel.aggregate(sequenceFilters)
return siloQueryModel.aggregate(sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) })
}

@GetMapping("/nucleotideMutations")
Expand All @@ -72,7 +76,7 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re

return siloQueryModel.computeMutationProportions(
minProportion,
sequenceFilters.filterKeys { it != MIN_PROPORTION_PROPERTY },
sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) },
)
}

Expand All @@ -85,9 +89,11 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
): List<MutationData> {
requestContext.filter = requestBody

val (minProportions, sequenceFilters) = requestBody.entries.partition { it.key == MIN_PROPORTION_PROPERTY }
val (nonSequenceFilters, sequenceFilters) = requestBody.entries.partition {
nonSequenceFilterFields.contains(it.key)
}

val maybeMinProportion = minProportions.getOrNull(0)?.value
val maybeMinProportion = nonSequenceFilters.find { it.key == MIN_PROPORTION_PROPERTY }?.value
val minProportion = try {
maybeMinProportion?.toDouble() ?: DEFAULT_MIN_PROPORTION
} catch (exception: IllegalArgumentException) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.genspectrum.lapis.util

import jakarta.servlet.ReadListener
import jakarta.servlet.ServletInputStream
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletRequestWrapper
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream

class CachedBodyHttpServletRequest(request: HttpServletRequest) : HttpServletRequestWrapper(request) {
private val cachedBody: ByteArray by lazy {
val inputStream: InputStream = request.inputStream
val byteArrayOutputStream = ByteArrayOutputStream()

inputStream.copyTo(byteArrayOutputStream)
byteArrayOutputStream.toByteArray()
}

@Throws(IOException::class)
override fun getInputStream(): ServletInputStream {
return CachedBodyServletInputStream(ByteArrayInputStream(cachedBody))
}

private inner class CachedBodyServletInputStream(private val cachedInputStream: ByteArrayInputStream) :
ServletInputStream() {

override fun isFinished(): Boolean {
return cachedInputStream.available() == 0
}

override fun isReady(): Boolean {
return true
}

override fun setReadListener(listener: ReadListener) {
throw UnsupportedOperationException("setReadListener is not supported")
}

@Throws(IOException::class)
override fun read(): Int {
return cachedInputStream.read()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.genspectrum.lapis.util

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
import org.springframework.stereotype.Component

@Component
object YamlObjectMapper {
val objectMapper: ObjectMapper = ObjectMapper(YAMLFactory()).registerKotlinModule()
}
Loading

0 comments on commit 4e07a6b

Please sign in to comment.