Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gen] newtype aliases #3002

Merged
merged 8 commits into from
Aug 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[gen] newtype aliases: basic implementation
hochgi committed Aug 15, 2024
commit f11b05d5d71c03fb0675ae8447690d2d0aff0aa3
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ object EndpointGen {
}
val obj = Code.Object(
name = name.capitalize,
extensions = Nil,
schema = false,
endpoints = endpoints.toMap,
objects = Nil,
115 changes: 86 additions & 29 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import zio.Chunk
import zio.http.Method
import zio.http.endpoint.openapi.OpenAPI.ReferenceOr
import zio.http.endpoint.openapi.{JsonSchema, OpenAPI}
import zio.http.gen.scala.Code.Collection
import zio.http.gen.scala.Code.{CodecType, Collection, PathSegmentCode}
import zio.http.gen.scala.{Code, CodeGen}

object EndpointGen {
@@ -79,6 +79,7 @@ final case class EndpointGen(config: Config) {
var subtypeToTraits = Map.empty[String, Set[String]]
var caseClassToSharedFields = Map.empty[String, Set[String]]
var nameToSchemaAndAnnotations = Map.empty[String, (JsonSchema, Chunk[JsonSchema.MetaData])]
var aliasedPrimitives = Set.empty[String]

openAPI.components.toList.foreach { components =>
components.schemas.foreach { case (OpenAPI.Key(name), refOrSchema) =>
@@ -112,7 +113,11 @@ final case class EndpointGen(config: Config) {
case None => Some(refNames.toSet)
}
}
case _ => // do nothing
case _ =>
// primitives that are aliased should be registered for a special Newtype treatment
if (schema.isPrimitive) {
aliasedPrimitives += name
}
}

nameToSchemaAndAnnotations = nameToSchemaAndAnnotations.updated(name, schema -> annotations)
@@ -176,9 +181,9 @@ final case class EndpointGen(config: Config) {
// The `mapType` function is used to alter any relevant part of each code file
noDuplicateFiles.map { cf =>
cf.copy(
objects = cf.objects.map(mapType(_.name, subtypeToTraits)),
caseClasses = cf.caseClasses.map(mapType(_.name, subtypeToTraits)),
enums = cf.enums.map(mapType(_.name, subtypeToTraits)),
objects = cf.objects.map(mapType(_.name, subtypeToTraits, aliasedPrimitives)),
caseClasses = cf.caseClasses.map(mapType(_.name, subtypeToTraits, aliasedPrimitives)),
enums = cf.enums.map(mapType(_.name, subtypeToTraits, aliasedPrimitives)),
)
}
}
@@ -192,6 +197,9 @@ final case class EndpointGen(config: Config) {
* function that alters a case class, and lifts it such that we can apply it
* to any structure, and it'll take care to recurse when needed.
*
* Another issue we fix here, is when we have aliased primitives. In that case
* we need to append `.Type` to the aliased primitive type.
*
* @param getEncapsulatingName
* used to get the name of the code structure we operate on
* (Object/CaseClass/Enum)
@@ -203,7 +211,11 @@ final case class EndpointGen(config: Config) {
* @return
* the modified structure
*/
def mapType[T <: Code.ScalaType](getEncapsulatingName: T => String, subtypeToTraits: Map[String, Set[String]])(
def mapType[T <: Code.ScalaType](
getEncapsulatingName: T => String,
subtypeToTraits: Map[String, Set[String]],
aliasedPrimitives: Set[String],
)(
codeStructureToAlter: T,
): T =
mapCaseClasses { cc =>
@@ -212,13 +224,18 @@ final case class EndpointGen(config: Config) {
// We use the subtypeToTraits map to check if the type is a concrete subtype of a sealed trait.
// As of the time of writing this code, there should be only a single trait.
// In case future code generalizes to allow multiple mixins, this code should be updated.
subtypeToTraits.get(tName).fold(originalType) { set =>
// If the type parameter has exactly 1 super type trait,
// and that trait's name is different from our enclosing object's name,
// then we should alter the type to include the object's name.
if (set.size != 1 || set.head == getEncapsulatingName(codeStructureToAlter)) originalType
else Code.TypeRef(set.head + "." + tName)
}
subtypeToTraits
.get(tName)
.fold {
if (aliasedPrimitives(tName)) Code.TypeRef(tName + ".Type")
else originalType
} { set =>
// If the type parameter has exactly 1 super type trait,
// and that trait's name is different from our enclosing object's name,
// then we should alter the type to include the object's name.
if (set.size != 1 || set.head == getEncapsulatingName(codeStructureToAlter)) originalType
else Code.TypeRef(set.head + "." + tName)
}
}) :: tail
})
}(codeStructureToAlter)
@@ -329,7 +346,8 @@ final case class EndpointGen(config: Config) {
imports = (Code.Import.FromBase("component._") :: imports.flatten).distinct,
objects = List(
Code.Object(
className,
name = className,
extensions = Nil,
schema = false,
endpoints = endpoints.toMap,
objects = anonymousTypes.values.toList,
@@ -391,7 +409,8 @@ final case class EndpointGen(config: Config) {
)
anonymousTypes += method.toString ->
Code.Object(
method.toString,
name = method.toString,
extensions = Nil,
schema = false,
endpoints = Map.empty,
objects = code.objects,
@@ -429,7 +448,8 @@ final case class EndpointGen(config: Config) {
throw new Exception(s"Could not generate code for request body $schema"),
)
val obj = Code.Object(
method.toString,
name = method.toString,
extensions = Nil,
schema = false,
endpoints = Map.empty,
objects = code.objects,
@@ -476,7 +496,8 @@ final case class EndpointGen(config: Config) {
throw new Exception(s"Could not generate code for request body $schema"),
)
val obj = Code.Object(
method.toString,
name = method.toString,
extensions = Nil,
schema = false,
endpoints = Map.empty,
objects = code.objects,
@@ -524,7 +545,15 @@ final case class EndpointGen(config: Config) {
case Some(OpenAPI.ReferenceOr.Or(schema: JsonSchema)) =>
schemaToPathCodec(schema, openAPI, param.name)
case Some(OpenAPI.ReferenceOr.Reference(ref, _, _)) =>
schemaToPathCodec(resolveSchemaRef(openAPI, ref), openAPI, param.name)
val (baref, mutateCodec): (String, PathSegmentCode => PathSegmentCode) =
if (ref.startsWith("#/components/schemas/")) {
val baref = ref.replaceFirst("^#/components/schemas/", "")
baref -> ((psc: PathSegmentCode) => psc.copy(segmentType = CodecType.Aliased(psc.segmentType, baref)))
} else {
ref -> identity[PathSegmentCode]
}
val r = schemaToPathCodec(resolveSchemaRef(openAPI, baref), openAPI, param.name)
mutateCodec(r)
case None =>
// Not sure if open api allows path parameters without schema.
// But string seems a good default
@@ -596,6 +625,12 @@ final case class EndpointGen(config: Config) {
throw new Exception(s"Found reference to schema $key, but no components section found.")
}

def schemaToType(openAPI: OpenAPI, name: String, schema: JsonSchema): Option[(List[Code.Import], String)] =
Option.when(schema.isPrimitive) {
val field = schemaToField(schema, openAPI, name, Chunk.empty).get // .get is safe, always defined for primitives
CodeGen.render("")(field.fieldType)
}

@tailrec
private def schemaToPathCodec(schema: JsonSchema, openAPI: OpenAPI, name: String): Code.PathSegmentCode = {
schema match {
@@ -674,6 +709,28 @@ final case class EndpointGen(config: Config) {
if (obj.required.contains(name)) field else field.copy(fieldType = field.fieldType.opt)
}.toList

def aliasedSchemaToCode(openAPI: OpenAPI, name: String, wrapped: JsonSchema): Option[Code.File] =
schemaToType(openAPI, name, wrapped).map { case (imports, wrappedType) =>
Code.File(
List("component", name.capitalize + ".scala"),
pkgPath = List("component"),
imports = Code.Import("zio.prelude.Newtype") :: imports,
objects = List(
Code.Object(
name = name,
extensions = List(s"Newtype[${wrappedType}]"),
schema = Some("derive"),
endpoints = Map.empty,
objects = Nil,
caseClasses = Nil,
enums = Nil,
),
),
caseClasses = Nil,
enums = Nil,
)
}

def schemaToCode(
schema: JsonSchema,
openAPI: OpenAPI,
@@ -720,12 +777,12 @@ final case class EndpointGen(config: Config) {
schemaToCode(schema, openAPI, schemaName, annotations)

case JsonSchema.RefSchema(ref) => throw new Exception(s"Unexpected reference schema: $ref")
case JsonSchema.Integer(_, _, _, _, _, _) => None
case JsonSchema.String(_, _, _, _) => None
case JsonSchema.Boolean => None
case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
case JsonSchema.Integer(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema)
case JsonSchema.String(_, _, _, _) => aliasedSchemaToCode(openAPI, name, schema)
case JsonSchema.Boolean => aliasedSchemaToCode(openAPI, name, schema)
case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
throw new Exception("OneOf schemas with primitive types are not supported")
case JsonSchema.OneOfSchema(schemas) =>
case JsonSchema.OneOfSchema(schemas) =>
val discriminatorInfo =
annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator }
val discriminator: Option[String] = discriminatorInfo.map(_.propertyName)
@@ -785,7 +842,7 @@ final case class EndpointGen(config: Config) {
),
),
)
case JsonSchema.AllOfSchema(schemas) =>
case JsonSchema.AllOfSchema(schemas) =>
val genericFieldIndex = Iterator.from(0)
val unvalidatedFields = schemas.map(_.withoutAnnotations).flatMap {
case schema @ JsonSchema.Object(_, _, _) =>
@@ -830,9 +887,9 @@ final case class EndpointGen(config: Config) {
enums = Nil,
),
)
case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
throw new Exception("AnyOf schemas with primitive types are not supported")
case JsonSchema.AnyOfSchema(schemas) =>
case JsonSchema.AnyOfSchema(schemas) =>
val discriminatorInfo =
annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator }
val discriminator: Option[String] = discriminatorInfo.map(_.propertyName)
@@ -889,9 +946,9 @@ final case class EndpointGen(config: Config) {
),
),
)
case JsonSchema.Number(_, _, _, _, _, _) => None
case JsonSchema.ArrayType(None, _, _) => None
case JsonSchema.ArrayType(Some(schema), _, _) =>
case JsonSchema.Number(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema)
case JsonSchema.ArrayType(None, _, _) => None
case JsonSchema.ArrayType(Some(schema), _, _) =>
schemaToCode(schema, openAPI, name, annotations)
case JsonSchema.Object(properties, additionalProperties, _)
if properties.nonEmpty && additionalProperties.isRight =>
52 changes: 43 additions & 9 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala
Original file line number Diff line number Diff line change
@@ -48,20 +48,53 @@ object Code {
final case class FromBase(path: String) extends Import
}

/**
* @param schema
* \- value = "derive with" syntax, e.g. "DeriveSchema.gen" or just "derive"
*/
final case class Object(
name: String,
schema: Boolean,
extensions: List[String],
schema: Option[String],
endpoints: Map[Field, EndpointCode],
objects: List[Object],
caseClasses: List[CaseClass],
enums: List[Enum],
) extends ScalaType

object Object {
def schemaCompanion(str: String): Object = Object(str, schema = true, Map.empty, Nil, Nil, Nil)

def apply(
name: String,
extensions: List[String],
schema: Boolean,
endpoints: Map[Field, EndpointCode],
objects: List[Object],
caseClasses: List[CaseClass],
enums: List[Enum],
): Object =
Object(name, extensions, if (schema) Some("DeriveSchema.gen") else None, endpoints, objects, caseClasses, enums)

def schemaCompanion(str: String): Object = Object(
name = str,
extensions = Nil,
schema = true,
endpoints = Map.empty,
objects = Nil,
caseClasses = Nil,
enums = Nil,
)

def apply(name: String, endpoints: Map[Field, EndpointCode]): Object =
Object(name, schema = false, endpoints, Nil, Nil, Nil)
Object(
name = name,
extensions = Nil,
schema = false,
endpoints = endpoints,
objects = Nil,
caseClasses = Nil,
enums = Nil,
)
}

final case class CaseClass(name: String, fields: List[Field], companionObject: Option[Object], mixins: List[String])
@@ -151,12 +184,13 @@ object Code {
}
sealed trait CodecType
object CodecType {
case object Boolean extends CodecType
case object Int extends CodecType
case object Literal extends CodecType
case object Long extends CodecType
case object String extends CodecType
case object UUID extends CodecType
case object Boolean extends CodecType
case object Int extends CodecType
case object Literal extends CodecType
case object Long extends CodecType
case object String extends CodecType
case object UUID extends CodecType
case class Aliased(underlying: CodecType, newtypeName: String) extends CodecType
}
final case class QueryParamCode(name: String, queryType: CodecType)
final case class HeadersCode(headers: List[HeaderCode])
Loading