Skip to content

Commit

Permalink
Simplify schema based header codecs (zio#3232)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Feb 1, 2025
1 parent 6d96d16 commit b045873
Show file tree
Hide file tree
Showing 39 changed files with 1,086 additions and 1,411 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ jobs:
with:
apps: sbt

- uses: coursier/setup-action@v1
with:
apps: sbt

- name: Release
env:
PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }}
Expand Down
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version = 3.8.1
version = 3.8.6
maxColumn = 120

align.preset = more
Expand Down
4 changes: 2 additions & 2 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ object Dependencies {
val `jwt-core` = "com.github.jwt-scala" %% "jwt-core" % JwtCoreVersion
val `scala-compact-collection` = "org.scala-lang.modules" %% "scala-collection-compat" % ScalaCompactCollectionVersion

val scalafmt = "org.scalameta" %% "scalafmt-dynamic" % "3.8.1"
val scalametaParsers = "org.scalameta" %% "parsers" % "4.9.9"
val scalafmt = "org.scalameta" %% "scalafmt-dynamic" % "3.8.6"
val scalametaParsers = "org.scalameta" %% "parsers" % "4.12.7"

val netty =
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,9 @@ private[cli] object CliEndpoint {
}
CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List())

case HttpCodec.Header(headerType, _) =>
CliEndpoint(headers = HttpOptions.Header(headerType.name, TextCodec.string) :: List())
case HttpCodec.HeaderCustom(codec, _) =>
CliEndpoint(headers = HttpOptions.Header(codec.name.get, TextCodec.string) :: List())
case HttpCodec.Method(codec, _) =>
case HttpCodec.Header(headerType, _) =>
CliEndpoint(headers = HttpOptions.Header(headerType.names.head, TextCodec.string) :: List())
case HttpCodec.Method(codec, _) =>
codec.asInstanceOf[SimpleCodec[_, _]] match {
case SimpleCodec.Specified(method: Method) =>
CliEndpoint(methods = method)
Expand All @@ -126,14 +124,9 @@ private[cli] object CliEndpoint {
CliEndpoint(url = HttpOptions.Path(pathCodec) :: List())

case HttpCodec.Query(codec, _) =>
if (codec.isPrimitive)
CliEndpoint(url = HttpOptions.Query(codec) :: List())
else if (codec.isRecord)
CliEndpoint(url = codec.recordFields.map { case (_, codec) =>
HttpOptions.Query(codec)
}.toList)
else
CliEndpoint(url = HttpOptions.Query(codec) :: List())
CliEndpoint(url = codec.recordFields.map { case (f, codec) =>
HttpOptions.Query(codec, f.fieldName)
}.toList)
case HttpCodec.Status(_, _) => CliEndpoint.empty

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import zio.schema._
import zio.schema.annotation.description

import zio.http._
import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.internal.StringSchemaCodec
import zio.http.internal.StringSchemaCodec.PrimitiveCodec

/*
* HttpOptions is a wrapper of a transformation Options[CliRequest] => Options[CliRequest].
Expand Down Expand Up @@ -265,11 +266,10 @@ private[cli] object HttpOptions {

}

final case class Query(codec: SchemaCodec[_], doc: Doc = Doc.empty) extends URLOptions {
final case class Query(codec: PrimitiveCodec[_], name: String, doc: Doc = Doc.empty) extends URLOptions {
self =>
override val name = codec.name.get
override val tag = "?" + name
def options: Options[_] = optionsFromSchema(codec)(name)
def options: Options[_] = optionsFromSchema(codec.schema)(name)

override def ??(doc: Doc): Query = self.copy(doc = self.doc + doc)

Expand All @@ -293,8 +293,8 @@ private[cli] object HttpOptions {

}

private[cli] def optionsFromSchema[A](codec: SchemaCodec[A]): String => Options[A] =
codec.schema match {
private[cli] def optionsFromSchema[A](schema: Schema[A]): String => Options[A] =
schema match {
case Schema.Primitive(standardType, _) =>
standardType match {
case StandardType.UnitType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ object CommandGen {
case _: HttpOptions.Constant => false
case _ => true
}.map {
case HttpOptions.Path(pathCodec, _) =>
case HttpOptions.Path(pathCodec, _) =>
pathCodec.segments.toList.flatMap { segment =>
getSegment(segment) match {
case (_, "") => Nil
case (name, "boolean") => s"[${getName(name, "")}]" :: Nil
case (name, codec) => s"${getName(name, "")} $codec" :: Nil
}
}
case HttpOptions.Query(codec, _) if codec.isPrimitive =>
case HttpOptions.Query(codec, name, _) =>
getType(codec.schema) match {
case "" => s"[${getName(codec.name.get, "")}]" :: Nil
case tpy => s"${getName(codec.name.get, "")} $tpy" :: Nil
case "" => s"[${getName(name, "")}]" :: Nil
case tpy => s"${getName(name, "")} $tpy" :: Nil
}
case _ => Nil
case _ => Nil
}.foldRight(List[String]())(_ ++ _)

val headersOptions = cliEndpoint.headers.filter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import zio.test._

import zio.schema.Schema

import zio.http.Header.HeaderType
import zio.http._
import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.endpoint._
import zio.http.endpoint.cli.AuxGen._
Expand Down Expand Up @@ -103,10 +101,10 @@ object EndpointGen {
lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] =
Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) =>
val schema = schema0.asInstanceOf[Schema[Any]]
val codec = SchemaCodec(Some(name), schema)
val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]]
CliRepr(
HttpCodec.Query(codec),
CliEndpoint(url = HttpOptions.Query(codec) :: Nil),
codec,
CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import zio.test.Gen
import zio.schema.Schema

import zio.http._
import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.endpoint.cli.AuxGen._
import zio.http.endpoint.cli.CliRepr._
import zio.http.internal.StringSchemaCodec.PrimitiveCodec

/**
* Constructs a Gen[Options[CliRequest], CliEndpoint]
Expand All @@ -33,10 +33,10 @@ object OptionsGen {
.optionsFromTextCodec(textCodec)(name)
.map(value => textCodec.encode(value))

def encodeOptions[A](name: String, codec: SchemaCodec[A]): Options[String] =
def encodeOptions[A](name: String, codec: PrimitiveCodec[A], schema: Schema[A]): Options[String] =
HttpOptions
.optionsFromSchema(codec)(name)
.map(value => codec.stringCodec.encode(value))
.optionsFromSchema(schema)(name)
.map(value => codec.encode(value))

lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] =
Gen
Expand Down Expand Up @@ -80,10 +80,10 @@ object OptionsGen {
.alphaNumericStringBounded(1, 30)
.zip(anyStandardType)
.map { case (name, schema) =>
val codec = SchemaCodec(Some(name), schema)
val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]]
CliRepr(
encodeOptions(name, codec),
CliEndpoint(url = HttpOptions.Query(codec) :: Nil),
encodeOptions(name, codec.codec.recordFields.head._2, schema.asInstanceOf[Schema[Any]]),
CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil),
)
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ object WebSocketReconnectingClient extends ZIOAppDefault {
channel.send(ChannelEvent.Read(WebSocketFrame.text("foo")))

// On receiving "foo", we'll reply with another "foo" to keep echo loop going
case Read(WebSocketFrame.Text("foo")) =>
case Read(WebSocketFrame.Text("foo")) =>
ZIO.logInfo("Received foo message.") *>
ZIO.sleep(1.second) *>
channel.send(ChannelEvent.Read(WebSocketFrame.text("foo")))

// Handle exception and convert it to failure to signal the shutdown of the socket connection via the promise
case ExceptionCaught(t) =>
case ExceptionCaught(t) =>
ZIO.fail(t)

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ object WebSocketServerAdvanced extends ZIOAppDefault {
val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("end")) =>
case Read(WebSocketFrame.Text("end")) =>
channel.shutdown

// Send a "bar" if the client sends a "foo"
case Read(WebSocketFrame.Text("foo")) =>
case Read(WebSocketFrame.Text("foo")) =>
channel.send(Read(WebSocketFrame.text("bar")))

// Send a "foo" if the client sends a "bar"
case Read(WebSocketFrame.Text("bar")) =>
case Read(WebSocketFrame.Text("bar")) =>
channel.send(Read(WebSocketFrame.text("foo")))

// Echo the same message 10 times if it's not "foo" or "bar"
case Read(WebSocketFrame.Text(text)) =>
case Read(WebSocketFrame.Text(text)) =>
channel
.send(Read(WebSocketFrame.text(s"echo $text")))
.repeatN(10)
Expand All @@ -38,11 +38,11 @@ object WebSocketServerAdvanced extends ZIOAppDefault {
channel.send(Read(WebSocketFrame.text("Greetings!")))

// Log when the channel is getting closed
case Read(WebSocketFrame.Close(status, reason)) =>
case Read(WebSocketFrame.Close(status, reason)) =>
Console.printLine("Closing channel with status: " + status + " and reason: " + reason)

// Print the exception if it's not a normal close
case ExceptionCaught(cause) =>
case ExceptionCaught(cause) =>
Console.printLine(s"Channel error!: ${cause.getMessage}")

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ object WebSocketSimpleClient extends ZIOAppDefault {
channel.send(Read(WebSocketFrame.text("foo")))

// Send a "bar" if the server sends a "foo"
case Read(WebSocketFrame.Text("foo")) =>
case Read(WebSocketFrame.Text("foo")) =>
channel.send(Read(WebSocketFrame.text("bar")))

// Close the connection if the server sends a "bar"
case Read(WebSocketFrame.Text("bar")) =>
case Read(WebSocketFrame.Text("bar")) =>
ZIO.succeed(println("Goodbye!")) *> channel.send(Read(WebSocketFrame.close(1000)))

case _ =>
Expand Down
22 changes: 11 additions & 11 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,9 @@ final case class EndpointGen(config: Config) {
* `transform` that simply `wrap` / `unwrap` the provided value.
*/
case JsonSchema.Boolean => aliasedSchemaToCode(openAPI, name, schema)
case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
case JsonSchema.OneOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
throw new Exception("OneOf schemas with primitive types are not supported")
case JsonSchema.OneOfSchema(schemas) =>
case JsonSchema.OneOfSchema(schemas) =>
val discriminatorInfo =
annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator }
val discriminator: Option[String] = discriminatorInfo.map(_.propertyName)
Expand Down Expand Up @@ -885,7 +885,7 @@ final case class EndpointGen(config: Config) {
),
),
)
case JsonSchema.AllOfSchema(schemas) =>
case JsonSchema.AllOfSchema(schemas) =>
val genericFieldIndex = Iterator.from(0)
val unvalidatedFields = schemas.toList.map(_.withoutAnnotations).flatMap {
case schema @ JsonSchema.Object(_, _, _) =>
Expand Down Expand Up @@ -928,9 +928,9 @@ final case class EndpointGen(config: Config) {
enums = Nil,
),
)
case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
case JsonSchema.AnyOfSchema(schemas) if schemas.exists(_.isPrimitive) =>
throw new Exception("AnyOf schemas with primitive types are not supported")
case JsonSchema.AnyOfSchema(schemas) =>
case JsonSchema.AnyOfSchema(schemas) =>
val discriminatorInfo =
annotations.collectFirst { case JsonSchema.MetaData.Discriminator(discriminator) => discriminator }
val discriminator: Option[String] = discriminatorInfo.map(_.propertyName)
Expand Down Expand Up @@ -987,12 +987,12 @@ final case class EndpointGen(config: Config) {
),
),
)
case JsonSchema.Number(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema)
case JsonSchema.Number(_, _, _, _, _, _) => aliasedSchemaToCode(openAPI, name, schema)
// should we provide support for (Newtype) aliasing arrays of primitives?
case JsonSchema.ArrayType(None, _, _) => None
case JsonSchema.ArrayType(Some(schema), _, _) =>
case JsonSchema.ArrayType(None, _, _) => None
case JsonSchema.ArrayType(Some(schema), _, _) =>
schemaToCode(schema, openAPI, name, annotations)
case obj: JsonSchema.Object if obj.isInvalid =>
case obj: JsonSchema.Object if obj.isInvalid =>
throw new Exception("Object with properties and additionalProperties is not supported")
case obj @ JsonSchema.Object(properties, _, _) if obj.isClosedDictionary =>
val unvalidatedFields = fieldsOfObject(openAPI, annotations)(obj)
Expand Down Expand Up @@ -1052,8 +1052,8 @@ final case class EndpointGen(config: Config) {
),
),
)
case JsonSchema.Null => throw new Exception("Null query parameters are not supported")
case JsonSchema.AnyJson => None
case JsonSchema.Null => throw new Exception("Null query parameters are not supported")
case JsonSchema.AnyJson => None
}
}

