Skip to content

Commit

Permalink
feat: throw more specialized exception to handle bad requests
Browse files Browse the repository at this point in the history
We don't know where IllegalArgumentExceptions are thrown in library code and whether this is a "Bad Request" for us.
  • Loading branch information
fengelniederhammer committed Oct 17, 2023
1 parent 435a140 commit b8b86c2
Show file tree
Hide file tree
Showing 18 changed files with 82 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class ExceptionHandler : ResponseEntityExceptionHandler() {
return responseEntity(HttpStatus.INTERNAL_SERVER_ERROR, e.message)
}

@ExceptionHandler(IllegalArgumentException::class)
@ExceptionHandler(BadRequestException::class)
@ResponseStatus(HttpStatus.BAD_REQUEST)
fun handleIllegalArgumentException(e: IllegalArgumentException): ErrorResponse {
log.warn(e) { "Caught IllegalArgumentException: ${e.message}" }
fun handleBadRequestException(e: BadRequestException): ErrorResponse {
log.warn(e) { "Caught BadRequestException: ${e.message}" }

return responseEntity(HttpStatus.BAD_REQUEST, e.message)
}
Expand Down Expand Up @@ -83,3 +83,5 @@ class ExceptionHandler : ResponseEntityExceptionHandler() {

/** This is not yet actually thrown, but makes "403 Forbidden" appear in OpenAPI docs. */
class AddForbiddenToOpenApiDocsHelper(message: String) : Exception(message)

class BadRequestException(message: String, cause: Throwable? = null) : Exception(message, cause)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.genspectrum.lapis.model

import org.genspectrum.lapis.config.SequenceFilterFieldType
import org.genspectrum.lapis.config.SequenceFilterFields
import org.genspectrum.lapis.controller.BadRequestException
import org.genspectrum.lapis.request.AminoAcidInsertion
import org.genspectrum.lapis.request.AminoAcidMutation
import org.genspectrum.lapis.request.CommonSequenceFilters
Expand Down Expand Up @@ -103,7 +104,7 @@ class SiloFilterExpressionMapper(
is SequenceFilterFieldType.FloatFrom -> Pair(type.associatedField, Filter.FloatBetween)
is SequenceFilterFieldType.FloatTo -> Pair(type.associatedField, Filter.FloatBetween)

null -> throw IllegalArgumentException(
null -> throw BadRequestException(
"'$key' is not a valid sequence filter key. Valid keys are: " +
allowedSequenceFilterFields.fields.keys.joinToString(),
)
Expand All @@ -122,7 +123,7 @@ class SiloFilterExpressionMapper(
aaMutations.isNotEmpty()

if (containsAdvancedVariantQuery && containsSimpleVariantQuery) {
throw IllegalArgumentException(
throw BadRequestException(
"variantQuery filter cannot be used with other variant filters such as: " +
variantQueryTypes.joinToString(", "),
)
Expand All @@ -135,7 +136,7 @@ class SiloFilterExpressionMapper(
]

if (intBetweenFilterForSameColumn != null) {
throw IllegalArgumentException(
throw BadRequestException(
"Cannot filter by exact int field '$intEqualsColumnName' " +
"and by int range field '${intBetweenFilterForSameColumn[0].originalKey}'.",
)
Expand All @@ -149,7 +150,7 @@ class SiloFilterExpressionMapper(
]

if (floatBetweenFilterForSameColumn != null) {
throw IllegalArgumentException(
throw BadRequestException(
"Cannot filter by exact float field '$floatEqualsColumnName' " +
"and by float range field '${floatBetweenFilterForSameColumn[0].originalKey}'.",
)
Expand All @@ -159,9 +160,7 @@ class SiloFilterExpressionMapper(

private fun mapToVariantQueryFilter(variantQuery: String): SiloFilterExpression {
if (variantQuery.isBlank()) {
throw IllegalArgumentException(
"variantQuery must not be empty",
)
throw BadRequestException("variantQuery must not be empty")
}

return variantQueryFacade.map(variantQuery)
Expand All @@ -176,7 +175,7 @@ class SiloFilterExpressionMapper(
}

if (exactDateFilters.isNotEmpty() && dateRangeFilters.isNotEmpty()) {
throw IllegalArgumentException(
throw BadRequestException(
"Cannot filter by exact date field '${exactDateFilters[0].originalKey}' " +
"and by date range field '${dateRangeFilters[0].originalKey}'.",
)
Expand Down Expand Up @@ -211,14 +210,14 @@ class SiloFilterExpressionMapper(
try {
return LocalDate.parse(value)
} catch (exception: DateTimeParseException) {
throw IllegalArgumentException("$originalKey '$value' is not a valid date: ${exception.message}", exception)
throw BadRequestException("$originalKey '$value' is not a valid date: ${exception.message}", exception)
}
}

private fun mapToPangoLineageFilter(column: String, value: String) = when {
value.endsWith(".*") -> PangoLineageEquals(column, value.substringBeforeLast(".*"), includeSublineages = true)
value.endsWith('*') -> PangoLineageEquals(column, value.substringBeforeLast('*'), includeSublineages = true)
value.endsWith('.') -> throw IllegalArgumentException(
value.endsWith('.') -> throw BadRequestException(
"Invalid pango lineage: $value must not end with a dot. Did you mean '$value*'?",
)

Expand All @@ -233,7 +232,7 @@ class SiloFilterExpressionMapper(
try {
return IntEquals(siloColumnName, value.toInt())
} catch (exception: NumberFormatException) {
throw IllegalArgumentException(
throw BadRequestException(
"$siloColumnName '$value' is not a valid integer: ${exception.message}",
exception,
)
Expand All @@ -248,7 +247,7 @@ class SiloFilterExpressionMapper(
try {
return FloatEquals(siloColumnName, value.toDouble())
} catch (exception: NumberFormatException) {
throw IllegalArgumentException(
throw BadRequestException(
"$siloColumnName '$value' is not a valid float: ${exception.message}",
exception,
)
Expand All @@ -274,7 +273,7 @@ class SiloFilterExpressionMapper(
try {
return value.toInt()
} catch (exception: NumberFormatException) {
throw IllegalArgumentException(
throw BadRequestException(
"$originalKey '$value' is not a valid integer: ${exception.message}",
exception,
)
Expand All @@ -300,7 +299,7 @@ class SiloFilterExpressionMapper(
try {
return value.toDouble()
} catch (exception: NumberFormatException) {
throw IllegalArgumentException(
throw BadRequestException(
"$originalKey '$value' is not a valid float: ${exception.message}",
exception,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.genspectrum.lapis.request
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonDeserializer
import org.genspectrum.lapis.controller.BadRequestException
import org.springframework.boot.jackson.JsonComponent
import org.springframework.core.convert.converter.Converter
import org.springframework.stereotype.Component
Expand All @@ -11,22 +12,22 @@ data class AminoAcidInsertion(val position: Int, val gene: String, val insertion
companion object {
fun fromString(aminoAcidInsertion: String): AminoAcidInsertion {
val match = AMINO_ACID_INSERTION_REGEX.find(aminoAcidInsertion)
?: throw IllegalArgumentException("Invalid nucleotide mutation: $aminoAcidInsertion")
?: throw BadRequestException("Invalid nucleotide mutation: $aminoAcidInsertion")

val matchGroups = match.groups

val position = matchGroups["position"]?.value?.toInt()
?: throw IllegalArgumentException(
?: throw BadRequestException(
"Invalid amino acid insertion: $aminoAcidInsertion: Did not find position",
)

val gene = matchGroups["gene"]?.value
?: throw IllegalArgumentException(
?: throw BadRequestException(
"Invalid amino acid insertion: $aminoAcidInsertion: Did not find gene",
)

val insertions = matchGroups["insertions"]?.value?.replace("?", ".*")
?: throw IllegalArgumentException(
?: throw BadRequestException(
"Invalid amino acid insertion: $aminoAcidInsertion: Did not find insertions",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.genspectrum.lapis.request
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonDeserializer
import org.genspectrum.lapis.controller.BadRequestException
import org.springframework.boot.jackson.JsonComponent
import org.springframework.core.convert.converter.Converter
import org.springframework.stereotype.Component
Expand All @@ -11,14 +12,14 @@ data class AminoAcidMutation(val gene: String, val position: Int, val symbol: St
companion object {
fun fromString(aminoAcidMutation: String): AminoAcidMutation {
val match = AMINO_ACID_MUTATION_REGEX.find(aminoAcidMutation)
?: throw IllegalArgumentException("Invalid amino acid mutation: $aminoAcidMutation")
?: throw BadRequestException("Invalid amino acid mutation: $aminoAcidMutation")

val matchGroups = match.groups

val gene = matchGroups["gene"]?.value
?: throw IllegalArgumentException("Invalid amino acid mutation: $aminoAcidMutation: Did not find gene")
?: throw BadRequestException("Invalid amino acid mutation: $aminoAcidMutation: Did not find gene")
val position = matchGroups["position"]?.value?.toInt()
?: throw IllegalArgumentException(
?: throw BadRequestException(
"Invalid amino acid mutation: $aminoAcidMutation: Did not find position",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.node.ArrayNode
import com.fasterxml.jackson.databind.node.JsonNodeType
import org.genspectrum.lapis.controller.AMINO_ACID_INSERTIONS_PROPERTY
import org.genspectrum.lapis.controller.AMINO_ACID_MUTATIONS_PROPERTY
import org.genspectrum.lapis.controller.BadRequestException
import org.genspectrum.lapis.controller.LIMIT_PROPERTY
import org.genspectrum.lapis.controller.NUCLEOTIDE_INSERTIONS_PROPERTY
import org.genspectrum.lapis.controller.NUCLEOTIDE_MUTATIONS_PROPERTY
Expand All @@ -32,39 +33,39 @@ fun parseCommonFields(node: JsonNode, codec: ObjectCodec): ParsedCommonFields {
val nucleotideMutations = when (val nucleotideMutationsNode = node.get(NUCLEOTIDE_MUTATIONS_PROPERTY)) {
null -> emptyList()
is ArrayNode -> nucleotideMutationsNode.map { codec.treeToValue(it, NucleotideMutation::class.java) }
else -> throw IllegalArgumentException(
else -> throw BadRequestException(
"nucleotideMutations must be an array or null",
)
}

val aminoAcidMutations = when (val aminoAcidMutationsNode = node.get(AMINO_ACID_MUTATIONS_PROPERTY)) {
null -> emptyList()
is ArrayNode -> aminoAcidMutationsNode.map { codec.treeToValue(it, AminoAcidMutation::class.java) }
else -> throw IllegalArgumentException(
else -> throw BadRequestException(
"aminoAcidMutations must be an array or null",
)
}

val nucleotideInsertions = when (val nucleotideInsertionsNode = node.get(NUCLEOTIDE_INSERTIONS_PROPERTY)) {
null -> emptyList()
is ArrayNode -> nucleotideInsertionsNode.map { codec.treeToValue(it, NucleotideInsertion::class.java) }
else -> throw IllegalArgumentException(
else -> throw BadRequestException(
"nucleotideInsertions must be an array or null",
)
}

val aminoAcidInsertions = when (val aminoAcidInsertionsNode = node.get(AMINO_ACID_INSERTIONS_PROPERTY)) {
null -> emptyList()
is ArrayNode -> aminoAcidInsertionsNode.map { codec.treeToValue(it, AminoAcidInsertion::class.java) }
else -> throw IllegalArgumentException(
else -> throw BadRequestException(
"aminoAcidInsertions must be an array or null",
)
}

val orderByFields = when (val orderByNode = node.get(ORDER_BY_PROPERTY)) {
null -> emptyList()
is ArrayNode -> orderByNode.map { codec.treeToValue(it, OrderByField::class.java) }
else -> throw IllegalArgumentException(
else -> throw BadRequestException(
"orderBy must be an array or null",
)
}
Expand All @@ -73,14 +74,14 @@ fun parseCommonFields(node: JsonNode, codec: ObjectCodec): ParsedCommonFields {
val limit = when (limitNode?.nodeType) {
null -> null
JsonNodeType.NULL, JsonNodeType.NUMBER -> limitNode.asInt()
else -> throw IllegalArgumentException("limit must be a number or null")
else -> throw BadRequestException("limit must be a number or null")
}

val offsetNode = node.get(OFFSET_PROPERTY)
val offset = when (offsetNode?.nodeType) {
null -> null
JsonNodeType.NULL, JsonNodeType.NUMBER -> offsetNode.asInt()
else -> throw IllegalArgumentException("offset must be a number or null")
else -> throw BadRequestException("offset must be a number or null")
}

val sequenceFilters = node.fields().asSequence().filter { isStringOrNumber(it.value) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.JsonDeserializer
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.NullNode
import com.fasterxml.jackson.databind.node.NumericNode
import org.genspectrum.lapis.controller.BadRequestException
import org.genspectrum.lapis.controller.MIN_PROPORTION_PROPERTY
import org.springframework.boot.jackson.JsonComponent

Expand All @@ -31,7 +32,7 @@ class MutationProportionsRequestDeserializer : JsonDeserializer<MutationProporti
null, is NullNode -> null
is NumericNode -> minProportionNode.doubleValue()

else -> throw IllegalArgumentException("minProportion must be a number")
else -> throw BadRequestException("minProportion must be a number")
}

val parsedCommonFields = parseCommonFields(node, codec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.genspectrum.lapis.request
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonDeserializer
import org.genspectrum.lapis.controller.BadRequestException
import org.springframework.boot.jackson.JsonComponent
import org.springframework.core.convert.converter.Converter
import org.springframework.stereotype.Component
Expand All @@ -11,17 +12,17 @@ data class NucleotideInsertion(val position: Int, val insertions: String, val se
companion object {
fun fromString(nucleotideInsertion: String): NucleotideInsertion {
val match = NUCLEOTIDE_INSERTION_REGEX.find(nucleotideInsertion)
?: throw IllegalArgumentException("Invalid nucleotide mutation: $nucleotideInsertion")
?: throw BadRequestException("Invalid nucleotide mutation: $nucleotideInsertion")

val matchGroups = match.groups

val position = matchGroups["position"]?.value?.toInt()
?: throw IllegalArgumentException(
?: throw BadRequestException(
"Invalid nucleotide insertion: $nucleotideInsertion: Did not find position",
)

val insertions = matchGroups["insertions"]?.value?.replace("?", ".*")
?: throw IllegalArgumentException(
?: throw BadRequestException(
"Invalid nucleotide insertion: $nucleotideInsertion: Did not find insertions",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.genspectrum.lapis.request
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonDeserializer
import org.genspectrum.lapis.controller.BadRequestException
import org.springframework.boot.jackson.JsonComponent
import org.springframework.core.convert.converter.Converter
import org.springframework.stereotype.Component
Expand All @@ -11,12 +12,12 @@ data class NucleotideMutation(val sequenceName: String?, val position: Int, val
companion object {
fun fromString(nucleotideMutation: String): NucleotideMutation {
val match = NUCLEOTIDE_MUTATION_REGEX.find(nucleotideMutation)
?: throw IllegalArgumentException("Invalid nucleotide mutation: $nucleotideMutation")
?: throw BadRequestException("Invalid nucleotide mutation: $nucleotideMutation")

val matchGroups = match.groups

val position = matchGroups["position"]?.value?.toInt()
?: throw IllegalArgumentException(
?: throw BadRequestException(
"Invalid nucleotide mutation: $nucleotideMutation: Did not find position",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.JsonDeserializer
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.ObjectNode
import com.fasterxml.jackson.databind.node.TextNode
import org.genspectrum.lapis.controller.BadRequestException
import org.springframework.boot.jackson.JsonComponent
import org.springframework.core.convert.converter.Converter
import org.springframework.stereotype.Component
Expand All @@ -32,20 +33,20 @@ class OrderByFieldDeserializer : JsonDeserializer<OrderByField>() {
return when (val value = jsonParser.readValueAsTree<JsonNode>()) {
is TextNode -> OrderByField(value.asText(), Order.ASCENDING)
is ObjectNode -> deserializeOrderByField(value)
else -> throw IllegalArgumentException("orderByField must be a string or an object")
else -> throw BadRequestException("orderByField must be a string or an object")
}
}

private fun deserializeOrderByField(value: ObjectNode): OrderByField {
val fieldNode = value.get("field")
if (fieldNode == null || fieldNode !is TextNode) {
throw IllegalArgumentException("orderByField must have a string property \"field\"")
throw BadRequestException("orderByField must have a string property \"field\"")
}

val ascending = when (value.get("type")?.asText()) {
"ascending", null -> Order.ASCENDING
"descending" -> Order.DESCENDING
else -> throw IllegalArgumentException("orderByField type must be \"ascending\" or \"descending\"")
else -> throw BadRequestException("orderByField type must be \"ascending\" or \"descending\"")
}

return OrderByField(fieldNode.asText(), ascending)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.DeserializationContext
import com.fasterxml.jackson.databind.JsonDeserializer
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.node.ArrayNode
import org.genspectrum.lapis.controller.BadRequestException
import org.genspectrum.lapis.controller.FIELDS_PROPERTY
import org.springframework.boot.jackson.JsonComponent

Expand All @@ -29,7 +30,7 @@ class SequenceFiltersRequestWithFieldsDeserializer : JsonDeserializer<SequenceFi
val fields = when (val fields = node.get(FIELDS_PROPERTY)) {
null -> emptyList()
is ArrayNode -> fields.asSequence().map { it.asText() }.toList()
else -> throw IllegalArgumentException(
else -> throw BadRequestException(
"fields must be an array or null",
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ExceptionHandlerTest(@Autowired val mockMvc: MockMvc) {

@Test
fun `throw BAD_REQUEST(400) with additional info for bad requests`() {
every { validControllerCall() } throws IllegalArgumentException("SomeMessage")
every { validControllerCall() } throws BadRequestException("SomeMessage")

mockMvc.perform(get(validRoute))
.andExpect(status().isBadRequest)
Expand Down
Loading

0 comments on commit b8b86c2

Please sign in to comment.