Skip to content

Commit

Permalink
feat: add fields to aggregated endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasKellerer authored and fengelniederhammer committed Jul 5, 2023
1 parent 8ba2f84 commit d183a0f
Show file tree
Hide file tree
Showing 13 changed files with 336 additions and 27 deletions.
25 changes: 25 additions & 0 deletions lapis2/src/main/kotlin/org/genspectrum/lapis/OpenApiDocs.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import org.genspectrum.lapis.config.OpennessLevel
import org.genspectrum.lapis.config.SequenceFilterFields
import org.genspectrum.lapis.controller.MIN_PROPORTION_PROPERTY
import org.genspectrum.lapis.controller.REQUEST_SCHEMA
import org.genspectrum.lapis.controller.REQUEST_SCHEMA_WITH_GROUP_BY_FIELDS
import org.genspectrum.lapis.controller.REQUEST_SCHEMA_WITH_MIN_PROPORTION
import org.genspectrum.lapis.controller.RESPONSE_SCHEMA_AGGREGATED

fun buildOpenApiSchema(sequenceFilterFields: SequenceFilterFields, databaseConfig: DatabaseConfig): OpenAPI {
var properties = sequenceFilterFields.fields
Expand All @@ -33,6 +35,29 @@ fun buildOpenApiSchema(sequenceFilterFields: SequenceFilterFields, databaseConfi
.type("object")
.description("valid filters for sequence data")
.properties(properties + Pair(MIN_PROPORTION_PROPERTY, Schema<String>().type("number"))),
).addSchemas(
REQUEST_SCHEMA_WITH_GROUP_BY_FIELDS,
Schema<String>()
.type("object")
.description("valid filters for sequence data")
.properties(
properties + Pair(
"fields",
Schema<String>().type("array").items(Schema<String>().type("string")),
),
),
).addSchemas(
RESPONSE_SCHEMA_AGGREGATED,
Schema<String>()
.type("object")
.description("aggregated sequence data")
.required(listOf("count"))
.properties(
properties +
mapOf(
"count" to Schema<String>().type("number"),
),
),
),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.swagger.v3.oas.annotations.responses.ApiResponse
import org.genspectrum.lapis.auth.ACCESS_KEY_PROPERTY
import org.genspectrum.lapis.logging.RequestContext
import org.genspectrum.lapis.model.SiloQueryModel
import org.genspectrum.lapis.request.AggregationRequest
import org.genspectrum.lapis.response.AggregationData
import org.genspectrum.lapis.response.MutationData
import org.springframework.web.bind.annotation.GetMapping
Expand All @@ -20,44 +21,54 @@ import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.annotation.RestController

const val MIN_PROPORTION_PROPERTY = "minProportion"

const val GROUP_BY_FIELDS_PROPERTY = "fields"
const val REQUEST_SCHEMA = "SequenceFilters"
const val REQUEST_SCHEMA_WITH_MIN_PROPORTION = "SequenceFiltersWithMinProportion"
const val REQUEST_SCHEMA_WITH_GROUP_BY_FIELDS = "SequenceFiltersWithGroupByFields"
const val RESPONSE_SCHEMA_AGGREGATED = "AggregatedResponse"

private const val DEFAULT_MIN_PROPORTION = 0.05

@RestController
class LapisController(private val siloQueryModel: SiloQueryModel, private val requestContext: RequestContext) {
companion object {
private val nonSequenceFilterFields = listOf(MIN_PROPORTION_PROPERTY, ACCESS_KEY_PROPERTY)
private val nonSequenceFilterFields =
listOf(MIN_PROPORTION_PROPERTY, ACCESS_KEY_PROPERTY, GROUP_BY_FIELDS_PROPERTY)
}

@GetMapping("/aggregated")
@LapisAggregatedResponse
fun aggregated(
@Parameter(
schema = Schema(ref = "#/components/schemas/$REQUEST_SCHEMA"),
schema = Schema(ref = "#/components/schemas/$REQUEST_SCHEMA_WITH_GROUP_BY_FIELDS"),
explode = Explode.TRUE,
style = ParameterStyle.FORM,
)
@RequestParam
sequenceFilters: Map<String, String>,
@RequestParam(defaultValue = "") fields: List<String>,
): List<AggregationData> {
requestContext.filter = sequenceFilters

return siloQueryModel.aggregate(sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) })
return siloQueryModel.aggregate(
sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) },
fields,
)
}