Expand Down
14 changes: 8 additions & 6 deletions zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ import zio.http.gen.model._
import zio.http.gen.openapi.Config.NormalizeFields
import zio.http.gen.openapi.{Config, EndpointGen}

// format: off
@nowarn("msg=missing interpolator")
object CodeGenSpec extends ZIOSpecDefault {

case class ValidatedData(
@validate(Validation.maxLength(10))
name: String,
@validate(Validation.greaterThan(0) && Validation.lessThan(100))
age: Int,
@validate(Validation.maxLength(10)) name: String,
@validate(Validation.greaterThan(0) && Validation.lessThan(100)) age: Int,
)
implicit val validatedDataSchema: Schema[ValidatedData] = DeriveSchema.gen[ValidatedData]
implicit val validatedDataSchema: Schema[ValidatedData] =
DeriveSchema.gen[ValidatedData]

private def fileShouldBe(dir: java.nio.file.Path, subPath: String, expectedFile: String): TestResult = {
val filePath = dir.resolve(Paths.get(subPath))
Expand Down Expand Up @@ -156,7 +156,8 @@ object CodeGenSpec extends ZIOSpecDefault {
.header(HeaderCodec.accept)
.header(HeaderCodec.contentType)
.header(HeaderCodec.headerAs[String]("Token"))
val openAPI = OpenAPIGen.fromEndpoints(endpoint)

val openAPI = OpenAPIGen.fromEndpoints(endpoint)

codeGenFromOpenAPI(openAPI) { testDir =>
fileShouldBe(testDir, "api/v1/Users.scala", "/EndpointWithHeaders.scala")
Expand Down Expand Up @@ -605,6 +606,7 @@ object CodeGenSpec extends ZIOSpecDefault {
}
}
} @@ TestAspect.exceptScala3, // for some reason, the temp dir is empty in Scala 3
//format: on
test("Endpoint with array field in input") {
val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User]
val openAPI = OpenAPIGen.fromEndpoints("", "", endpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package zio.http.netty.model

import scala.collection.AbstractIterator

import zio.Chunk

import zio.http.Server.Config.CompressionOptions
import zio.http._

Expand Down Expand Up @@ -58,10 +60,10 @@ private[netty] object Conversions {

def headersToNetty(headers: Headers): HttpHeaders =
headers match {
case Headers.FromIterable(_) => encodeHeaderListToNetty(headers)
case Headers.Native(value, _, _, _) => value.asInstanceOf[HttpHeaders]
case Headers.Concat(_, _) => encodeHeaderListToNetty(headers)
case Headers.Empty => new DefaultHttpHeaders()
case Headers.FromIterable(_) => encodeHeaderListToNetty(headers)
case Headers.Native(value, _, _, _, _) => value.asInstanceOf[HttpHeaders]
case Headers.Concat(_, _) => encodeHeaderListToNetty(headers)
case Headers.Empty => new DefaultHttpHeaders()
}

def urlToNetty(url: URL): String = {
Expand Down Expand Up @@ -89,6 +91,7 @@ private[netty] object Conversions {
(headers: HttpHeaders) => nettyHeadersIterator(headers),
// NOTE: Netty's headers.get is case-insensitive
(headers: HttpHeaders, key: CharSequence) => headers.get(key),
(headers: HttpHeaders, key: CharSequence) => Chunk.fromJavaIterable(headers.getAll(key)),
(headers: HttpHeaders, key: CharSequence) => headers.contains(key),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ object LogAnnotationMiddlewareSpec extends ZIOSpecDefault {
handler(ZIO.logWarning("Oh!") *> ZIO.succeed(Response.text("Hey logging!"))),
)
.@@(
Middleware.logAnnotate(req =>
Set(LogAnnotation("method", req.method.name), LogAnnotation("path", req.path.encode)),
),
Middleware
.logAnnotate(req => Set(LogAnnotation("method", req.method.name), LogAnnotation("path", req.path.encode))),
)
.runZIO(Request.get("/"))

Expand Down
Loading

0 comments on commit b045873

Please sign in to comment.