Skip to content

Commit

Permalink
Simplify schema based header codecs (zio#3232)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jan 31, 2025
1 parent 6d96d16 commit be67f1f
Show file tree
Hide file tree
Showing 24 changed files with 1,019 additions and 1,323 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ jobs:
with:
apps: sbt

- uses: coursier/setup-action@v1
with:
apps: sbt

- name: Release
env:
PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,9 @@ private[cli] object CliEndpoint {
}
CliEndpoint(body = HttpOptions.Body(name, codec.defaultMediaType, codec.defaultSchema) :: List())

case HttpCodec.Header(headerType, _) =>
CliEndpoint(headers = HttpOptions.Header(headerType.name, TextCodec.string) :: List())
case HttpCodec.HeaderCustom(codec, _) =>
CliEndpoint(headers = HttpOptions.Header(codec.name.get, TextCodec.string) :: List())
case HttpCodec.Method(codec, _) =>
case HttpCodec.Header(headerType, _) =>
CliEndpoint(headers = HttpOptions.Header(headerType.names.head, TextCodec.string) :: List())
case HttpCodec.Method(codec, _) =>
codec.asInstanceOf[SimpleCodec[_, _]] match {
case SimpleCodec.Specified(method: Method) =>
CliEndpoint(methods = method)
Expand All @@ -126,14 +124,9 @@ private[cli] object CliEndpoint {
CliEndpoint(url = HttpOptions.Path(pathCodec) :: List())

case HttpCodec.Query(codec, _) =>
if (codec.isPrimitive)
CliEndpoint(url = HttpOptions.Query(codec) :: List())
else if (codec.isRecord)
CliEndpoint(url = codec.recordFields.map { case (_, codec) =>
HttpOptions.Query(codec)
}.toList)
else
CliEndpoint(url = HttpOptions.Query(codec) :: List())
CliEndpoint(url = codec.recordFields.map { case (f, codec) =>
HttpOptions.Query(codec, f.fieldName)
}.toList)
case HttpCodec.Status(_, _) => CliEndpoint.empty

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import zio.schema._
import zio.schema.annotation.description

import zio.http._
import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.internal.StringSchemaCodec
import zio.http.internal.StringSchemaCodec.PrimitiveCodec

/*
* HttpOptions is a wrapper of a transformation Options[CliRequest] => Options[CliRequest].
Expand Down Expand Up @@ -265,11 +266,10 @@ private[cli] object HttpOptions {

}

final case class Query(codec: SchemaCodec[_], doc: Doc = Doc.empty) extends URLOptions {
final case class Query(codec: PrimitiveCodec[_], name: String, doc: Doc = Doc.empty) extends URLOptions {
self =>
override val name = codec.name.get
override val tag = "?" + name
def options: Options[_] = optionsFromSchema(codec)(name)
def options: Options[_] = optionsFromSchema(codec.schema)(name)

override def ??(doc: Doc): Query = self.copy(doc = self.doc + doc)

Expand All @@ -293,8 +293,8 @@ private[cli] object HttpOptions {

}

private[cli] def optionsFromSchema[A](codec: SchemaCodec[A]): String => Options[A] =
codec.schema match {
private[cli] def optionsFromSchema[A](schema: Schema[A]): String => Options[A] =
schema match {
case Schema.Primitive(standardType, _) =>
standardType match {
case StandardType.UnitType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ object CommandGen {
case _: HttpOptions.Constant => false
case _ => true
}.map {
case HttpOptions.Path(pathCodec, _) =>
case HttpOptions.Path(pathCodec, _) =>
pathCodec.segments.toList.flatMap { segment =>
getSegment(segment) match {
case (_, "") => Nil
case (name, "boolean") => s"[${getName(name, "")}]" :: Nil
case (name, codec) => s"${getName(name, "")} $codec" :: Nil
}
}
case HttpOptions.Query(codec, _) if codec.isPrimitive =>
case HttpOptions.Query(codec, name, _) =>
getType(codec.schema) match {
case "" => s"[${getName(codec.name.get, "")}]" :: Nil
case tpy => s"${getName(codec.name.get, "")} $tpy" :: Nil
case "" => s"[${getName(name, "")}]" :: Nil
case tpy => s"${getName(name, "")} $tpy" :: Nil
}
case _ => Nil
case _ => Nil
}.foldRight(List[String]())(_ ++ _)

val headersOptions = cliEndpoint.headers.filter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ import zio.test._

import zio.schema.Schema

import zio.http.Header.HeaderType
import zio.http._
import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.endpoint._
import zio.http.endpoint.cli.AuxGen._
Expand Down Expand Up @@ -103,10 +101,10 @@ object EndpointGen {
lazy val anyQuery: Gen[Any, CliReprOf[Codec[_]]] =
Gen.alphaNumericStringBounded(1, 30).zip(anyStandardType).map { case (name, schema0) =>
val schema = schema0.asInstanceOf[Schema[Any]]
val codec = SchemaCodec(Some(name), schema)
val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]]
CliRepr(
HttpCodec.Query(codec),
CliEndpoint(url = HttpOptions.Query(codec) :: Nil),
codec,
CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import zio.test.Gen
import zio.schema.Schema

import zio.http._
import zio.http.codec.HttpCodec.SchemaCodec
import zio.http.codec._
import zio.http.endpoint.cli.AuxGen._
import zio.http.endpoint.cli.CliRepr._
import zio.http.internal.StringSchemaCodec.PrimitiveCodec

/**
* Constructs a Gen[Options[CliRequest], CliEndpoint]
Expand All @@ -33,10 +33,10 @@ object OptionsGen {
.optionsFromTextCodec(textCodec)(name)
.map(value => textCodec.encode(value))

def encodeOptions[A](name: String, codec: SchemaCodec[A]): Options[String] =
def encodeOptions[A](name: String, codec: PrimitiveCodec[A], schema: Schema[A]): Options[String] =
HttpOptions
.optionsFromSchema(codec)(name)
.map(value => codec.stringCodec.encode(value))
.optionsFromSchema(schema)(name)
.map(value => codec.encode(value))

lazy val anyBodyOption: Gen[Any, CliReprOf[Options[Retriever]]] =
Gen
Expand Down Expand Up @@ -80,10 +80,10 @@ object OptionsGen {
.alphaNumericStringBounded(1, 30)
.zip(anyStandardType)
.map { case (name, schema) =>
val codec = SchemaCodec(Some(name), schema)
val codec = QueryCodec.query(name)(schema).asInstanceOf[HttpCodec.Query[Any]]
CliRepr(
encodeOptions(name, codec),
CliEndpoint(url = HttpOptions.Query(codec) :: Nil),
encodeOptions(name, codec.codec.recordFields.head._2, schema.asInstanceOf[Schema[Any]]),
CliEndpoint(url = HttpOptions.Query(codec.codec.recordFields.head._2, name) :: Nil),
)
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package zio.http.netty.model

import scala.collection.AbstractIterator

import zio.Chunk

import zio.http.Server.Config.CompressionOptions
import zio.http._

Expand Down Expand Up @@ -58,10 +60,10 @@ private[netty] object Conversions {

def headersToNetty(headers: Headers): HttpHeaders =
headers match {
case Headers.FromIterable(_) => encodeHeaderListToNetty(headers)
case Headers.Native(value, _, _, _) => value.asInstanceOf[HttpHeaders]
case Headers.Concat(_, _) => encodeHeaderListToNetty(headers)
case Headers.Empty => new DefaultHttpHeaders()
case Headers.FromIterable(_) => encodeHeaderListToNetty(headers)
case Headers.Native(value, _, _, _, _) => value.asInstanceOf[HttpHeaders]
case Headers.Concat(_, _) => encodeHeaderListToNetty(headers)
case Headers.Empty => new DefaultHttpHeaders()
}

def urlToNetty(url: URL): String = {
Expand Down Expand Up @@ -89,6 +91,7 @@ private[netty] object Conversions {
(headers: HttpHeaders) => nettyHeadersIterator(headers),
// NOTE: Netty's headers.get is case-insensitive
(headers: HttpHeaders, key: CharSequence) => headers.get(key),
(headers: HttpHeaders, key: CharSequence) => Chunk.fromJavaIterable(headers.getAll(key)),
(headers: HttpHeaders, key: CharSequence) => headers.contains(key),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package zio.http.endpoint

import java.time.Instant

import scala.math.BigDecimal.javaBigDecimal2bigDecimal

import zio._
import zio.test._

Expand Down
145 changes: 141 additions & 4 deletions zio-http/shared/src/main/scala/zio/http/Header.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ import scala.util.{Either, Failure, Success, Try}
import zio.Config.Secret
import zio._

import zio.http.codec.RichTextCodec
import zio.http.internal.DateEncoding
import zio.schema.Schema
import zio.schema.codec.DecodeError.ReadError

import zio.http.Header.HeaderTypeBase.Typed
import zio.http.codec.{HttpCodecError, RichTextCodec}
import zio.http.internal.{DateEncoding, ErrorConstructor, StringSchemaCodec}

sealed trait Header {
type Self <: Header
Expand All @@ -50,21 +54,154 @@ sealed trait Header {

object Header {

sealed trait HeaderType {
sealed trait HeaderTypeBase {
type HeaderValue

def names: Chunk[String]

def fromHeaders(headers: Headers): Either[String, HeaderValue]

private[http] def fromHeadersUnsafe(headers: Headers): HeaderValue

def toHeaders(value: HeaderValue): Headers =
value match {
case h: Header => Headers.fromIterable(h :: Nil)
case _ => Headers.empty
}
}

object HeaderTypeBase {
type Typed[HV] = HeaderTypeBase { type HeaderValue = HV }
}

sealed trait SchemaHeaderType extends HeaderTypeBase {
def schema: Schema[HeaderValue]

def optional: HeaderTypeBase.Typed[Option[HeaderValue]]
}

object SchemaHeaderType {
type Typed[H] = SchemaHeaderType { type HeaderValue = H }

def apply[H](implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = {
new SchemaHeaderType {
type HeaderValue = H
val schema: Schema[H] = schema0
val codec: StringSchemaCodec[H, Headers] = StringSchemaCodec.fromSchema(
schema,
(h: Headers, k: String, v: String) => h.addHeader(k, v),
(h: Headers, kvs: Iterable[(String, String)]) => h.addHeaders(kvs),
(h: Headers, k: String) => h.contains(k),
(h: Headers, k: String) => h.getUnsafe(k),
(h: Headers, k: String) => h.rawHeaders(k),
(h: Headers, k: String) => h.rawHeaders(k).size,
ErrorConstructor(
param => HttpCodecError.MissingHeader(param),
params => HttpCodecError.MissingHeaders(params),
validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors),
(param, value) => HttpCodecError.DecodingErrorHeader(param, value),
(param, expected, actual) => HttpCodecError.InvalidHeaderCount(param, expected, actual),
),
isKebabCase = true,
null,
)

override def names: Chunk[String] =
codec.recordFields.map(_._1.fieldName)

override def optional: SchemaHeaderType.Typed[Option[H]] =
apply(schema.optional)

override def fromHeaders(headers: Headers): Either[String, H] =
try Right(codec.decode(headers))
catch {
case NonFatal(e) => Left(e.getMessage)
}

private[http] override def fromHeadersUnsafe(headers: Headers): H =
codec.decode(headers)

override def toHeaders(value: H): Headers =
codec.encode(value, Headers.empty)
}
}

def apply[H](name: String)(implicit schema0: Schema[H]): SchemaHeaderType.Typed[H] = {
new SchemaHeaderType {
type HeaderValue = H
val schema: Schema[H] = schema0
val codec: StringSchemaCodec[H, Headers] = StringSchemaCodec.fromSchema(
schema,
(h: Headers, k: String, v: String) => h.addHeader(k, v),
(h: Headers, kvs: Iterable[(String, String)]) => h.addHeaders(kvs),
(h: Headers, k: String) => h.contains(k),
(h: Headers, k: String) => h.getUnsafe(k),
(h: Headers, k: String) => h.rawHeaders(k),
(h: Headers, k: String) => h.rawHeaders(k).size,
ErrorConstructor(
header => HttpCodecError.MissingHeader(header),
headers => HttpCodecError.MissingHeaders(headers),
validationErrors => HttpCodecError.InvalidEntity.wrap(validationErrors),
(header, value) => HttpCodecError.DecodingErrorHeader(header, value),
(header, expected, actual) => HttpCodecError.InvalidHeaderCount(header, expected, actual),
),
isKebabCase = true,
name,
)

override def names: Chunk[String] =
codec.recordFields.map(_._1.fieldName)

override def optional: SchemaHeaderType.Typed[Option[H]] =
apply(name)(schema.optional)

override def fromHeaders(headers: Headers): Either[String, H] =
try Right(codec.decode(headers))
catch {
case NonFatal(e) => Left(e.getMessage)
}

private[http] override def fromHeadersUnsafe(headers: Headers): H =
codec.decode(headers)

override def toHeaders(value: H): Headers =
codec.encode(value, Headers.empty)
}
}
}

sealed trait HeaderType extends HeaderTypeBase {
type HeaderValue <: Header

def names: Chunk[String] = Chunk.single(name)

def name: String

def parse(value: String): Either[String, HeaderValue]

def render(value: HeaderValue): String

def fromHeaders(headers: Headers): Either[String, HeaderValue] =
headers.getUnsafe(name) match {
case null => Left(s"Header $name not found")
case value => parse(value)
}

def fromHeadersUnsafe(headers: Headers): HeaderValue =
fromHeaders(headers).fold(
e => throw HttpCodecError.DecodingErrorHeader(name, ReadError(Cause.empty, e)),
identity,
)

override def toHeaders(value: HeaderValue): Headers =
Headers.FromIterable(Iterable(value))

}

object HeaderType {
type Typed[HV] = HeaderType { type HeaderValue = HV }
}

// @deprecated("Use Schema based header codecs instead", "3.1.0")
final case class Custom(customName: CharSequence, value: CharSequence) extends Header {
override type Self = Custom
override def self: Self = this
Expand Down
Loading

0 comments on commit be67f1f

Please sign in to comment.