@PostMapping("/aggregated")
@LapisAggregatedResponse
fun postAggregated(
@Parameter(schema = Schema(ref = "#/components/schemas/$REQUEST_SCHEMA"))
@RequestBody
sequenceFilters: Map<String, String>,
@Parameter(schema = Schema(ref = "#/components/schemas/$REQUEST_SCHEMA_WITH_GROUP_BY_FIELDS"))
@RequestBody()
request: AggregationRequest,
): List<AggregationData> {
requestContext.filter = sequenceFilters
requestContext.filter = request.sequenceFilters

return siloQueryModel.aggregate(sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) })
return siloQueryModel.aggregate(
request.sequenceFilters.filterKeys { !nonSequenceFilterFields.contains(it) },
request.fields,
)
}

@GetMapping("/nucleotideMutations")
Expand Down Expand Up @@ -118,7 +129,13 @@ class LapisController(private val siloQueryModel: SiloQueryModel, private val re
responseCode = "200",
description = "OK",
content = [
Content(array = ArraySchema(schema = Schema(implementation = AggregationData::class))),
Content(
array = ArraySchema(
schema = Schema(
ref = "#/components/schemas/$RESPONSE_SCHEMA_AGGREGATED",
),
),
),
],
),
ApiResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ class SiloQueryModel(
private val siloFilterExpressionMapper: SiloFilterExpressionMapper,
) {

fun aggregate(sequenceFilters: Map<SequenceFilterFieldName, String>) = siloClient.sendQuery(
fun aggregate(
sequenceFilters: Map<SequenceFilterFieldName, String>,
groupByFields: List<SequenceFilterFieldName> = emptyList(),
) = siloClient.sendQuery(
SiloQuery(
SiloAction.aggregated(),
SiloAction.aggregated(groupByFields),
siloFilterExpressionMapper.map(sequenceFilters),
),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.genspectrum.lapis.request

import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonDeserializer
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.ArrayNode
import org.springframework.boot.jackson.JsonComponent

data class AggregationRequest(
val sequenceFilters: Map<String, String>,
val fields: List<String>,
)

@JsonComponent
class AggregationRequestDeserializer : JsonDeserializer<AggregationRequest>() {
override fun deserialize(p: JsonParser, ctxt: DeserializationContext): AggregationRequest {
val node = p.readValueAsTree<JsonNode>()

val fields = when (node.get("fields")) {
null -> emptyList()
is ArrayNode -> node.get("fields").asSequence().map { it.asText() }.toList()
else -> throw IllegalArgumentException("Fields in AggregationRequest must be an array or null")
}

val sequenceFilters =
node.fields().asSequence().filter { it.key != "fields" }.associate { it.key to it.value.asText() }

return AggregationRequest(sequenceFilters, fields)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ import com.fasterxml.jackson.databind.SerializerProvider
import io.swagger.v3.oas.annotations.media.Schema
import org.springframework.boot.jackson.JsonComponent

@Schema(
description = "This type will have additional fields for every value in the \"fields\" parameter of the request " +
"with its respective value. The \"count\" field is always present.",
additionalProperties = Schema.AdditionalPropertiesValue.TRUE,
)
data class AggregationData(val count: Int, @Schema(hidden = true) val fields: Map<String, JsonNode>)

@JsonComponent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ data class SiloQuery<ResponseType>(val action: SiloAction<ResponseType>, val fil

sealed class SiloAction<ResponseType>(@JsonIgnore val typeReference: TypeReference<SiloQueryResponse<ResponseType>>) {
companion object {
fun aggregated(): SiloAction<List<AggregationData>> = AggregatedAction("Aggregated")
fun aggregated(groupByFields: List<String> = emptyList()): SiloAction<List<AggregationData>> =
AggregatedAction("Aggregated", groupByFields)

fun mutations(minProportion: Double? = null): SiloAction<List<MutationData>> =
MutationsAction("Mutations", minProportion)
}

private data class AggregatedAction(val type: String) :
@JsonInclude(JsonInclude.Include.NON_EMPTY)
private data class AggregatedAction(val type: String, val groupByFields: List<String>) :
SiloAction<List<AggregationData>>(object : TypeReference<SiloQueryResponse<List<AggregationData>>>() {})

@JsonInclude(JsonInclude.Include.NON_NULL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ExceptionHandlerTest(@Autowired val mockMvc: MockMvc) {
}

private val validRoute = "/aggregated"
private fun MockKMatcherScope.validControllerCall() = lapisController.aggregated(any())
private fun MockKMatcherScope.validControllerCall() = lapisController.aggregated(any(), any())
private val validResponse = emptyList<AggregationData>()

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,51 @@ class LapisControllerTest(@Autowired val mockMvc: MockMvc) {
.andExpect(jsonPath("\$[0].count").value(0))
}

@Test
fun `GET aggregated with fields`() {
every {
siloQueryModelMock.aggregate(
mapOf("country" to "Switzerland"),
listOf("country", "age"),
)
} returns listOf(
AggregationData(
0,
mapOf("country" to TextNode("Switzerland"), "age" to IntNode(42)),
),
)

mockMvc.perform(get("/aggregated?country=Switzerland&fields=country,age"))
.andExpect(status().isOk)
.andExpect(jsonPath("\$[0].count").value(0))
.andExpect(jsonPath("\$[0].country").value("Switzerland"))
.andExpect(jsonPath("\$[0].age").value(42))
}

@Test
fun `POST aggregated with fields`() {
every {
siloQueryModelMock.aggregate(
mapOf("country" to "Switzerland"),
listOf("country", "age"),
)
} returns listOf(
AggregationData(
0,
mapOf("country" to TextNode("Switzerland"), "age" to IntNode(42)),
),
)
val request = post("/aggregated")
.content("""{"country": "Switzerland", "fields": ["country","age"]}""")
.contentType(MediaType.APPLICATION_JSON)

mockMvc.perform(request)
.andExpect(status().isOk)
.andExpect(jsonPath("\$[0].count").value(0))
.andExpect(jsonPath("\$[0].country").value("Switzerland"))
.andExpect(jsonPath("\$[0].age").value(42))
}

@Test
fun `GET nucleotideMutations without explicit minProportion defaults to 5 percent`() {
every {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ class SiloQueryModelTest {
every { siloClientMock.sendQuery(any<SiloQuery<List<AggregationData>>>()) } returns emptyList()
every { siloFilterExpressionMapperMock.map(any<Map<String, String>>()) } returns True

underTest.aggregate(emptyMap())
underTest.aggregate(emptyMap(), emptyList())

verify {
siloClientMock.sendQuery(
SiloQuery(SiloAction.aggregated(), True),
SiloQuery(SiloAction.aggregated(emptyList()), True),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.genspectrum.lapis.request

import com.fasterxml.jackson.databind.ObjectMapper
import org.hamcrest.MatcherAssert
import org.hamcrest.Matchers
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.MethodSource
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.context.SpringBootTest

@SpringBootTest
class AggregationRequestDeserializerTest {
@Autowired
private lateinit var objectMapper: ObjectMapper

@ParameterizedTest(name = "Test AggregationRequestDeserializer {1}")
@MethodSource("getTestAggregationRequest")
fun `AggregationRequest is correctly deserialized from JSON`(underTest: String, expected: AggregationRequest) {
val result = objectMapper.readValue(underTest, AggregationRequest::class.java)

MatcherAssert.assertThat(result, Matchers.equalTo(expected))
}

companion object {
@JvmStatic
fun getTestAggregationRequest() = listOf(
Arguments.of(
"""
{
"country": "Switzerland",
"fields": ["division", "country"]
}
""",
AggregationRequest(
mapOf("country" to "Switzerland"),
listOf("division", "country"),
),
),
Arguments.of(
"""
{
"country": "Switzerland"
}
""",
AggregationRequest(
mapOf("country" to "Switzerland"),
emptyList(),
),
),

)
}

@Test
fun `Given an AggregationRequest with fields not null or ArrayList it should return an error`() {
val underTest = """
{
"country": "Switzerland",
"fields": "notAnArrayNode"
}
"""

assertThrows(IllegalArgumentException::class.java) {
objectMapper.readValue(underTest, AggregationRequest::class.java)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ class SiloQueryTest {
}
""",
),
Arguments.of(
SiloAction.aggregated(listOf("field1", "field2")),
"""
{
"type": "Aggregated",
"groupByFields": ["field1", "field2"]
}
""",
),
Arguments.of(
SiloAction.mutations(),
"""
Expand Down
Loading

0 comments on commit d183a0f

Please sign in to comment.