Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen: better enum handling in query params #4385

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ class EndpointGenerator {
val maybeTargetFileName = if (useHeadTagForObjectNames) m.tags.flatMap(_.headOption) else None
val queryOrPathParamRefs = m.resolvedParameters
.collect { case queryParam: OpenapiParameter if queryParam.in == "query" || queryParam.in == "path" => queryParam.schema }
.collect { case ref: OpenapiSchemaRef if ref.isSchema => ref.stripped }
.collect {
case ref: OpenapiSchemaRef if ref.isSchema => ref.stripped
case OpenapiSchemaArray(ref: OpenapiSchemaRef, _) if ref.isSchema => ref.stripped
}
.toSet
val jsonParamRefs = (m.requestBody.toSeq.flatMap(_.content.map(c => (c.contentType, c.schema))) ++
m.responses.flatMap(_.content.map(c => (c.contentType, c.schema))))
Expand Down Expand Up @@ -264,6 +267,12 @@ class EndpointGenerator {
streamingImplementation: StreamingImplementation,
doc: OpenapiDocument
)(implicit location: Location): (String, Option[String], Seq[String]) = {
def toOutType(baseType: String, isArray: Boolean, noOptionWrapper: Boolean) = (isArray, noOptionWrapper) match {
case (true, true) => s"List[$baseType]"
case (true, false) => s"Option[List[$baseType]]"
case (false, true) => baseType
case (false, false) => s"Option[$baseType]"
}
def getEnumParamDefn(param: OpenapiParameter, e: OpenapiSchemaEnum, isArray: Boolean) = {
val enumName = endpointName.capitalize + strippedToCamelCase(param.name).capitalize
val enumParamRefs = if (param.in == "query" || param.in == "path") Set(enumName) else Set.empty[String]
Expand All @@ -283,12 +292,7 @@ class EndpointGenerator {
// 'exploded' params have no distinction between an empty list and an absent value, so don't wrap in 'Option' for them
val noOptionWrapper = required || (isArray && param.isExploded)
val req = if (noOptionWrapper) tpe else s"Option[$tpe]"
val outType = (isArray, noOptionWrapper) match {
case (true, true) => s"List[$enumName]"
case (true, false) => s"Option[List[$enumName]]"
case (false, true) => enumName
case (false, false) => s"Option[$enumName]"
}
val outType = toOutType(enumName, isArray, noOptionWrapper)

def mapToList =
if (!isArray) "" else if (noOptionWrapper) s".map(_.values)($arrayType(_))" else s".map(_.map(_.values))(_.map($arrayType(_)))"
Expand Down Expand Up @@ -320,7 +324,8 @@ class EndpointGenerator {
def mapToList = if (noOptionWrapper) s".map(_.values)($arrayType(_))" else s".map(_.map(_.values))(_.map($arrayType(_)))"

val desc = param.description.map(d => JavaEscape.escapeString(d)).fold("")(d => s""".description("$d")""")
(s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""", None, req)
val outType = toOutType(t, true, noOptionWrapper)
(s""".in(${param.in}[$req]("${param.name}")$mapToList$desc)""", None, outType)
case e @ OpenapiSchemaEnum(_, _, _) => getEnumParamDefn(param, e, isArray = false)
case OpenapiSchemaArray(e: OpenapiSchemaEnum, _) => getEnumParamDefn(param, e, isArray = true)
case x => bail(s"Can't create non-simple params to input - found $x")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,24 @@ object TapirGeneratedEndpoints {
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
}


case class EnumExtraParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends ExtraParamSupport[T] {
// Case-insensitive mapping
def decode(s: String): sttp.tapir.DecodeResult[T] =
scala.util.Try(T.upperCaseNameValuesToMap(s.toUpperCase))
.fold(
_ =>
sttp.tapir.DecodeResult.Error(
s,
new NoSuchElementException(
s"Could not find value $s for enum ${enumName}, available values: ${T.values.mkString(", ")}"
)
),
sttp.tapir.DecodeResult.Value(_)
)
def encode(t: T): String = t.entryName
}
def extraCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): ExtraParamSupport[T] =
EnumExtraParamSupport(enumName, T)
sealed trait ADTWithoutDiscriminator
sealed trait ADTWithDiscriminator
sealed trait ADTWithDiscriminatorNoMapping
Expand Down Expand Up @@ -81,36 +98,42 @@ object TapirGeneratedEndpoints {
case object Foo extends AnEnum
case object Bar extends AnEnum
case object Baz extends AnEnum
implicit val enumCodecSupportAnEnum: ExtraParamSupport[AnEnum] =
extraCodecSupport[AnEnum]("AnEnum", AnEnum)
}



lazy val putAdtTest =
type PutAdtTestEndpoint = Endpoint[Unit, ADTWithoutDiscriminator, Unit, ADTWithoutDiscriminator, Any]
lazy val putAdtTest: PutAdtTestEndpoint =
endpoint
.put
.in(("adt" / "test"))
.in(jsonBody[ADTWithoutDiscriminator])
.out(jsonBody[ADTWithoutDiscriminator].description("successful operation"))

lazy val postAdtTest =
type PostAdtTestEndpoint = Endpoint[Unit, ADTWithDiscriminatorNoMapping, Unit, ADTWithDiscriminator, Any]
lazy val postAdtTest: PostAdtTestEndpoint =
endpoint
.post
.in(("adt" / "test"))
.in(jsonBody[ADTWithDiscriminatorNoMapping])
.out(jsonBody[ADTWithDiscriminator].description("successful operation"))

lazy val getOneofOptionTest =
type GetOneofOptionTestEndpoint = Endpoint[Unit, Unit, Unit, Option[AnEnum], Any]
lazy val getOneofOptionTest: GetOneofOptionTestEndpoint =
endpoint
.get
.in(("oneof" / "option" / "test"))
.out(oneOf[Option[AnEnum]](
oneOfVariantSingletonMatcher(sttp.model.StatusCode(204), emptyOutput.description("No response"))(None),
oneOfVariantValueMatcher(sttp.model.StatusCode(200), jsonBody[Option[AnEnum]].description("An enum")){ case Some(_: AnEnum) => true }))

lazy val postGenericJson =
type PostGenericJsonEndpoint = Endpoint[Unit, (Option[List[AnEnum]], Option[io.circe.Json]), Unit, io.circe.Json, Any]
lazy val postGenericJson: PostGenericJsonEndpoint =
endpoint
.post
.in(("generic" / "json"))
.in(query[Option[CommaSeparatedValues[AnEnum]]]("aTrickyParam").map(_.map(_.values))(_.map(CommaSeparatedValues(_))).description("A very thorough description"))
.in(jsonBody[Option[io.circe.Json]])
.out(jsonBody[io.circe.Json].description("anything back"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ lazy val root = (project in file("."))
.settings(
scalaVersion := "2.13.16",
version := "0.1",
openapiJsonSerdeLib := "jsoniter"
openapiJsonSerdeLib := "jsoniter",
openapiGenerateEndpointTypes := true
)

libraryDependencies ++= Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,28 @@ paths:
$ref: '#/components/schemas/AnEnum'
'/generic/json':
post:
parameters:
- in: query
name: aTrickyParam
style: form
explode: false
required: false
description: A very thorough description
schema:
type: array
items:
$ref: '#/components/schemas/AnEnum'
requestBody:
description: anything
content:
application/json:
schema: {}
schema: { }
responses:
"200":
description: anything back
content:
application/json:
schema: {}
schema: { }

components:
schemas:
Expand Down
Loading