diff --git a/lapis2/src/main/kotlin/org/genspectrum/lapis/silo/SiloClient.kt b/lapis2/src/main/kotlin/org/genspectrum/lapis/silo/SiloClient.kt index 1366baaa..070aab85 100644 --- a/lapis2/src/main/kotlin/org/genspectrum/lapis/silo/SiloClient.kt +++ b/lapis2/src/main/kotlin/org/genspectrum/lapis/silo/SiloClient.kt @@ -3,6 +3,8 @@ package org.genspectrum.lapis.silo import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.kotlin.readValue import mu.KotlinLogging +import org.genspectrum.lapis.logging.RequestIdContext +import org.genspectrum.lapis.openApi.REQUEST_ID_HEADER import org.genspectrum.lapis.response.InfoData import org.springframework.beans.factory.annotation.Value import org.springframework.http.HttpHeaders @@ -24,19 +26,19 @@ class SiloClient( @Value("\${silo.url}") private val siloUrl: String, private val objectMapper: ObjectMapper, private val dataVersion: DataVersion, + private val requestIdContext: RequestIdContext, ) { + private val httpClient = HttpClient.newHttpClient() + fun sendQuery(query: SiloQuery): ResponseType { val queryJson = objectMapper.writeValueAsString(query) log.info { "Calling SILO: $queryJson" } - val client = HttpClient.newHttpClient() - val request = HttpRequest.newBuilder(URI("$siloUrl/query")) - .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) - .POST(HttpRequest.BodyPublishers.ofString(queryJson)) - .build() - - val response = send(client, request) + val response = send(URI("$siloUrl/query")) { + it.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .POST(HttpRequest.BodyPublishers.ofString(queryJson)) + } try { return objectMapper.readValue(response.body(), query.action.typeReference).queryResult @@ -49,20 +51,26 @@ class SiloClient( fun callInfo(): InfoData { log.info { "Calling SILO info" } - val client = HttpClient.newHttpClient() - val request = HttpRequest.newBuilder(URI("$siloUrl/info")).GET().build() - - val response = send(client, request) + val response = send(URI("$siloUrl/info")) { it.GET() } return InfoData(getDataVersion(response)) } private fun send( - client: HttpClient, - request: HttpRequest?, + uri: URI, + buildRequest: (HttpRequest.Builder) -> Unit, ): HttpResponse { + val request = HttpRequest.newBuilder(uri) + .apply(buildRequest) + .apply { + if (requestIdContext.requestId != null) { + header(REQUEST_ID_HEADER, requestIdContext.requestId) + } + } + .build() + val response = try { - client.send(request, BodyHandlers.ofString()) + httpClient.send(request, BodyHandlers.ofString()) } catch (exception: Exception) { val message = "Could not connect to silo: " + exception::class.toString() + " " + exception.message throw RuntimeException(message, exception) diff --git a/lapis2/src/test/kotlin/org/genspectrum/lapis/silo/SiloClientTest.kt b/lapis2/src/test/kotlin/org/genspectrum/lapis/silo/SiloClientTest.kt index 91d3c04c..995aaf3f 100644 --- a/lapis2/src/test/kotlin/org/genspectrum/lapis/silo/SiloClientTest.kt +++ b/lapis2/src/test/kotlin/org/genspectrum/lapis/silo/SiloClientTest.kt @@ -3,6 +3,7 @@ package org.genspectrum.lapis.silo import com.fasterxml.jackson.databind.node.DoubleNode import com.fasterxml.jackson.databind.node.IntNode import com.fasterxml.jackson.databind.node.TextNode +import org.genspectrum.lapis.logging.RequestIdContext import org.genspectrum.lapis.response.AggregationData import org.genspectrum.lapis.response.DetailsData import org.genspectrum.lapis.response.MutationData @@ -29,18 +30,21 @@ import org.springframework.boot.test.context.SpringBootTest private const val MOCK_SERVER_PORT = 1080 +private const val REQUEST_ID_VALUE = "someRequestId" + @SpringBootTest(properties = ["silo.url=http://localhost:$MOCK_SERVER_PORT"]) -class SiloClientTest { +class SiloClientTest( + @Autowired private val underTest: SiloClient, + @Autowired private val requestIdContext: RequestIdContext, +) { private lateinit var mockServer: ClientAndServer - @Autowired - private lateinit var underTest: SiloClient - private val someQuery = SiloQuery(SiloAction.aggregated(), StringEquals("theColumn", "theValue")) @BeforeEach fun setupMockServer() { mockServer = ClientAndServer.startClientAndServer(MOCK_SERVER_PORT) + requestIdContext.requestId = REQUEST_ID_VALUE } @AfterEach @@ -337,7 +341,8 @@ class SiloClientTest { request() .withMethod("POST") .withPath("/query") - .withContentType(MediaType.APPLICATION_JSON), + .withContentType(MediaType.APPLICATION_JSON) + .withHeader("X-Request-Id", REQUEST_ID_VALUE), ) .respond(httpResponse) }