Skip to content

Commit

Permalink
fix: return valid gzip when LAPIS returns an error response #656
Browse files Browse the repository at this point in the history
Apparently the following happened:
* The gzip stream eagerly writes the gzip header to the underlying stream (in its constructor)
* an error occurs
* Spring's error handling resets the response buffer (erasing the gzip header)
* The exception handler writes a messages, which gets gzip compressed.

Fix by: Delay initializing the gzip output stream until something is written to it.
  • Loading branch information
fengelniederhammer committed Feb 21, 2024
1 parent dbd48e4 commit db1c6cb
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import java.util.zip.GZIPOutputStream
private val log = KotlinLogging.logger {}

enum class Compression(val value: String, val compressionOutputStreamFactory: (OutputStream) -> OutputStream) {
GZIP("gzip", ::GZIPOutputStream),
GZIP("gzip", ::LazyGzipOutputStream),
ZSTD("zstd", { ZstdOutputStream(it).apply { commitUnderlyingResponseToPreventContentLengthFromBeingSet() } }),
;

Expand All @@ -52,6 +52,24 @@ enum class Compression(val value: String, val compressionOutputStreamFactory: (O
}
}

class LazyGzipOutputStream(outputStream: OutputStream) : OutputStream() {
private val gzipOutputStream by lazy { GZIPOutputStream(outputStream) }

override fun write(byte: Int) = gzipOutputStream.write(byte)

override fun write(bytes: ByteArray) = gzipOutputStream.write(bytes)

override fun write(
bytes: ByteArray,
offset: Int,
length: Int,
) = gzipOutputStream.write(bytes, offset, length)

override fun flush() = gzipOutputStream.flush()

override fun close() = gzipOutputStream.close()
}

// https://github.com/apache/tomcat/blob/10e3731f344cd0d018d4be2ee767c105d2832283/java/org/apache/catalina/connector/OutputBuffer.java#L223-L229
fun ZstdOutputStream.commitUnderlyingResponseToPreventContentLengthFromBeingSet() {
val nothing = ByteArray(0)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.genspectrum.lapis.controller

import com.github.luben.zstd.ZstdInputStream
import com.jayway.jsonpath.JsonPath
import com.ninjasquad.springmockk.MockkBean
import io.mockk.every
import org.genspectrum.lapis.controller.SampleRoute.AGGREGATED
Expand All @@ -9,7 +10,9 @@ import org.genspectrum.lapis.controller.SampleRoute.ALIGNED_NUCLEOTIDE_SEQUENCES
import org.genspectrum.lapis.controller.SampleRoute.UNALIGNED_NUCLEOTIDE_SEQUENCES
import org.genspectrum.lapis.model.SiloQueryModel
import org.genspectrum.lapis.request.LapisInfo
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.Matchers.containsString
import org.hamcrest.Matchers.`is`
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
Expand All @@ -23,6 +26,7 @@ import org.springframework.http.HttpHeaders.CONTENT_LENGTH
import org.springframework.http.MediaType.APPLICATION_JSON
import org.springframework.http.MediaType.APPLICATION_JSON_VALUE
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.MvcResult
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.content
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.header
Expand Down Expand Up @@ -66,23 +70,52 @@ class LapisControllerCompressionTest(
fun `endpoints return compressed data`(requestsScenario: RequestScenario) {
requestsScenario.mockData.mockWithData(siloQueryModelMock)

val content = mockMvc.perform(requestsScenario.request)
val response = mockMvc.perform(requestsScenario.request)
.andExpect(status().isOk)
.andExpect(content().contentType(requestsScenario.expectedContentType))
.andExpect(header().doesNotExist(CONTENT_LENGTH))
.andExpect(header().string(CONTENT_ENCODING, requestsScenario.compressionFormat))
.andReturn()
.response
.contentAsByteArray

val decompressedStream = when (requestsScenario.compressionFormat) {
val compressionFormat = requestsScenario.compressionFormat

val decompressedContent = decompressContent(response, compressionFormat)

requestsScenario.mockData.assertDataMatches(decompressedContent)
}

@ParameterizedTest
@MethodSource("getCompressionFormats")
fun `GIVEN model throws bad request WHEN requesting compressed data THEN it should return compressed error`(
compressionFormat: String,
) {
val errorMessage = "test message"
every { siloQueryModelMock.getAggregated(any()) } throws BadRequestException(errorMessage)

val response = mockMvc.perform(getSample("${AGGREGATED.pathSegment}?compression=$compressionFormat"))
.andExpect(status().isBadRequest)
.andExpect(header().string(CONTENT_ENCODING, compressionFormat))
.andReturn()

val decompressedContent = decompressContent(response, compressionFormat)

val errorDetail = JsonPath.read<String>(decompressedContent, "$.error.detail")
assertThat(errorDetail, `is`(errorMessage))
}

private fun decompressContent(
response: MvcResult,
compressionFormat: String,
): String {
val content = response.response.contentAsByteArray

val decompressedStream = when (compressionFormat) {
Compression.GZIP.value -> GZIPInputStream(content.inputStream())
Compression.ZSTD.value -> ZstdInputStream(content.inputStream())
else -> throw Exception("Test issue: unknown compression format ${requestsScenario.compressionFormat}")
else -> throw Exception("Test issue: unknown compression format $compressionFormat")
}
val decompressedContent = decompressedStream.readAllBytes().decodeToString()

requestsScenario.mockData.assertDataMatches(decompressedContent)
return decompressedStream.readAllBytes().decodeToString()
}

private companion object {
Expand Down Expand Up @@ -132,6 +165,9 @@ class LapisControllerCompressionTest(
"${ALIGNED_AMINO_ACID_SEQUENCES.pathSegment}/gene1",
)
.flatMap { getFastaRequests(it, "gzip") + getFastaRequests(it, "zstd") }

@JvmStatic
val compressionFormats = listOf("gzip", "zstd")
}
}

Expand Down

0 comments on commit db1c6cb

Please sign in to comment.