diff --git a/.github/workflows/site.yml b/.github/workflows/site.yml index 65b15c12c..027da47ab 100644 --- a/.github/workflows/site.yml +++ b/.github/workflows/site.yml @@ -57,40 +57,64 @@ jobs: run: sbt docs/publishToNpm env: NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} - generate-readme: - name: Generate README + update-readme: + name: Update README runs-on: ubuntu-latest - if: ${{ (github.event_name == 'push') || ((github.event_name == 'release') && (github.event.action == 'published')) }} + continue-on-error: false + if: ${{ github.event_name == 'push' }} steps: - name: Git Checkout - uses: actions/checkout@v3.3.0 + uses: actions/checkout@v4 with: - ref: ${{ github.head_ref }} fetch-depth: '0' + - name: Install libuv + run: sudo apt-get update && sudo apt-get install -y libuv1-dev - name: Setup Scala - uses: actions/setup-java@v3.9.0 + uses: actions/setup-java@v4 with: - distribution: temurin - java-version: 17 + distribution: corretto + java-version: '17' check-latest: true + - name: Cache Dependencies + uses: coursier/cache-action@v6 - name: Generate Readme - run: sbt docs/generateReadme + run: sbt docs/generateReadme - name: Commit Changes run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" + git config --local user.email "zio-assistant[bot]@users.noreply.github.com" + git config --local user.name "ZIO Assistant" git add README.md git commit -m "Update README.md" || echo "No changes to commit" + - name: Generate Token + id: generate-token + uses: zio/generate-github-app-token@v1.0.0 + with: + app_id: ${{ secrets.APP_ID }} + app_private_key: ${{ secrets.APP_PRIVATE_KEY }} - name: Create Pull Request - uses: peter-evans/create-pull-request@v4.2.3 + id: cpr + uses: peter-evans/create-pull-request@v6 with: body: |- Autogenerated changes after running the `sbt docs/generateReadme` command of the [zio-sbt-website](https://zio.dev/zio-sbt) plugin. - I will automatically update the README.md file whenever there is new change for README.md, e.g. + I will automatically update the README.md file whenever there is a new change for README.md, e.g. - After each release, I will update the version in the installation section. - After any changes to the "docs/index.md" file, I will update the README.md file accordingly. branch: zio-sbt-website/update-readme commit-message: Update README.md + token: ${{ steps.generate-token.outputs.token }} delete-branch: true title: Update README.md + - name: Approve PR + if: ${{ steps.cpr.outputs.pull-request-number }} + run: gh pr review "$PR_URL" --approve + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_URL: ${{ steps.cpr.outputs.pull-request-url }} + - name: Enable Auto-Merge + if: ${{ steps.cpr.outputs.pull-request-number }} + run: gh pr merge --auto --squash "$PR_URL" || gh pr merge --squash "$PR_URL" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_URL: ${{ steps.cpr.outputs.pull-request-url }} diff --git a/README.md b/README.md index 9d5fc1649..27e69d518 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Some of the key features of ZIO HTTP are: Setup via `build.sbt`: ```scala -libraryDependencies += "dev.zio" %% "zio-http" % "3.0.0-RC7" +libraryDependencies += "dev.zio" %% "zio-http" % "3.0.0-RC9" ``` **NOTES ON VERSIONING:** diff --git a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala index 15d94f264..3be96b515 100644 --- a/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala +++ b/zio-http-cli/src/main/scala/zio/http/endpoint/cli/HttpOptions.scala @@ -312,7 +312,7 @@ private[cli] object HttpOptions { private[cli] def optionsFromSegment(segment: SegmentCodec[_]): Options[String] = { def fromSegment[A](segment: SegmentCodec[A]): Options[String] = segment match { - case SegmentCodec.UUID(name) => + case SegmentCodec.UUID(name) => Options .text(name) .mapOrFail(str => @@ -324,13 +324,14 @@ private[cli] object HttpOptions { }, ) .map(_.toString) - case SegmentCodec.Text(name) => Options.text(name) - case SegmentCodec.IntSeg(name) => Options.integer(name).map(_.toInt).map(_.toString) - case SegmentCodec.LongSeg(name) => Options.integer(name).map(_.toInt).map(_.toString) - case SegmentCodec.BoolSeg(name) => Options.boolean(name).map(_.toString) - case SegmentCodec.Literal(value) => Options.Empty.map(_ => value) - case SegmentCodec.Trailing => Options.none.map(_.toString) - case SegmentCodec.Empty => Options.none.map(_.toString) + case SegmentCodec.Text(name) => Options.text(name) + case SegmentCodec.IntSeg(name) => Options.integer(name).map(_.toInt).map(_.toString) + case SegmentCodec.LongSeg(name) => Options.integer(name).map(_.toInt).map(_.toString) + case SegmentCodec.BoolSeg(name) => Options.boolean(name).map(_.toString) + case SegmentCodec.Literal(value) => Options.Empty.map(_ => value) + case SegmentCodec.Trailing => Options.none.map(_.toString) + case SegmentCodec.Empty => Options.none.map(_.toString) + case SegmentCodec.Combined(_, _, _) => throw new IllegalArgumentException("Combined segment not supported") } fromSegment(segment) diff --git a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala index 395fab438..ac69bffc7 100644 --- a/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala +++ b/zio-http-cli/src/test/scala/zio/http/endpoint/cli/CommandGen.scala @@ -20,14 +20,16 @@ object CommandGen { def getSegment(segment: SegmentCodec[_]): (String, String) = { def fromSegment[A](segment: SegmentCodec[A]): (String, String) = segment match { - case SegmentCodec.UUID(name) => (name, "text") - case SegmentCodec.Text(name) => (name, "text") - case SegmentCodec.IntSeg(name) => (name, "integer") - case SegmentCodec.LongSeg(name) => (name, "integer") - case SegmentCodec.BoolSeg(name) => (name, "boolean") - case SegmentCodec.Literal(_) => ("", "") - case SegmentCodec.Trailing => ("", "") - case SegmentCodec.Empty => ("", "") + case SegmentCodec.UUID(name) => (name, "text") + case SegmentCodec.Text(name) => (name, "text") + case SegmentCodec.IntSeg(name) => (name, "integer") + case SegmentCodec.LongSeg(name) => (name, "integer") + case SegmentCodec.BoolSeg(name) => (name, "boolean") + case SegmentCodec.Literal(_) => ("", "") + case SegmentCodec.Trailing => ("", "") + case SegmentCodec.Empty => ("", "") + case SegmentCodec.Combined(left, right, combiner) => + ??? } fromSegment(segment) diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index f7575e841..be1ce8407 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -182,7 +182,6 @@ object CodeGenSpec extends ZIOSpecDefault { val code = EndpointGen.fromOpenAPI(openAPI) val tempDir = Files.createTempDirectory("codegen") - println(tempDir) CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) fileShouldBe( @@ -240,7 +239,6 @@ object CodeGenSpec extends ZIOSpecDefault { val code = EndpointGen.fromOpenAPI(openAPI) val tempDir = Files.createTempDirectory("codegen") - println(tempDir) CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath)) fileShouldBe( diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala index 04455169a..9e68e07a1 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala @@ -55,6 +55,7 @@ object NettyResponse { onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status))) Response(status, headers, Body.empty) } else { + val contentType = headers.get(Header.ContentType) val responseHandler = new ClientResponseStreamHandler(onComplete, keepAlive, status) ctx .pipeline() @@ -64,7 +65,11 @@ object NettyResponse { responseHandler, ): Unit - val data = NettyBody.fromAsync(callback => responseHandler.connect(callback), knownContentLength) + val data = NettyBody.fromAsync( + callback => responseHandler.connect(callback), + knownContentLength, + contentType.map(_.renderedValue), + ) Response(status, headers, data) } } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyRequestEncoder.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyRequestEncoder.scala index 2bccb58f3..f20a8d90f 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyRequestEncoder.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyRequestEncoder.scala @@ -16,14 +16,13 @@ package zio.http.netty.client +import zio.Unsafe import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{Task, Trace, Unsafe, ZIO} -import zio.http.netty._ import zio.http.netty.model.Conversions import zio.http.{Body, Request} -import io.netty.buffer.{ByteBuf, EmptyByteBuf, Unpooled} +import io.netty.buffer.Unpooled import io.netty.handler.codec.http.{DefaultFullHttpRequest, DefaultHttpRequest, HttpHeaderNames, HttpRequest} private[zio] object NettyRequestEncoder { @@ -34,12 +33,7 @@ private[zio] object NettyRequestEncoder { def encode(req: Request): HttpRequest = { val method = Conversions.methodToNetty(req.method) val jVersion = Conversions.versionToNetty(req.version) - - def replaceEmptyPathWithSlash(url: zio.http.URL) = if (url.path.isEmpty) url.addLeadingSlash else url - - // As per the spec, the path should contain only the relative path. - // Host and port information should be in the headers. - val path = replaceEmptyPathWithSlash(req.url).relative.addLeadingSlash.encode + val path = Conversions.urlToNetty(req.url) val headers = Conversions.headersToNetty(req.allHeaders) @@ -69,4 +63,5 @@ private[zio] object NettyRequestEncoder { new DefaultHttpRequest(jVersion, method, path, headers) } } + } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala index 9658e437e..76ef429a4 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/model/Conversions.scala @@ -64,6 +64,13 @@ private[netty] object Conversions { case Headers.Empty => new DefaultHttpHeaders() } + def urlToNetty(url: URL): String = { + // As per the spec, the path should contain only the relative path. + // Host and port information should be in the headers. + val url0 = if (url.path.isEmpty) url.addLeadingSlash else url + url0.relative.addLeadingSlash.encode + } + private def nettyHeadersIterator(headers: HttpHeaders): Iterator[Header] = new AbstractIterator[Header] { private val nettyIterator = headers.iteratorCharSequence() diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 590180e37..4df974e94 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -20,7 +20,6 @@ import java.io.IOException import java.net.InetSocketAddress import java.util.concurrent.atomic.LongAdder -import scala.annotation.tailrec import scala.util.control.NonFatal import zio._ @@ -29,7 +28,6 @@ import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.http.Body.WebsocketBody import zio.http._ import zio.http.netty._ -import zio.http.netty.client.NettyRequestEncoder import zio.http.netty.model.Conversions import zio.http.netty.socket.NettySocketProtocol @@ -287,7 +285,12 @@ private[zio] final case class ServerInboundHandler( ) .addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, None)) - val jReq = NettyRequestEncoder.encode(request) + val jReq = new DefaultFullHttpRequest( + Conversions.versionToNetty(request.version), + Conversions.methodToNetty(request.method), + Conversions.urlToNetty(request.url), + ) + jReq.headers().setAll(Conversions.headersToNetty(request.allHeaders)) ctx.channel().eventLoop().submit { () => ctx.fireChannelRead(jReq) }: Unit } } diff --git a/zio-http/jvm/src/test/scala/zio/http/ClientStreamingSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ClientStreamingSpec.scala index f18338092..e354b2961 100644 --- a/zio-http/jvm/src/test/scala/zio/http/ClientStreamingSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/ClientStreamingSpec.scala @@ -208,7 +208,7 @@ object ClientStreamingSpec extends HttpRunnableSpec { port <- server(streamingServer) client <- ZIO.service[Client] result <- check(Gen.int(1, N)) { chunkSize => - (for { + for { bytes <- Random.nextBytes(N) form = Form( Chunk( @@ -233,7 +233,7 @@ object ClientStreamingSpec extends HttpRunnableSpec { collected.map.contains("file"), collected.map.contains("foo"), collected.get("file").get.asInstanceOf[FormField.Binary].data == bytes, - )).tapErrorCause(cause => ZIO.debug(cause.prettyPrint)) + ) } } yield result } @@ samples(20) @@ TestAspect.ifEnvNotSet("CI"), diff --git a/zio-http/jvm/src/test/scala/zio/http/DualSSLSpec.scala b/zio-http/jvm/src/test/scala/zio/http/DualSSLSpec.scala index 18cf60957..8087f86c6 100644 --- a/zio-http/jvm/src/test/scala/zio/http/DualSSLSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/DualSSLSpec.scala @@ -50,7 +50,7 @@ object DualSSLSpec extends ZIOHttpSpec { includeClientCert = true, ) - val config = Server.Config.default.port(8073).ssl(sslConfigWithTrustedClient) + val config = Server.Config.default.port(8073).ssl(sslConfigWithTrustedClient).logWarningOnFatalError(false) val payload = Gen.alphaNumericStringBounded(10000, 20000) diff --git a/zio-http/jvm/src/test/scala/zio/http/MultipartMixedSpec.scala b/zio-http/jvm/src/test/scala/zio/http/MultipartMixedSpec.scala index 6192a7c02..8a6a4a97c 100644 --- a/zio-http/jvm/src/test/scala/zio/http/MultipartMixedSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/MultipartMixedSpec.scala @@ -210,7 +210,7 @@ object MultipartMixedSpec extends ZIOHttpSpec { test("property") { check(gens.genTestCase) { testCase => - zio.Console.printLine(testCase) *> testCase.runTests + testCase.runTests } } @@ TestAspect.shrinks(0) @@ -248,8 +248,8 @@ object MultipartMixedSpec extends ZIOHttpSpec { gens.breaker.fixed(512), ) - val innerTests = inner.runTests.map(_.label("inner")).debug("inner") - val outerTests = outer.runTests.map(_.label("outer")).debug("outer") + val innerTests = inner.runTests.map(_.label("inner")) + val outerTests = outer.runTests.map(_.label("outer")) val nestedTests = { val expectedNested = Nested.Multi( @@ -262,7 +262,6 @@ object MultipartMixedSpec extends ZIOHttpSpec { outer.partsToNested.map { collected => zio.test.assert(collected)(Assertion.equalTo(expectedNested)).label("nestedTests") } - .debug("nestedTests") } (innerTests <*> outerTests <*> nestedTests).map { case (i, o, n) => @@ -308,8 +307,8 @@ object MultipartMixedSpec extends ZIOHttpSpec { gens.breaker.fixed(512), ) - val innerTests = inner.runTests.map(_.label("inner")).debug("inner") - val outerTests = outer.runTests.map(_.label("outer")).debug("outer") + val innerTests = inner.runTests.map(_.label("inner")) + val outerTests = outer.runTests.map(_.label("outer")) val nestedTests = { val expectedNested = Nested.Multi( @@ -322,7 +321,6 @@ object MultipartMixedSpec extends ZIOHttpSpec { outer.partsToNested.map { collected => zio.test.assert(collected)(Assertion.equalTo(expectedNested)).label("nestedTests") } - .debug("nestedTests") } (innerTests <*> outerTests <*> nestedTests).map { case (i, o, n) => @@ -364,8 +362,8 @@ object MultipartMixedSpec extends ZIOHttpSpec { gens.breaker.fixed(512), ) - val innerTests = inner.runTests.map(_.label("inner")).debug("inner") - val outerTests = outer.runTests.map(_.label("outer")).debug("outer") + val innerTests = inner.runTests.map(_.label("inner")) + val outerTests = outer.runTests.map(_.label("outer")) val nestedTests = { val expectedNested = Nested.Multi( @@ -379,7 +377,6 @@ object MultipartMixedSpec extends ZIOHttpSpec { outer.partsToNested.map { collected => zio.test.assert(collected)(Assertion.equalTo(expectedNested)).label("nestedTests") } - .debug("nestedTests") } (innerTests <*> outerTests <*> nestedTests).map { case (i, o, n) => diff --git a/zio-http/jvm/src/test/scala/zio/http/RequestStreamingServerSpec.scala b/zio-http/jvm/src/test/scala/zio/http/RequestStreamingServerSpec.scala index 69647acb2..0a97d09f5 100644 --- a/zio-http/jvm/src/test/scala/zio/http/RequestStreamingServerSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/RequestStreamingServerSpec.scala @@ -78,16 +78,14 @@ object RequestStreamingServerSpec extends HttpRunnableSpec { val host = req.headers.get(Header.Host).get val newRequest = req.copy(url = req.url.path("/2").host(host.hostAddress).port(host.port.getOrElse(80))) - ZIO.debug(s"#1: got response, forwarding") *> - ZIO.serviceWithZIO[Client] { client => - client.request(newRequest) - } + ZIO.serviceWithZIO[Client] { client => + client.request(newRequest) + } }, Method.POST / "2" -> handler { (req: Request) => - ZIO.debug("#2: got response, collecting") *> - req.body.asChunk.map { body => - Response.text(body.length.toString) - } + req.body.asChunk.map { body => + Response.text(body.length.toString) + } }, ).sandbox val sizes = Chunk(0, 8192, 1024 * 1024) diff --git a/zio-http/jvm/src/test/scala/zio/http/SSLSpec.scala b/zio-http/jvm/src/test/scala/zio/http/SSLSpec.scala index 16cb728ef..80bc8bbd5 100644 --- a/zio-http/jvm/src/test/scala/zio/http/SSLSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/SSLSpec.scala @@ -27,7 +27,7 @@ import zio.http.netty.client.NettyClientDriver object SSLSpec extends ZIOHttpSpec { val sslConfig = SSLConfig.fromResource("server.crt", "server.key") - val config = Server.Config.default.port(8073).ssl(sslConfig) + val config = Server.Config.default.port(8073).ssl(sslConfig).logWarningOnFatalError(false) val clientSSL1 = ClientSSLConfig.FromCertResource("server.crt") val clientSSL2 = ClientSSLConfig.FromCertResource("ss2.crt.pem") diff --git a/zio-http/jvm/src/test/scala/zio/http/ServerStartSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ServerStartSpec.scala index d2806b0d1..000266f5a 100644 --- a/zio-http/jvm/src/test/scala/zio/http/ServerStartSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/ServerStartSpec.scala @@ -51,6 +51,15 @@ object ServerStartSpec extends HttpRunnableSpec { ZLayer.succeed(NettyConfig.defaultWithFastShutdown), ) }, + test("application can shutdown if server is not started") { + ZIO + .succeed(assertCompletes) + .provide( + Server.customized.unit, + ZLayer.succeed(Server.Config.default.port(8089)), + ZLayer.succeed(NettyConfig.defaultWithFastShutdown), + ) + }, ) override def spec: Spec[TestEnvironment with Scope, Any] = serverStartSpec @@ withLiveClock diff --git a/zio-http/jvm/src/test/scala/zio/http/StaticFileServerSpec.scala b/zio-http/jvm/src/test/scala/zio/http/StaticFileServerSpec.scala index 921e75b71..cddbd2525 100644 --- a/zio-http/jvm/src/test/scala/zio/http/StaticFileServerSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/StaticFileServerSpec.scala @@ -66,7 +66,7 @@ object StaticFileServerSpec extends HttpRunnableSpec { assertZIO(res)(equalTo("foo\nbar")) }, test("should have content-type") { - val res = fileOk.run().debug("fileOk").map(_.header(Header.ContentType)) + val res = fileOk.run().map(_.header(Header.ContentType)) assertZIO(res)(isSome(equalTo(Header.ContentType(MediaType.text.plain, charset = Some(Charsets.Utf8))))) }, test("should respond with empty if file not found") { @@ -121,7 +121,7 @@ object StaticFileServerSpec extends HttpRunnableSpec { assertZIO(res)(isSome(equalTo(Header.ContentType(MediaType.text.plain, charset = Some(Charsets.Utf8))))) }, test("should respond with empty if not found") { - val res = resourceNotFound.run().debug("not found").map(_.status) + val res = resourceNotFound.run().map(_.status) assertZIO(res)(equalTo(Status.NotFound)) }, ), diff --git a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala index 3f87b59f7..3ea000d99 100644 --- a/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/WebSocketSpec.scala @@ -23,197 +23,205 @@ import zio.test.{TestClock, assertCompletes, assertTrue, assertZIO, testClock} import zio.http.ChannelEvent.UserEvent.HandshakeComplete import zio.http.ChannelEvent.{Read, Unregistered, UserEvent, UserEventTriggered} -import zio.http.internal.{DynamicServer, HttpRunnableSpec, serverTestLayer} +import zio.http.internal.{DynamicServer, HttpRunnableSpec, serverTestLayer, testNettyServerConfig, testServerConfig} object WebSocketSpec extends HttpRunnableSpec { - private val websocketSpec = suite("WebsocketSpec")( - test("channel events between client and server") { - for { - msg <- MessageCollector.make[WebSocketChannelEvent] - url <- DynamicServer.wsURL - id <- DynamicServer.deploy { - Handler.webSocket { channel => + private val websocketSpec = + List( + test("channel events between client and server") { + for { + msg <- MessageCollector.make[WebSocketChannelEvent] + url <- DynamicServer.wsURL + id <- DynamicServer.deploy { + Handler.webSocket { channel => + channel.receiveAll { + case event @ Read(frame) => channel.send(Read(frame)) *> msg.add(event) + case event: Unregistered.type => msg.add(event, isDone = true) + case event => msg.add(event) + } + }.toRoutes + } + + res <- ZIO.scoped { + Handler.webSocket { channel => + channel.receiveAll { + case UserEventTriggered(HandshakeComplete) => + channel.send(Read(WebSocketFrame.text("FOO"))) + case Read(WebSocketFrame.Text("FOO")) => + channel.send(Read(WebSocketFrame.text("BAR"))) + case Read(WebSocketFrame.Text("BAR")) => + channel.shutdown + case _ => + ZIO.unit + } + }.connect(url, Headers(DynamicServer.APP_ID, id)) *> { + for { + events <- msg.await + expected = List( + UserEventTriggered(HandshakeComplete), + Read(WebSocketFrame.text("FOO")), + Read(WebSocketFrame.text("BAR")), + Unregistered, + ) + } yield assertTrue(events == expected) + } + } + } yield res + }, + test("on close interruptibility") { + for { + + // Maintain a flag to check if the close handler was completed + isSet <- Promise.make[Nothing, Unit] + isStarted <- Promise.make[Nothing, Unit] + clock <- testClock + + // Setup websocket server + + serverHttp = Handler.webSocket { channel => channel.receiveAll { - case event @ Read(frame) => channel.send(Read(frame)) *> msg.add(event) - case event: Unregistered.type => msg.add(event, isDone = true) - case event => msg.add(event) + case Unregistered => + isStarted.succeed(()) <&> isSet.succeed(()).delay(5 seconds).withClock(clock) + case _ => + ZIO.unit } - }.toRoutes - } + }.toRoutes.deployWS - res <- ZIO.scoped { - Handler.webSocket { channel => + // Setup Client + // Client closes the connection after 1 second + clientSocket = Handler.webSocket { channel => channel.receiveAll { case UserEventTriggered(HandshakeComplete) => - channel.send(Read(WebSocketFrame.text("FOO"))) - case Read(WebSocketFrame.Text("FOO")) => - channel.send(Read(WebSocketFrame.text("BAR"))) - case Read(WebSocketFrame.Text("BAR")) => - channel.shutdown + channel.send(Read(WebSocketFrame.close(1000))).delay(1 second).withClock(clock) case _ => ZIO.unit } - }.connect(url, Headers(DynamicServer.APP_ID, id)) *> { - for { - events <- msg.await - expected = List( - UserEventTriggered(HandshakeComplete), - Read(WebSocketFrame.text("FOO")), - Read(WebSocketFrame.text("BAR")), - Unregistered, - ) - } yield assertTrue(events == expected) } - } - } yield res - }, - test("on close interruptibility") { - for { - - // Maintain a flag to check if the close handler was completed - isSet <- Promise.make[Nothing, Unit] - isStarted <- Promise.make[Nothing, Unit] - clock <- testClock - - // Setup websocket server - - serverHttp = Handler.webSocket { channel => - channel.receiveAll { - case Unregistered => - isStarted.succeed(()) <&> isSet.succeed(()).delay(5 seconds).withClock(clock) - case _ => - ZIO.unit - } - }.toRoutes.deployWS - - // Setup Client - // Client closes the connection after 1 second - clientSocket = Handler.webSocket { channel => - channel.receiveAll { - case UserEventTriggered(HandshakeComplete) => - channel.send(Read(WebSocketFrame.close(1000))).delay(1 second).withClock(clock) - case _ => - ZIO.unit + + // Deploy the server and send it a socket request + _ <- serverHttp.runZIO(clientSocket) + + // Wait for the close handler to complete + _ <- TestClock.adjust(2 seconds) + _ <- isStarted.await + _ <- TestClock.adjust(5 seconds) + _ <- isSet.await + + // Check if the close handler was completed + } yield assertCompletes + } @@ nonFlaky(25), + test("Multiple websocket upgrades") { + val app = + Handler.webSocket(channel => channel.send(ChannelEvent.Read(WebSocketFrame.text("BAR")))).toRoutes.deployWS + val codes = ZIO + .foreach(1 to 1024)(_ => app.runZIO(WebSocketApp.unit).map(_.status)) + .map(_.count(_ == Status.SwitchingProtocols)) + + assertZIO(codes)(equalTo(1024)) + } @@ ignore, + test("channel events between client and server when the provided URL is HTTP") { + for { + msg <- MessageCollector.make[WebSocketChannelEvent] + url <- DynamicServer.httpURL + id <- DynamicServer.deploy { + Handler.webSocket { channel => + channel.receiveAll { + case event @ Read(frame) => channel.send(Read(frame)) *> msg.add(event) + case event: Unregistered.type => msg.add(event, isDone = true) + case event => msg.add(event) + } + }.toRoutes } - } - - // Deploy the server and send it a socket request - _ <- serverHttp.runZIO(clientSocket) - - // Wait for the close handler to complete - _ <- TestClock.adjust(2 seconds) - _ <- isStarted.await - _ <- TestClock.adjust(5 seconds) - _ <- isSet.await - - // Check if the close handler was completed - } yield assertCompletes - } @@ nonFlaky, - test("Multiple websocket upgrades") { - val app = - Handler.webSocket(channel => channel.send(ChannelEvent.Read(WebSocketFrame.text("BAR")))).toRoutes.deployWS - val codes = ZIO - .foreach(1 to 1024)(_ => app.runZIO(WebSocketApp.unit).map(_.status)) - .map(_.count(_ == Status.SwitchingProtocols)) - - assertZIO(codes)(equalTo(1024)) - } @@ ignore, - test("channel events between client and server when the provided URL is HTTP") { - for { - msg <- MessageCollector.make[WebSocketChannelEvent] - url <- DynamicServer.httpURL - id <- DynamicServer.deploy { - Handler.webSocket { channel => - channel.receiveAll { - case event @ Read(frame) => channel.send(Read(frame)) *> msg.add(event) - case event: Unregistered.type => msg.add(event, isDone = true) - case event => msg.add(event) - } - }.toRoutes - } - res <- ZIO.scoped { - Handler.webSocket { channel => - channel.receiveAll { - case UserEventTriggered(HandshakeComplete) => - channel.send(Read(WebSocketFrame.text("FOO"))) - case Read(WebSocketFrame.Text("FOO")) => - channel.send(Read(WebSocketFrame.text("BAR"))) - case Read(WebSocketFrame.Text("BAR")) => - channel.shutdown - case _ => - ZIO.unit + res <- ZIO.scoped { + Handler.webSocket { channel => + channel.receiveAll { + case UserEventTriggered(HandshakeComplete) => + channel.send(Read(WebSocketFrame.text("FOO"))) + case Read(WebSocketFrame.Text("FOO")) => + channel.send(Read(WebSocketFrame.text("BAR"))) + case Read(WebSocketFrame.Text("BAR")) => + channel.shutdown + case _ => + ZIO.unit + } + }.connect(url, Headers(DynamicServer.APP_ID, id)) *> { + for { + events <- msg.await + expected = List( + UserEventTriggered(HandshakeComplete), + Read(WebSocketFrame.text("FOO")), + Read(WebSocketFrame.text("BAR")), + Unregistered, + ) + } yield assertTrue(events == expected) } - }.connect(url, Headers(DynamicServer.APP_ID, id)) *> { - for { - events <- msg.await - expected = List( - UserEventTriggered(HandshakeComplete), - Read(WebSocketFrame.text("FOO")), - Read(WebSocketFrame.text("BAR")), - Unregistered, - ) - } yield assertTrue(events == expected) } - } - } yield res - }, - test("Client connection is interruptible") { - for { - url <- DynamicServer.httpURL - id <- DynamicServer.deploy { - Handler.webSocket { channel => - ZIO.debug("receiveAll") *> - channel.receiveAll { evt => - println(evt) - evt match { - case ChannelEvent.UserEventTriggered(UserEvent.HandshakeComplete) => - ZIO.debug("registered") *> - ZIO - .foreachDiscard(1 to 100) { idx => - ZIO.debug(s"sending $idx") *> - channel.send(Read(WebSocketFrame.text(idx.toString))) *> ZIO.sleep(100.millis) - } - .forkScoped - case _ => ZIO.unit - } + } yield res + }, + test("Client connection is interruptible") { + for { + url <- DynamicServer.httpURL + id <- DynamicServer.deploy { + Handler.webSocket { channel => + channel.receiveAll { + case ChannelEvent.UserEventTriggered(UserEvent.HandshakeComplete) => + ZIO + .foreachDiscard(1 to 100) { idx => + channel.send(Read(WebSocketFrame.text(idx.toString))) *> ZIO.sleep(100.millis) + } + .forkScoped + case _ => ZIO.unit } - }.toRoutes - } + }.toRoutes + } - queue1 <- Queue.unbounded[String] - queue2 <- Queue.unbounded[String] - _ <- ZIO.scoped { - Handler.webSocket { channel => - channel.receiveAll { - case Read(WebSocketFrame.Text(s)) => - println(s"read $s") - queue1.offer(s) - case _ => - ZIO.unit - }.onInterrupt(ZIO.debug("ws interrupted")) - }.connect(url, Headers(DynamicServer.APP_ID, id)) *> - queue1.take - .tap(s => - ZIO.debug(s"got $s") *> - queue2.offer(s), - ) - .repeatUntil(_ == "5") - .timeout(1.second) - .debug - } - result <- queue2.takeAll - } yield assertTrue(result == Chunk("1", "2", "3", "4", "5")) - }, - ) - - override def spec = suite("Server") { - serve.as(List(websocketSpec)) - } - .provideSome[DynamicServer & Server & Client](Scope.default) - .provideShared(DynamicServer.live, serverTestLayer, Client.default) @@ - diagnose(30.seconds) @@ withLiveClock @@ sequential + queue1 <- Queue.unbounded[String] + queue2 <- Queue.unbounded[String] + _ <- ZIO.scoped { + Handler.webSocket { channel => + channel.receiveAll { + case Read(WebSocketFrame.Text(s)) => + queue1.offer(s) + case _ => + ZIO.unit + } + }.connect(url, Headers(DynamicServer.APP_ID, id)) *> + queue1.take + .tap(s => queue2.offer(s)) + .repeatUntil(_ == "5") + .timeout(1.second) + } + result <- queue2.takeAll + } yield assertTrue(result == Chunk("1", "2", "3", "4", "5")) + }, + ) + + private val withStreamingEnabled = + suite("request streaming enabled")( + serve.as(websocketSpec), + ) + .provideSome[DynamicServer & Server & Client](Scope.default) + .provideShared( + DynamicServer.live, + ZLayer.succeed(Server.Config.default.onAnyOpenPort.enableRequestStreaming), + testNettyServerConfig, + Server.customized, + Client.default, + ) @@ sequential + + private val withStreamingDisabled = + suite("request streaming disabled")( + serve.as(websocketSpec), + ) + .provideSome[DynamicServer & Server & Client](Scope.default) + .provideShared(DynamicServer.live, serverTestLayer, Client.default) @@ sequential + + override def spec = suite("WebSocketSpec")( + withStreamingDisabled, + withStreamingEnabled, + ) @@ diagnose(30.seconds) @@ withLiveClock final class MessageCollector[A](ref: Ref[List[A]], promise: Promise[Nothing, Unit]) { def add(a: A, isDone: Boolean = false): UIO[Unit] = ref.update(_ :+ a) <* promise.succeed(()).when(isDone) diff --git a/zio-http/jvm/src/test/scala/zio/http/ZIOHttpSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ZIOHttpSpec.scala index 2a70d0bd1..29cace490 100644 --- a/zio-http/jvm/src/test/scala/zio/http/ZIOHttpSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/ZIOHttpSpec.scala @@ -5,5 +5,5 @@ import zio.test._ trait ZIOHttpSpec extends ZIOSpecDefault { override def aspects: Chunk[TestAspectPoly] = - Chunk(TestAspect.timeout(60.seconds), TestAspect.timed) + Chunk(TestAspect.timeout(60.seconds), TestAspect.timed, TestAspect.silentLogging, TestAspect.silent) } diff --git a/zio-http/jvm/src/test/scala/zio/http/codec/PathCodecSpec.scala b/zio-http/jvm/src/test/scala/zio/http/codec/PathCodecSpec.scala index 68216f7d8..6dfdba639 100644 --- a/zio-http/jvm/src/test/scala/zio/http/codec/PathCodecSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/codec/PathCodecSpec.scala @@ -24,6 +24,7 @@ import zio._ import zio.test._ import zio.http._ +import zio.http.codec.PathCodec.Segment import zio.http.codec._ object PathCodecSpec extends ZIOHttpSpec { @@ -115,6 +116,67 @@ object PathCodecSpec extends ZIOHttpSpec { ) }, ), + suite("decoding with sub-segment codecs")( + test("int") { + val codec = PathCodec.empty / + string("foo") / + "instances" / + int("a") ~ "_" ~ int("b") / + "bar" / + int("baz") + + assertTrue(codec.decode(Path("/abc/instances/123_13/bar/42")) == Right(("abc", 123, 13, 42))) + }, + test("uuid") { + val codec = PathCodec.empty / + string("foo") / + "foo" / + uuid("a") ~ "__" ~ int("b") / + "bar" / + int("baz") + + val id = UUID.randomUUID() + val p = s"/abc/foo/${id}__13/bar/42" + assertTrue(codec.decode(Path(p)) == Right(("abc", id, 13, 42))) + }, + test("string before literal") { + val codec = PathCodec.empty / + string("foo") / + "foo" / + string("a") ~ "__" ~ int("b") / + "bar" / + int("baz") + assertTrue(codec.decode(Path("/abc/foo/cba__13/bar/42")) == Right(("abc", "cba", 13, 42))) + }, + test("string before int") { + val codec = PathCodec.empty / + string("foo") / + "foo" / + string("a") ~ int("b") / + "bar" / + int("baz") + assertTrue(codec.decode(Path("/abc/foo/cba13/bar/42")) == Right(("abc", "cba", 13, 42))) + }, + test("string before long") { + val codec = PathCodec.empty / + string("foo") / + "foo" / + string("a") ~ long("b") / + "bar" / + int("baz") + assertTrue(codec.decode(Path("/abc/foo/cba133333333333/bar/42")) == Right(("abc", "cba", 133333333333L, 42))) + }, + test("trailing literal") { + val codec = PathCodec.empty / + string("foo") / + "instances" / + int("a") ~ "what" / + "bar" / + int("baz") + + assertTrue(codec.decode(Path("/abc/instances/123what/bar/42")) == Right(("abc", 123, 42))) + }, + ), suite("representation")( test("empty") { val codec = PathCodec.empty @@ -149,6 +211,13 @@ object PathCodecSpec extends ZIOHttpSpec { assertTrue(codec.render == "/users/{user-id}/posts/{post-id}") }, + test("/users/{first-name}_{last-name}") { + val codec = + PathCodec.empty / PathCodec.literal("users") / + string("first-name") ~ "_" ~ string("last-name") + + assertTrue(codec.render == "/users/{first-name}_{last-name}") + }, test("transformed") { val codec = PathCodec.path("/users") / diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala index 65ba04bce..aef2838f4 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/RoundtripSpec.scala @@ -42,7 +42,7 @@ object RoundtripSpec extends ZIOHttpSpec { ZLayer.make[Server & Client & Scope]( Server.customized, ZLayer.succeed(Server.Config.default.onAnyOpenPort.enableRequestStreaming), - Client.customized.map(env => ZEnvironment(env.get @@ ZClientAspect.debug)), + Client.customized.map(env => ZEnvironment(env.get)), ClientDriver.shared, // NettyDriver.customized, ZLayer.succeed(NettyConfig.defaultWithFastShutdown), diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala index ef5e7c550..26b26f642 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -113,6 +113,24 @@ object OpenAPIGenSpec extends ZIOSpecDefault { case class NestedThree(name: String) extends SimpleNestedSealedTrait } + @description("A recursive structure") + case class Recursive( + nestedOption: Option[Recursive], + nestedList: List[Recursive], + nestedMap: Map[String, Recursive], + nestedSet: Set[Recursive], + nestedEither: Either[Recursive, Recursive], + nestedTuple: (Recursive, Recursive), + nestedOverAnother: NestedRecursive, + ) + object Recursive { + implicit val schema: Schema[Recursive] = DeriveSchema.gen[Recursive] + } + case class NestedRecursive(next: Recursive) + object NestedRecursive { + implicit val schema: Schema[NestedRecursive] = DeriveSchema.gen[NestedRecursive] + } + @description("A simple payload") case class Payload(content: String) @@ -2440,6 +2458,120 @@ object OpenAPIGenSpec extends ZIOSpecDefault { SwaggerUI.routes("docs/openapi", OpenAPIGen.fromEndpoints(endpoint)) assertCompletes }, + test("Recursive schema") { + val endpoint = Endpoint(RoutePattern.POST / "folder") + .out[Recursive] + val openApi = OpenAPIGen.fromEndpoints(endpoint) + val json = toJsonAst(openApi) + val expectedJson = + """ + |{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "", + | "version" : "" + | }, + | "paths" : { + | "/folder" : { + | "post" : { + | "responses" : { + | "200" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/Recursive" + | } + | } + | } + | } + | } + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NestedRecursive" : + | { + | "type" : + | "object", + | "properties" : { + | "next" : { + | "$ref" : "#/components/schemas/Recursive" + | } + | }, + | "required" : [ + | "next" + | ] + | }, + | "Recursive" : + | { + | "type" : + | "object", + | "properties" : { + | "nestedSet" : { + | "type" : + | "array", + | "items" : { + | "$ref" : "#/components/schemas/Recursive" + | } + | }, + | "nestedEither" : { + | "oneOf" : [ + | { + | "$ref" : "#/components/schemas/Recursive" + | }, + | { + | "$ref" : "#/components/schemas/Recursive" + | } + | ] + | }, + | "nestedTuple" : { + | "allOf" : [ + | { + | "$ref" : "#/components/schemas/Recursive" + | }, + | { + | "$ref" : "#/components/schemas/Recursive" + | } + | ] + | }, + | "nestedOption" : { + | "$ref" : "#/components/schemas/Recursive" + | }, + | "nestedList" : { + | "type" : + | "array", + | "items" : { + | "$ref" : "#/components/schemas/Recursive" + | } + | }, + | "nestedOverAnother" : { + | "$ref" : "#/components/schemas/NestedRecursive" + | } + | }, + | "additionalProperties" : + | { + | "$ref" : "#/components/schemas/Recursive" + | }, + | "required" : [ + | "nestedOption", + | "nestedList", + | "nestedMap", + | "nestedSet", + | "nestedEither", + | "nestedTuple", + | "nestedOverAnother" + | ], + | "description" : "A recursive structure" + | } + | } + | } + |} + |""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, ) } diff --git a/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/AuthSpec.scala b/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/AuthSpec.scala index 4b8179bf6..3242734c3 100644 --- a/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/AuthSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/AuthSpec.scala @@ -119,11 +119,11 @@ object AuthSpec extends ZIOHttpSpec with HttpAppTestExtensions { val app = secureRoutes for { s1 <- app.runZIO(Request.get(URL(Path.root / "a")).copy(headers = successBasicHeader)) - s1Body <- s1.body.asString.debug("s1Body") + s1Body <- s1.body.asString s2 <- app.runZIO(Request.get(URL(Path.root / "b" / "1")).copy(headers = successBasicHeader)) - s2Body <- s2.body.asString.debug("s2Body") + s2Body <- s2.body.asString s3 <- app.runZIO(Request.get(URL(Path.root / "c" / "name")).copy(headers = successBasicHeader)) - s3Body <- s3.body.asString.debug("s3Body") + s3Body <- s3.body.asString resultStatus = s1.status == Status.Ok && s2.status == Status.Ok && s3.status == Status.Ok resultBody = s1Body == "user" && s2Body == "for id: 1: user" && s3Body == "for name: name: user" } yield assertTrue(resultStatus, resultBody) diff --git a/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/WebSpec.scala b/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/WebSpec.scala index 42257b1e4..2237a3bd8 100644 --- a/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/WebSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/internal/middlewares/WebSpec.scala @@ -249,8 +249,7 @@ object WebSpec extends ZIOHttpSpec with HttpAppTestExtensions { self => for { url <- ZIO.fromEither(URL.decode(url)) - response <- app.runZIO(Request.get(url = url)).debug("response") - _ <- ZIO.debug(response.headerOrFail(Header.Location)) + response <- app.runZIO(Request.get(url = url)) } yield assertTrue( extractStatus(response) == status, response.header(Header.Location) == location.map(l => Header.Location(URL.decode(l).toOption.get)), diff --git a/zio-http/jvm/src/test/scala/zio/http/internal/package.scala b/zio-http/jvm/src/test/scala/zio/http/internal/package.scala index af499564d..c0d6d0ee2 100644 --- a/zio-http/jvm/src/test/scala/zio/http/internal/package.scala +++ b/zio-http/jvm/src/test/scala/zio/http/internal/package.scala @@ -25,7 +25,7 @@ import zio.http.netty.client.NettyClientDriver package object internal { val testServerConfig: ZLayer[Any, Nothing, Server.Config] = - ZLayer.succeed(Server.Config.default.onAnyOpenPort) + ZLayer.succeed(Server.Config.default.onAnyOpenPort.logWarningOnFatalError(false)) val testNettyServerConfig: ZLayer[Any, Nothing, NettyConfig] = ZLayer.succeed( diff --git a/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala b/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala index ad4acef18..5b0e09786 100644 --- a/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/netty/NettyStreamBodySpec.scala @@ -4,12 +4,14 @@ import zio._ import zio.test.TestAspect.withLiveClock import zio.test.{Spec, TestEnvironment, assertTrue} -import zio.stream.{ZStream, ZStreamAspect} +import zio.stream.{ZPipeline, ZStream, ZStreamAspect} import zio.http.ZClient.Config import zio.http._ import zio.http.internal.HttpRunnableSpec +import zio.http.multipart.mixed.MultipartMixed import zio.http.netty.NettyConfig.LeakDetectionLevel +import zio.http.netty.NettyStreamBodySpec.app object NettyStreamBodySpec extends HttpRunnableSpec { @@ -101,6 +103,82 @@ object NettyStreamBodySpec extends HttpRunnableSpec { ) } }, + test("properly decodes body's boundary") { + def trackablePart(content: String): ZIO[Any, Nothing, (MultipartMixed.Part, Promise[Nothing, Boolean])] = { + zio.Promise.make[Nothing, Boolean].map { p => + MultipartMixed.Part( + Headers(Header.ContentType(MediaType.text.`plain`)), + ZStream(content) + .via(ZPipeline.utf8Encode) + .ensuring(p.succeed(true)), + ) -> + p + } + } + def trackableMultipartMixed( + b: Boundary, + )(partsContents: String*): ZIO[Any, Nothing, (MultipartMixed, Seq[Promise[Nothing, Boolean]])] = { + ZIO + .foreach(partsContents)(trackablePart) + .map { tps => + val (parts, promisises) = tps.unzip + val mpm = MultipartMixed.fromParts(ZStream.fromIterable(parts), b, 1) + (mpm, promisises) + } + } + + def serve(resp: Response): ZIO[Any, Throwable, RuntimeFlags] = { + val app = Routes(Method.GET / "it" -> handler(resp)) + for { + portPromise <- Promise.make[Throwable, Int] + _ <- Server + .install(app) + .intoPromise(portPromise) + .zipRight(ZIO.never) + .provide( + ZLayer.succeed(NettyConfig.defaultWithFastShutdown.leakDetection(LeakDetectionLevel.PARANOID)), + ZLayer.succeed(Server.Config.default.onAnyOpenPort), + Server.customized, + ) + .fork + port <- portPromise.await + } yield port + } + + for { + mpmAndPromises <- trackableMultipartMixed(Boundary("this_is_a_boundary"))( + "this is the boring part 1", + "and this is the boring part two", + ) + (mpm, promises) = mpmAndPromises + resp = Response(body = + Body.fromStreamChunked(mpm.source).contentType(MediaType.multipart.`mixed`, mpm.boundary), + ) + .addHeader(Header.ContentType(MediaType.multipart.`mixed`, Some(mpm.boundary))) + port <- serve(resp) + client <- ZIO.service[Client] + req = Request.get(s"http://localhost:$port/it") + actualResp <- client(req) + actualMpm <- actualResp.body.asMultipartMixed + partsResults <- actualMpm.parts.zipWithIndex.mapZIO { case (part, idx) => + val pr = promises(idx.toInt) + // todo: due to server side buffering can't really expect the promises to be uncompleted BEFORE pulling on the client side + part.toBody.asString <*> + pr.isDone + }.runCollect + } yield { + zio.test.assertTrue { + actualResp.headers(Header.ContentType) == resp.headers(Header.ContentType) && + actualResp.body.boundary == Some(mpm.boundary) && + actualMpm.boundary == mpm.boundary && + partsResults == Chunk( + // todo: due to server side buffering can't really expect the promises to be uncompleted BEFORE pulling on the client side + ("this is the boring part 1", true), + ("and this is the boring part two", true), + ) + } + } + }, ).provide( singleConnectionClient, Scope.default, diff --git a/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala b/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala index e8f841aa7..532ba26aa 100644 --- a/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/netty/client/NettyConnectionPoolSpec.scala @@ -195,7 +195,7 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec { }.provideSome[Client & Scope]( ZLayer(appKeepAliveEnabled.unit), DynamicServer.live, - ZLayer.succeed(Server.Config.default.idleTimeout(500.millis).onAnyOpenPort), + ZLayer.succeed(Server.Config.default.idleTimeout(500.millis).onAnyOpenPort.logWarningOnFatalError(false)), testNettyServerConfig, Server.customized, ) @@ withLiveClock @@ -211,7 +211,7 @@ object NettyConnectionPoolSpec extends HttpRunnableSpec { }.provideSome[Scope]( ZLayer(appKeepAliveEnabled.unit), DynamicServer.live, - ZLayer.succeed(Server.Config.default.idleTimeout(500.millis).onAnyOpenPort), + ZLayer.succeed(Server.Config.default.idleTimeout(500.millis).onAnyOpenPort.logWarningOnFatalError(false)), testNettyServerConfig, Server.customized, Client.live, diff --git a/zio-http/shared/src/main/scala/zio/http/Server.scala b/zio-http/shared/src/main/scala/zio/http/Server.scala index c1406afbb..2d5912552 100644 --- a/zio-http/shared/src/main/scala/zio/http/Server.scala +++ b/zio-http/shared/src/main/scala/zio/http/Server.scala @@ -446,18 +446,16 @@ object Server extends ServerPlatformSpecific { initialInstall <- Promise.make[Nothing, Unit] serverStarted <- Promise.make[Throwable, Int] _ <- - ( - initialInstall.await *> - driver.start.flatMap { result => - inFlightRequests.succeed(result.inFlightRequests) &> - serverStarted.succeed(result.port) - } - .catchAll(serverStarted.fail) - ) + (for { + _ <- initialInstall.await.interruptible + result <- driver.start + _ <- inFlightRequests.succeed(result.inFlightRequests) + _ <- serverStarted.succeed(result.port) + } yield ()) // In the case of failure of `Driver#.start` or interruption while we are waiting to be - // installed for the first time, we should should always fail the `serverStarted` - // promise to allow the finalizers to make progress. - .catchAllCause(cause => inFlightRequests.failCause(cause)) + // installed for the first time, we should always fail the `serverStarted` and 'inFlightRequests' + // promises to allow the finalizers to make progress. + .onError(c => inFlightRequests.refailCause(c) *> serverStarted.refailCause(c)) .forkScoped } yield ServerLive(driver, initialInstall, serverStarted) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala index bfcd0ebf6..9df735cec 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala @@ -18,6 +18,7 @@ package zio.http.codec import scala.annotation.tailrec import scala.collection.immutable.ListMap +import scala.collection.mutable import scala.language.implicitConversions import zio._ @@ -209,7 +210,7 @@ sealed trait PathCodec[A] { self => val segment = segments(j) j = j + 1 try { - stack.push(java.util.UUID.fromString(segment.toString)) + stack.push(java.util.UUID.fromString(segment)) } catch { case _: IllegalArgumentException => fail = s"Expected UUID path segment but found ${segment}" @@ -225,9 +226,9 @@ sealed trait PathCodec[A] { self => val segment = segments(j) j = j + 1 - if (segment == "true") { + if (segment.equalsIgnoreCase("true")) { stack.push(true) - } else if (segment == "false") { + } else if (segment.equalsIgnoreCase("false")) { stack.push(false) } else { fail = s"Expected boolean path segment but found ${segment}" @@ -263,6 +264,15 @@ sealed trait PathCodec[A] { self => case Right(value) => stack.push(value) } + + case SubSegmentOpts(ops) => + val error = decodeSubstring(segments(j), ops, stack) + if (error != null) { + fail = error + i = instructions.length + } else { + j += 1 + } } i = i + 1 @@ -278,6 +288,163 @@ sealed trait PathCodec[A] { self => } } + private def decodeSubstring( + value: String, + instructions: Array[Opt], + stack: java.util.Deque[Any], + ): String = { + import Opt._ + + var i = 0 + var j = 0 + val size = value.length + while (i < instructions.length) { + val opt = instructions(i) + opt match { + case Match(toMatch) => + val size0 = toMatch.length + if ((size - j) < size0) { + return "Expected \"" + toMatch + "\" in segment " + value + " but found end of segment" + } else if (value.startsWith(toMatch, j)) { + stack.push(()) + j += size0 + } else { + return "Expected \"" + toMatch + "\" in segment " + value + " but found: " + value.substring(j) + } + case Combine(combiner0) => + val combiner = combiner0.asInstanceOf[Combiner[Any, Any]] + val right = stack.pop() + val left = stack.pop() + stack.push(combiner.combine(left, right)) + case StringOpt => + // Here things get "interesting" (aka annoying). We don't have a way of knowing when a string ends, + // so we have to look ahead to the next operator and figure out where it begins + val end = indexOfNextCodec(value, instructions, i, j) + if (end == -1) { // If this wasn't the last codec, let the error handler of the next codec handle this + stack.push(value.substring(j)) + j = size + } else { + stack.push(value.substring(j, end)) + j = end + } + case IntOpt => + val isNegative = value(j) == '-' + if (isNegative) j += 1 + var end = j + while (end < size && value(end).isDigit) end += 1 + if (end == j) { + return "Expected integer path segment but found end of segment" + } else if (end - j > 10) { + return "Expected integer path segment but found: " + value.substring(j, end) + } else { + + try { + val int = Integer.parseInt(value, j, end, 10) + j = end + if (isNegative) stack.push(-int) else stack.push(int) + } catch { + case _: NumberFormatException => + return "Expected integer path segment but found: " + value.substring(j, end) + } + } + case LongOpt => + val isNegative = value(j) == '-' + if (isNegative) j += 1 + var end = j + while (end < size && value(end).isDigit) end += 1 + if (end == j) { + return "Expected long path segment but found end of segment" + } else if (end - j > 19) { + return "Expected long path segment but found: " + value.substring(j, end) + } else { + try { + val long = java.lang.Long.parseLong(value, j, end, 10) + j = end + if (isNegative) stack.push(-long) else stack.push(long) + } catch { + case _: NumberFormatException => return "Expected long path segment but found: " + value.substring(j, end) + } + } + case UUIDOpt => + if ((size - j) < 36) { + return "Remaining path segment " + value.substring(j) + " is too short to be a UUID" + } else { + val sub = value.substring(j, j + 36) + try { + stack.push(java.util.UUID.fromString(sub)) + } catch { + case _: IllegalArgumentException => return "Expected UUID path segment but found: " + sub + } + j += 36 + } + case BoolOpt => + if (value.regionMatches(true, j, "true", 0, 4)) { + stack.push(true) + j += 4 + } else if (value.regionMatches(true, j, "false", 0, 5)) { + stack.push(false) + j += 5 + } else { + return "Expected boolean path segment but found end of segment" + } + case TrailingOpt => + // TrailingOpt must be invalid, since it wants to extract a path, + // which is not possible in a sub part of a segment. + // The equivalent of trailing here is just StringOpt + throw new IllegalStateException("TrailingOpt is not allowed in a sub segment") + case _ => + throw new IllegalStateException("Unexpected instruction in substring decoder") + } + i += 1 + } + if (j != size) "Expected end of segment but found: " + value.substring(j) + else null + } + + private def indexOfNextCodec(value: String, instructions: Array[Opt], fromI: Int, idx: Int): Int = { + import Opt._ + + var nextOpt = null.asInstanceOf[Opt] + var j1 = fromI + 1 + + while ((nextOpt eq null) && j1 < instructions.length) { + instructions(j1) match { + case op @ (Match(_) | IntOpt | LongOpt | UUIDOpt | BoolOpt) => + nextOpt = op + case _ => + j1 += 1 + } + } + + nextOpt match { + case null => + -1 + case Match(toMatch) => + if (idx + toMatch.length > value.length) -1 + else if (toMatch.length == 1) value.indexOf(toMatch.charAt(0).toInt, idx) + else value.indexOf(toMatch, idx) + case IntOpt | LongOpt => + value.indexWhere(_.isDigit, idx) + case BoolOpt => + val t = value.regionMatches(true, idx, "true", 0, 4) + if (t) idx + 4 else if (value.regionMatches(true, idx, "false", 0, 5)) idx + 5 else -1 + case UUIDOpt => + val until = SegmentCodec.UUID.isUUIDUntil(value, idx) + if (until == -1) -1 else idx + until + case MatchAny(values) => + var end = -1 + val valuesIt = values.iterator + while (valuesIt.hasNext && end == -1) { + val value = valuesIt.next() + val index = value.indexOf(value, idx) + if (index != -1) end = index + } + end + case _ => + throw new IllegalStateException("Unexpected instruction in substring decoder: " + nextOpt) + } + } + /** * Returns the documentation for the path codec, if any. */ @@ -346,33 +513,47 @@ sealed trait PathCodec[A] { self => private var _optimize: Array[Opt] = null.asInstanceOf[Array[Opt]] private[http] def optimize: Array[Opt] = { - def loop(pattern: PathCodec[_]): Chunk[Opt] = + + def loopSegment(segment: SegmentCodec[_], fresh: Boolean)(implicit b: mutable.ArrayBuilder[Opt]): Unit = + segment match { + case SegmentCodec.Empty => b += Opt.Unit + case SegmentCodec.Literal(value) => b += Opt.Match(value) + case SegmentCodec.IntSeg(_) => b += Opt.IntOpt + case SegmentCodec.LongSeg(_) => b += Opt.LongOpt + case SegmentCodec.Text(_) => b += Opt.StringOpt + case SegmentCodec.UUID(_) => b += Opt.UUIDOpt + case SegmentCodec.BoolSeg(_) => b += Opt.BoolOpt + case SegmentCodec.Trailing => b += Opt.TrailingOpt + case SegmentCodec.Combined(left, right, combiner) => + val ab = if (fresh) mutable.ArrayBuilder.make[Opt] else b + loopSegment(left, fresh = false)(ab) + loopSegment(right, fresh = false)(ab) + ab += Opt.Combine(combiner) + if (fresh) b += Opt.SubSegmentOpts(ab.result().asInstanceOf[Array[Opt]]) + } + + def loop(pattern: PathCodec[_])(implicit b: mutable.ArrayBuilder[Opt]): Unit = pattern match { case PathCodec.Annotated(codec, _) => loop(codec) case PathCodec.Segment(segment) => - Chunk(segment.asInstanceOf[SegmentCodec[_]] match { - case SegmentCodec.Empty => Opt.Unit - case SegmentCodec.Literal(value) => Opt.Match(value) - case SegmentCodec.IntSeg(_) => Opt.IntOpt - case SegmentCodec.LongSeg(_) => Opt.LongOpt - case SegmentCodec.Text(_) => Opt.StringOpt - case SegmentCodec.UUID(_) => Opt.UUIDOpt - case SegmentCodec.BoolSeg(_) => Opt.BoolOpt - case SegmentCodec.Trailing => Opt.TrailingOpt - }) - - case f: Fallback[_] => - Chunk(Opt.MatchAny(fallbacks(f))) - + loopSegment(segment, fresh = true) + case f: Fallback[_] => + b += Opt.MatchAny(fallbacks(f)) case Concat(left, right, combiner) => - loop(left) ++ loop(right) ++ Chunk(Opt.Combine(combiner)) - - case TransformOrFail(api, f, _) => - loop(api) :+ Opt.MapOrFail(f.asInstanceOf[Any => Either[String, Any]]) + loop(left) + loop(right) + b += Opt.Combine(combiner) + case TransformOrFail(api, f, _) => + loop(api) + b += Opt.MapOrFail(f.asInstanceOf[Any => Either[String, Any]]) } - if (_optimize eq null) _optimize = loop(self).toArray + if (_optimize eq null) { + val b: mutable.ArrayBuilder[Opt] = mutable.ArrayBuilder.make[Opt] + loop(self)(b) + _optimize = b.result() + } _optimize } @@ -409,17 +590,15 @@ sealed trait PathCodec[A] { self => */ def render(prefix: String, suffix: String): String = { def loop(path: PathCodec[_]): String = path match { - case PathCodec.Annotated(codec, _) => + case PathCodec.Annotated(codec, _) => loop(codec) - case PathCodec.Concat(left, right, _) => + case PathCodec.Concat(left, right, _) => loop(left) + loop(right) - - case PathCodec.Segment(segment) => segment.render(prefix, suffix) - + case PathCodec.Segment(segment) => + segment.render(prefix, suffix) case PathCodec.TransformOrFail(api, _, _) => loop(api) - - case PathCodec.Fallback(left, _) => + case PathCodec.Fallback(left, _) => loop(left) } @@ -517,6 +696,8 @@ object PathCodec { implicit def path(value: String): PathCodec[Unit] = apply(value) + implicit def segment[A](codec: SegmentCodec[A]): PathCodec[A] = Segment(codec) + def string(name: String): PathCodec[String] = Segment(SegmentCodec.string(name)) def trailing: PathCodec[Path] = Segment(SegmentCodec.Trailing) @@ -574,6 +755,7 @@ object PathCodec { case object BoolOpt extends Opt case object TrailingOpt extends Opt case object Unit extends Opt + final case class SubSegmentOpts(ops: Array[Opt]) extends Opt final case class MapOrFail(f: Any => Either[String, Any]) extends Opt } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/SegmentCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/SegmentCodec.scala index 102d57add..d5e17bedd 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/SegmentCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/SegmentCodec.scala @@ -15,11 +15,15 @@ */ package zio.http.codec +import scala.annotation.implicitNotFound import scala.language.implicitConversions import zio.Chunk import zio.http.Path +import zio.http.codec.Combiner.WithOut +import zio.http.codec.PathCodec.MetaData +import zio.http.codec.SegmentCodec._ sealed trait SegmentCodec[A] { self => private var _hashCode: Int = 0 @@ -32,6 +36,12 @@ sealed trait SegmentCodec[A] { self => case _ => false } + final def example(name: String, example: A): PathCodec[A] = + PathCodec.segment(self).annotate(MetaData.Examples(Map(name -> example))) + + final def examples(examples: (String, A)*): PathCodec[A] = + PathCodec.segment(self).annotate(MetaData.Examples(examples.toMap)) + def format(value: A): Path override val hashCode: Int = { @@ -44,9 +54,22 @@ sealed trait SegmentCodec[A] { self => case _ => false } + final def ??(doc: Doc): PathCodec[A] = PathCodec.Segment(self).??(doc) + + final def ~[B]( + that: SegmentCodec[B], + )(implicit combiner: Combiner[A, B], combinable: Combinable[B, SegmentCodec[B]]): SegmentCodec[combiner.Out] = + combinable.combine(self, that) + + final def ~(that: String)(implicit combiner: Combiner[A, Unit]): SegmentCodec[combiner.Out] = + self.~(SegmentCodec.literal(that))(combiner, Combinable.combinableLiteral) + // Returns number of segments matched, or -1 if not matched: def matches(segments: Chunk[String], index: Int): Int + // Returns the last index of the subsegment matched, or -1 if not matched + def inSegmentUntil(segment: String, from: Int): Int + final def nonEmpty: Boolean = !isEmpty final def render: String = { @@ -54,17 +77,45 @@ sealed trait SegmentCodec[A] { self => _render } - final def render(prefix: String, suffix: String): String = - self.asInstanceOf[SegmentCodec[_]] match { - case _: SegmentCodec.Empty.type => s"" - case SegmentCodec.Literal(value) => s"/$value" - case SegmentCodec.IntSeg(name) => s"/$prefix$name$suffix" - case SegmentCodec.LongSeg(name) => s"/$prefix$name$suffix" - case SegmentCodec.Text(name) => s"/$prefix$name$suffix" - case SegmentCodec.BoolSeg(name) => s"/$prefix$name$suffix" - case SegmentCodec.UUID(name) => s"/$prefix$name$suffix" - case _: SegmentCodec.Trailing.type => s"/..." + final def render(prefix: String, suffix: String): String = { + val b = new StringBuilder + + def loop(s: SegmentCodec[_]): Unit = { + s match { + case _: SegmentCodec.Empty.type => () + case SegmentCodec.Literal(value) => + b.appendAll(value) + case SegmentCodec.IntSeg(name) => + b.appendAll(prefix) + b.appendAll(name) + b.appendAll(suffix) + case SegmentCodec.LongSeg(name) => + b.appendAll(prefix) + b.appendAll(name) + b.appendAll(suffix) + case SegmentCodec.Text(name) => + b.appendAll(prefix) + b.appendAll(name) + b.appendAll(suffix) + case SegmentCodec.BoolSeg(name) => + b.appendAll(prefix) + b.appendAll(name) + b.appendAll(suffix) + case SegmentCodec.UUID(name) => + b.appendAll(prefix) + b.appendAll(name) + b.appendAll(suffix) + case SegmentCodec.Combined(left, right, _) => + loop(left) + loop(right) + case _: SegmentCodec.Trailing.type => + b.appendAll("...") + } } + if (self ne SegmentCodec.Empty) b.append('/') + loop(self.asInstanceOf[SegmentCodec[_]]) + b.result() + } final def transform[A2](f: A => A2)(g: A2 => A): PathCodec[A2] = PathCodec.Segment(self).transform(f)(g) @@ -79,6 +130,179 @@ sealed trait SegmentCodec[A] { self => PathCodec.Segment(self).transformOrFailRight(f)(g) } object SegmentCodec { + + @implicitNotFound("Segments of type ${B} cannot be appended to a multi-value segment") + sealed trait Combinable[B, S <: SegmentCodec[B]] { + def combine[A](self: SegmentCodec[A], that: SegmentCodec[B])(implicit + combiner: Combiner[A, B], + ): SegmentCodec[combiner.Out] + } + private[codec] object Combinable { + + implicit val combinableString: Combinable[String, SegmentCodec[String]] = + new Combinable[String, SegmentCodec[String]] { + override def combine[A](self: SegmentCodec[A], that: SegmentCodec[String])(implicit + combiner: Combiner[A, String], + ): SegmentCodec[combiner.Out] = { + self match { + case SegmentCodec.Empty => that.asInstanceOf[SegmentCodec[combiner.Out]] + case SegmentCodec.Text(name) => + throw new IllegalArgumentException( + "Cannot combine two string segments. Their names are " + name + " and " + that + .asInstanceOf[SegmentCodec.Text] + .name, + ) + case c: SegmentCodec.Combined[_, _, _] => + val last = c.flattened.last + last match { + case text: SegmentCodec.Text => + throw new IllegalArgumentException( + "Cannot combine two string segments. Their names are" + text.name + " and " + that + .asInstanceOf[Text] + .name, + ) + case _ => + SegmentCodec.Combined(self, that, combiner.asInstanceOf[WithOut[A, String, combiner.Out]]) + } + case _ => + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, String, combiner.Out]]) + } + } + } + implicit val combinableInt: Combinable[Int, SegmentCodec[Int]] = + new Combinable[Int, SegmentCodec[Int]] { + override def combine[A](self: SegmentCodec[A], that: SegmentCodec[Int])(implicit + combiner: Combiner[A, Int], + ): SegmentCodec[combiner.Out] = { + self match { + case SegmentCodec.Empty => that.asInstanceOf[SegmentCodec[combiner.Out]] + case SegmentCodec.IntSeg(name) => + throw new IllegalArgumentException( + "Cannot combine two numeric segments. Their names are " + name + " and " + that + .asInstanceOf[SegmentCodec.IntSeg] + .name, + ) + case SegmentCodec.LongSeg(name) => + throw new IllegalArgumentException( + "Cannot combine two numeric segments. Their names are " + name + " and " + that + .asInstanceOf[SegmentCodec.IntSeg] + .name, + ) + case c: SegmentCodec.Combined[_, _, _] => + val last = c.flattened.last + if (last.isInstanceOf[SegmentCodec.IntSeg] || last.isInstanceOf[SegmentCodec.LongSeg]) { + val lastName = + last match { + case SegmentCodec.IntSeg(name) => name + case SegmentCodec.LongSeg(name) => name + case _ => "" + } + throw new IllegalArgumentException( + "Cannot combine two numeric segments. Their names are " + lastName + " and " + that + .asInstanceOf[SegmentCodec.IntSeg] + .name, + ) + } else { + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, Int, combiner.Out]]) + } + case _ => + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, Int, combiner.Out]]) + } + } + } + implicit val combinableLong: Combinable[Long, SegmentCodec[Long]] = + new Combinable[Long, SegmentCodec[Long]] { + override def combine[A](self: SegmentCodec[A], that: SegmentCodec[Long])(implicit + combiner: Combiner[A, Long], + ): SegmentCodec[combiner.Out] = { + self match { + case SegmentCodec.Empty => that.asInstanceOf[SegmentCodec[combiner.Out]] + case SegmentCodec.IntSeg(name) => + throw new IllegalArgumentException( + "Cannot combine two numeric segments. Their names are " + name + " and " + that + .asInstanceOf[SegmentCodec.LongSeg] + .name, + ) + case SegmentCodec.LongSeg(name) => + throw new IllegalArgumentException( + "Cannot combine two numeric segments. Their names are " + name + " and " + that + .asInstanceOf[SegmentCodec.LongSeg] + .name, + ) + case c: SegmentCodec.Combined[_, _, _] => + val last = c.flattened.last + if (last.isInstanceOf[SegmentCodec.IntSeg] || last.isInstanceOf[SegmentCodec.LongSeg]) { + val lastName = + last match { + case SegmentCodec.IntSeg(name) => name + case SegmentCodec.LongSeg(name) => name + case _ => "" + } + throw new IllegalArgumentException( + "Cannot combine two numeric segments. Their names are " + lastName + " and " + that + .asInstanceOf[SegmentCodec.LongSeg] + .name, + ) + } else { + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, Long, combiner.Out]]) + } + case _ => + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, Long, combiner.Out]]) + } + } + } + implicit val combinableBool: Combinable[Boolean, SegmentCodec[Boolean]] = + new Combinable[Boolean, SegmentCodec[Boolean]] { + override def combine[A](self: SegmentCodec[A], that: SegmentCodec[Boolean])(implicit + combiner: Combiner[A, Boolean], + ): SegmentCodec[combiner.Out] = { + self match { + case SegmentCodec.Empty => that.asInstanceOf[SegmentCodec[combiner.Out]] + case _ => + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, Boolean, combiner.Out]]) + } + } + } + implicit val combinableUUID: Combinable[UUID, SegmentCodec[UUID]] = + new Combinable[UUID, SegmentCodec[UUID]] { + override def combine[A](self: SegmentCodec[A], that: SegmentCodec[UUID])(implicit + combiner: Combiner[A, UUID], + ): SegmentCodec[combiner.Out] = { + self match { + case SegmentCodec.Empty => that.asInstanceOf[SegmentCodec[combiner.Out]] + case _ => + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, UUID, combiner.Out]]) + } + } + } + implicit val combinableLiteral: Combinable[Unit, SegmentCodec[Unit]] = + new Combinable[Unit, SegmentCodec[Unit]] { + override def combine[A](self: SegmentCodec[A], that: SegmentCodec[Unit])(implicit + combiner: Combiner[A, Unit], + ): SegmentCodec[combiner.Out] = { + self match { + case SegmentCodec.Empty => that.asInstanceOf[SegmentCodec[combiner.Out]] + case SegmentCodec.Literal(value) => + SegmentCodec + .Literal(value + that.asInstanceOf[SegmentCodec.Literal].value) + .asInstanceOf[SegmentCodec[combiner.Out]] + case SegmentCodec.Combined(l, r, c) if r.isInstanceOf[SegmentCodec.Literal] => + SegmentCodec + .Combined( + l.asInstanceOf[SegmentCodec[Any]], + SegmentCodec + .Literal(r.asInstanceOf[SegmentCodec.Literal].value + that.asInstanceOf[SegmentCodec.Literal].value) + .asInstanceOf[SegmentCodec[Any]], + c.asInstanceOf[Combiner.WithOut[Any, Any, Any]], + ) + .asInstanceOf[SegmentCodec[combiner.Out]] + case _ => + SegmentCodec.Combined(self, that, combiner.asInstanceOf[Combiner.WithOut[A, Unit, combiner.Out]]) + } + } + } + } + def bool(name: String): SegmentCodec[Boolean] = SegmentCodec.BoolSeg(name) val empty: SegmentCodec[Unit] = SegmentCodec.Empty @@ -101,6 +325,9 @@ object SegmentCodec { def format(unit: Unit): Path = Path(s"") def matches(segments: Chunk[String], index: Int): Int = 0 + + override def inSegmentUntil(segment: String, from: Int): Int = from + } private[http] final case class Literal(value: String) extends SegmentCodec[Unit] { @@ -112,7 +339,13 @@ object SegmentCodec { else if (value == segments(index)) 1 else -1 } + + override def inSegmentUntil(segment: String, from: Int): Int = + if (segment.startsWith(value, from)) from + value.length + else -1 + } + private[http] final case class BoolSeg(name: String) extends SegmentCodec[Boolean] { def format(value: Boolean): Path = Path(s"/$value") @@ -124,29 +357,41 @@ object SegmentCodec { if (segment == "true" || segment == "false") 1 else -1 } + + override def inSegmentUntil(segment: String, from: Int): Int = + if (segment.startsWith("true", from)) from + 4 + else if (segment.startsWith("false", from)) from + 5 + else -1 + } + private[http] final case class IntSeg(name: String) extends SegmentCodec[Int] { def format(value: Int): Path = Path(s"/$value") - def matches(segments: Chunk[String], index: Int): Int = { + def matches(segments: Chunk[String], index: Int): Int = if (index < 0 || index >= segments.length) -1 else { - val SegmentCodec = segments(index) - var i = 0 - var defined = true - if (SegmentCodec.length > 1 && SegmentCodec.charAt(0) == '-') i += 1 - while (i < SegmentCodec.length) { - if (!SegmentCodec.charAt(i).isDigit) { - defined = false - i = SegmentCodec.length - } - i += 1 - } - if (defined && i >= 1) 1 else -1 + val lastIndex = inSegmentUntil(segments(index), 0) + if (lastIndex == -1 || lastIndex + 1 != segments(index).length) -1 + else 1 } - } + + override def inSegmentUntil(segment: String, from: Int): Int = + if (segment.isEmpty || from >= segment.length) { + -1 + } else { + var i = from + val isNegative = segment.charAt(i) == '-' + // 10 digits is the maximum for an Int + val maxDigits = if (isNegative) 11 else 10 + if (segment.length > 1 && isNegative) i += 1 + while (i + 1 < segment.length && i - from < maxDigits && segment.charAt(i).isDigit) i += 1 + i + } + } + private[http] final case class LongSeg(name: String) extends SegmentCodec[Long] { def format(value: Long): Path = Path(s"/$value") @@ -154,20 +399,26 @@ object SegmentCodec { def matches(segments: Chunk[String], index: Int): Int = { if (index < 0 || index >= segments.length) -1 else { - val SegmentCodec = segments(index) - var i = 0 - var defined = true - if (SegmentCodec.length > 1 && SegmentCodec.charAt(0) == '-') i += 1 - while (i < SegmentCodec.length) { - if (!SegmentCodec.charAt(i).isDigit) { - defined = false - i = SegmentCodec.length - } - i += 1 - } - if (defined && i >= 1) 1 else -1 + val lastIndex = inSegmentUntil(segments(index), 0) + if (lastIndex == -1 || lastIndex + 1 != segments(index).length) -1 + else 1 } } + + override def inSegmentUntil(segment: String, from: Int): Int = { + if (segment.isEmpty || from >= segment.length) { + -1 + } else { + var i = from + val isNegative = segment.charAt(i) == '-' + // 19 digits is the maximum for a Long + val maxDigits = if (isNegative) 20 else 19 + if (segment.length > 1 && isNegative) i += 1 + while (i + 1 < segment.length && i - from < maxDigits && segment.charAt(i).isDigit) i += 1 + i + } + } + } private[http] final case class Text(name: String) extends SegmentCodec[String] { @@ -176,6 +427,10 @@ object SegmentCodec { def matches(segments: Chunk[String], index: Int): Int = if (index < 0 || index >= segments.length) -1 else 1 + + override def inSegmentUntil(segment: String, from: Int): Int = + segment.length + } private[http] final case class UUID(name: String) extends SegmentCodec[java.util.UUID] { @@ -184,34 +439,97 @@ object SegmentCodec { def matches(segments: Chunk[String], index: Int): Int = { if (index < 0 || index >= segments.length) -1 else { - val SegmentCodec = segments(index) + val lastIndex = inSegmentUntil(segments(index), 0) + if (lastIndex == -1 || lastIndex + 1 != segments(index).length) -1 + else 1 + } + } - var i = 0 - var defined = true - var group = 0 - var count = 0 - while (i < SegmentCodec.length) { - val char = SegmentCodec.charAt(i) - if ((char >= 48 && char <= 57) || (char >= 65 && char <= 70) || (char >= 97 && char <= 102)) - count += 1 - else if (char == 45) { - if ( - group > 4 || (group == 0 && count != 8) || ((group == 1 || group == 2 || group == 3) && count != 4) || (group == 4 && count != 12) - ) { - defined = false - i = SegmentCodec.length - } - count = 0 - group += 1 - } else { + override def inSegmentUntil(segment: String, from: Int): Int = + UUID.isUUIDUntil(segment, from) + } + + private[http] object UUID { + def isUUIDUntil(segment: String, from: Int): Int = { + var i = from + var defined = true + var group = 0 + var count = 0 + if (segment.length + from < 36) return -1 + while (i < 36 && defined) { + val char = segment.charAt(i) + if ((char >= 48 && char <= 57) || (char >= 65 && char <= 70) || (char >= 97 && char <= 102)) + count += 1 + else if (char == 45) { + if ( + group > 4 || (group == 0 && count != 8) || ((group == 1 || group == 2 || group == 3) && count != 4) || (group == 4 && count != 12) + ) { defined = false - i = SegmentCodec.length + i = segment.length } + count = 0 + group += 1 + } else { + defined = false + i = segment.length + } + i += 1 + } + if (defined && from + 36 == i) i else -1 + } + + } + + private[http] final case class Combined[A, B, C]( + left: SegmentCodec[A], + right: SegmentCodec[B], + combiner: Combiner.WithOut[A, B, C], + ) extends SegmentCodec[C] { self => + val flattened: List[SegmentCodec[_]] = { + def loop(s: SegmentCodec[_]): List[SegmentCodec[_]] = s match { + case SegmentCodec.Combined(l, r, _) => loop(l) ++ loop(r) + case _ => List(s) + } + loop(self) + } + override def format(value: C): Path = { + val (l, r) = combiner.separate(value) + val lf = left.format(l) + val rf = right.format(r) + lf ++ rf + } + + override def matches(segments: Chunk[String], index: Int): Int = + if (index < 0 || index >= segments.length) -1 + else { + val segment = segments(index) + val length = segment.length + var from = 0 + var i = 0 + while (i < flattened.length) { + if (from >= length) return -1 + val codec = flattened(i) + val s = codec.inSegmentUntil(segment, from) + if (s == -1) return -1 + from = s i += 1 } - if (defined && i == 36) 1 else -1 + 1 } + + override def inSegmentUntil(segment: String, from: Int): Int = { + var i = from + var j = 0 + while (j < flattened.length) { + val codec = flattened(j) + val s = codec.inSegmentUntil(segment, i) + if (s == -1) return -1 + i = s + j += 1 + } + i } + } case object Trailing extends SegmentCodec[Path] { self => @@ -219,5 +537,9 @@ object SegmentCodec { def matches(segments: Chunk[String], index: Int): Int = (segments.length - index).max(0) + + override def inSegmentUntil(segment: String, from: Int): Int = + segment.length } + } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala index 933c52e13..22cd74e24 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala @@ -1,6 +1,6 @@ package zio.http.endpoint.openapi -import scala.annotation.nowarn +import scala.annotation.{nowarn, tailrec} import zio._ import zio.json.ast.Json @@ -307,130 +307,142 @@ object JsonSchema { def fromSegmentCodec(codec: SegmentCodec[_]): JsonSchema = codec match { - case SegmentCodec.BoolSeg(_) => JsonSchema.Boolean - case SegmentCodec.IntSeg(_) => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) - case SegmentCodec.LongSeg(_) => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) - case SegmentCodec.Text(_) => JsonSchema.String() - case SegmentCodec.UUID(_) => JsonSchema.String(JsonSchema.StringFormat.UUID) - case SegmentCodec.Literal(_) => throw new IllegalArgumentException("Literal segment is not supported.") - case SegmentCodec.Empty => throw new IllegalArgumentException("Empty segment is not supported.") - case SegmentCodec.Trailing => throw new IllegalArgumentException("Trailing segment is not supported.") + case SegmentCodec.BoolSeg(_) => JsonSchema.Boolean + case SegmentCodec.IntSeg(_) => JsonSchema.Integer(JsonSchema.IntegerFormat.Int32) + case SegmentCodec.LongSeg(_) => JsonSchema.Integer(JsonSchema.IntegerFormat.Int64) + case SegmentCodec.Text(_) => JsonSchema.String() + case SegmentCodec.UUID(_) => JsonSchema.String(JsonSchema.StringFormat.UUID) + case SegmentCodec.Literal(_) => throw new IllegalArgumentException("Literal segment is not supported.") + case SegmentCodec.Empty => throw new IllegalArgumentException("Empty segment is not supported.") + case SegmentCodec.Trailing => throw new IllegalArgumentException("Trailing segment is not supported.") + case SegmentCodec.Combined(_, _, _) => throw new IllegalArgumentException("Combined segment is not supported.") } - def fromZSchemaMulti(schema: Schema[_], refType: SchemaStyle = SchemaStyle.Inline): JsonSchemas = { + def fromZSchemaMulti( + schema: Schema[_], + refType: SchemaStyle = SchemaStyle.Inline, + seen: Set[java.lang.String] = Set.empty, + ): JsonSchemas = { val ref = nominal(schema, refType) - schema match { - case enum0: Schema.Enum[_] if enum0.cases.forall(_.schema.isInstanceOf[CaseClass0[_]]) => - JsonSchemas(fromZSchema(enum0, SchemaStyle.Inline), ref, Map.empty) - case enum0: Schema.Enum[_] => - JsonSchemas( - fromZSchema(enum0, SchemaStyle.Inline), - ref, - enum0.cases - .filterNot(_.annotations.exists(_.isInstanceOf[transientCase])) - .flatMap { c => - val key = - nominal(c.schema, refType) - .orElse(nominal(c.schema, SchemaStyle.Compact)) + if (ref.exists(seen.contains)) { + JsonSchemas(RefSchema(ref.get), ref, Map.empty) + } else { + val seenWithCurrent = seen ++ ref + schema match { + case enum0: Schema.Enum[_] if enum0.cases.forall(_.schema.isInstanceOf[CaseClass0[_]]) => + JsonSchemas(fromZSchema(enum0, SchemaStyle.Inline), ref, Map.empty) + case enum0: Schema.Enum[_] => + JsonSchemas( + fromZSchema(enum0, SchemaStyle.Inline), + ref, + enum0.cases + .filterNot(_.annotations.exists(_.isInstanceOf[transientCase])) + .flatMap { c => + val key = + nominal(c.schema, refType) + .orElse(nominal(c.schema, SchemaStyle.Compact)) + val nested = fromZSchemaMulti( + c.schema, + refType, + seenWithCurrent, + ) + nested.children ++ key.map(_ -> nested.root) + } + .toMap, + ) + case record: Schema.Record[_] => + val children = record.fields + .filterNot(_.annotations.exists(_.isInstanceOf[transientField])) + .flatMap { field => val nested = fromZSchemaMulti( - c.schema, + field.schema, refType, + seenWithCurrent, ) - nested.children ++ key.map(_ -> nested.root) + nested.rootRef.map(k => nested.children + (k -> nested.root)).getOrElse(nested.children) } - .toMap, - ) - case record: Schema.Record[_] => - val children = record.fields - .filterNot(_.annotations.exists(_.isInstanceOf[transientField])) - .flatMap { field => - val nested = fromZSchemaMulti( - field.schema, - refType, - ) - nested.rootRef.map(k => nested.children + (k -> nested.root)).getOrElse(nested.children) + .toMap + JsonSchemas(fromZSchema(record, SchemaStyle.Inline), ref, children) + case collection: Schema.Collection[_, _] => + collection match { + case Schema.Sequence(elementSchema, _, _, _, _) => + arraySchemaMulti(refType, ref, elementSchema, seenWithCurrent) + case Schema.Map(_, valueSchema, _) => + val nested = fromZSchemaMulti(valueSchema, refType, seenWithCurrent) + if (valueSchema.isInstanceOf[Schema.Primitive[_]]) { + JsonSchemas( + JsonSchema.Object( + Map.empty, + Right(nested.root), + Chunk.empty, + ), + ref, + nested.children, + ) + } else { + JsonSchemas( + JsonSchema.Object( + Map.empty, + Right(nested.root), + Chunk.empty, + ), + ref, + nested.children + (nested.rootRef.get -> nested.root), + ) + } + case Schema.Set(elementSchema, _) => + arraySchemaMulti(refType, ref, elementSchema, seenWithCurrent) } - .toMap - JsonSchemas(fromZSchema(record, SchemaStyle.Inline), ref, children) - case collection: Schema.Collection[_, _] => - collection match { - case Schema.Sequence(elementSchema, _, _, _, _) => - arraySchemaMulti(refType, ref, elementSchema) - case Schema.Map(_, valueSchema, _) => - val nested = fromZSchemaMulti(valueSchema, refType) - if (valueSchema.isInstanceOf[Schema.Primitive[_]]) { - JsonSchemas( - JsonSchema.Object( - Map.empty, - Right(nested.root), - Chunk.empty, - ), - ref, - nested.children, + case Schema.Transform(schema, _, _, _, _) => + fromZSchemaMulti(schema, refType, seenWithCurrent) + case Schema.Primitive(_, _) => + JsonSchemas(fromZSchema(schema, SchemaStyle.Inline), ref, Map.empty) + case Schema.Optional(schema, _) => + fromZSchemaMulti(schema, refType, seenWithCurrent) + case Schema.Fail(_, _) => + throw new IllegalArgumentException("Fail schema is not supported.") + case Schema.Tuple2(left, right, _) => + val leftSchema = fromZSchemaMulti(left, refType, seenWithCurrent) + val rightSchema = fromZSchemaMulti(right, refType, seenWithCurrent) + JsonSchemas( + AllOfSchema(Chunk(leftSchema.root, rightSchema.root)), + ref, + leftSchema.children ++ rightSchema.children, + ) + case Schema.Either(left, right, _) => + val leftSchema = fromZSchemaMulti(left, refType, seenWithCurrent) + val rightSchema = fromZSchemaMulti(right, refType, seenWithCurrent) + JsonSchemas( + OneOfSchema(Chunk(leftSchema.root, rightSchema.root)), + ref, + leftSchema.children ++ rightSchema.children, + ) + case Schema.Fallback(left, right, fullDecode, _) => + val leftSchema = fromZSchemaMulti(left, refType, seenWithCurrent) + val rightSchema = fromZSchemaMulti(right, refType, seenWithCurrent) + val candidates = + if (fullDecode) + Chunk( + AllOfSchema(Chunk(leftSchema.root, rightSchema.root)), + leftSchema.root, + rightSchema.root, ) - } else { - JsonSchemas( - JsonSchema.Object( - Map.empty, - Right(nested.root), - Chunk.empty, - ), - ref, - nested.children + (nested.rootRef.get -> nested.root), + else + Chunk( + leftSchema.root, + rightSchema.root, ) - } - case Schema.Set(elementSchema, _) => - arraySchemaMulti(refType, ref, elementSchema) - } - case Schema.Transform(schema, _, _, _, _) => - fromZSchemaMulti(schema, refType) - case Schema.Primitive(_, _) => - JsonSchemas(fromZSchema(schema, SchemaStyle.Inline), ref, Map.empty) - case Schema.Optional(schema, _) => - fromZSchemaMulti(schema, refType) - case Schema.Fail(_, _) => - throw new IllegalArgumentException("Fail schema is not supported.") - case Schema.Tuple2(left, right, _) => - val leftSchema = fromZSchemaMulti(left, refType) - val rightSchema = fromZSchemaMulti(right, refType) - JsonSchemas( - AllOfSchema(Chunk(leftSchema.root, rightSchema.root)), - ref, - leftSchema.children ++ rightSchema.children, - ) - case Schema.Either(left, right, _) => - val leftSchema = fromZSchemaMulti(left, refType) - val rightSchema = fromZSchemaMulti(right, refType) - JsonSchemas( - OneOfSchema(Chunk(leftSchema.root, rightSchema.root)), - ref, - leftSchema.children ++ rightSchema.children, - ) - case Schema.Fallback(left, right, fullDecode, _) => - val leftSchema = fromZSchemaMulti(left, refType) - val rightSchema = fromZSchemaMulti(right, refType) - val candidates = - if (fullDecode) - Chunk( - AllOfSchema(Chunk(leftSchema.root, rightSchema.root)), - leftSchema.root, - rightSchema.root, - ) - else - Chunk( - leftSchema.root, - rightSchema.root, - ) - JsonSchemas( - OneOfSchema(candidates), - ref, - leftSchema.children ++ rightSchema.children, - ) - case Schema.Lazy(schema0) => - fromZSchemaMulti(schema0(), refType) - case Schema.Dynamic(_) => - JsonSchemas(AnyJson, None, Map.empty) + JsonSchemas( + OneOfSchema(candidates), + ref, + leftSchema.children ++ rightSchema.children, + ) + case Schema.Lazy(schema0) => + fromZSchemaMulti(schema0(), refType, seen) + case Schema.Dynamic(_) => + JsonSchemas(AnyJson, None, Map.empty) + } } } @@ -438,8 +450,9 @@ object JsonSchema { refType: SchemaStyle, ref: Option[java.lang.String], elementSchema: Schema[_], + seen: Set[java.lang.String], ): JsonSchemas = { - val nested = fromZSchemaMulti(elementSchema, refType) + val nested = fromZSchemaMulti(elementSchema, refType, seen) if (elementSchema.isInstanceOf[Schema.Primitive[_]]) { JsonSchemas( JsonSchema.ArrayType(Some(nested.root)), @@ -450,7 +463,7 @@ object JsonSchema { JsonSchemas( JsonSchema.ArrayType(Some(nested.root)), ref, - nested.children ++ (nested.rootRef.map(_ -> nested.root)), + nested.children ++ nested.rootRef.map(_ -> nested.root), ) } } @@ -634,6 +647,7 @@ object JsonSchema { schema.annotations.collectFirst { case fieldDefaultValue(value) => value } .map(toJsonAst(schema.schema, _)) + @tailrec private def nominal(schema: Schema[_], referenceType: SchemaStyle = SchemaStyle.Reference): Option[java.lang.String] = schema match { case enumSchema: Schema.Enum[_] => refForTypeId(enumSchema.id, referenceType) diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 3915605cf..28e57d278 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -705,14 +705,15 @@ object OpenAPIGen { def segmentToJson(codec: SegmentCodec[_], value: Any): Json = { codec match { - case SegmentCodec.Empty => throw new Exception("Empty segment not allowed") - case SegmentCodec.Literal(_) => throw new Exception("Literal segment not allowed") - case SegmentCodec.BoolSeg(_) => Json.Bool(value.asInstanceOf[Boolean]) - case SegmentCodec.IntSeg(_) => Json.Num(value.asInstanceOf[Int]) - case SegmentCodec.LongSeg(_) => Json.Num(value.asInstanceOf[Long]) - case SegmentCodec.Text(_) => Json.Str(value.asInstanceOf[String]) - case SegmentCodec.UUID(_) => Json.Str(value.asInstanceOf[UUID].toString) - case SegmentCodec.Trailing => throw new Exception("Trailing segment not allowed") + case SegmentCodec.Empty => throw new Exception("Empty segment not allowed") + case SegmentCodec.Literal(_) => throw new Exception("Literal segment not allowed") + case SegmentCodec.BoolSeg(_) => Json.Bool(value.asInstanceOf[Boolean]) + case SegmentCodec.IntSeg(_) => Json.Num(value.asInstanceOf[Int]) + case SegmentCodec.LongSeg(_) => Json.Num(value.asInstanceOf[Long]) + case SegmentCodec.Text(_) => Json.Str(value.asInstanceOf[String]) + case SegmentCodec.UUID(_) => Json.Str(value.asInstanceOf[UUID].toString) + case SegmentCodec.Trailing => throw new Exception("Trailing segment not allowed") + case SegmentCodec.Combined(_, _, _) => throw new Exception("Combined segment not allowed") } } diff --git a/zio-http/shared/src/main/scala/zio/http/package.scala b/zio-http/shared/src/main/scala/zio/http/package.scala index 6ea5db2be..80d516361 100644 --- a/zio-http/shared/src/main/scala/zio/http/package.scala +++ b/zio-http/shared/src/main/scala/zio/http/package.scala @@ -18,7 +18,7 @@ package zio import java.util.UUID -import zio.http.codec.PathCodec +import zio.http.codec.{PathCodec, SegmentCodec} package object http extends UrlInterpolator with MdInterpolator { @@ -36,12 +36,12 @@ package object http extends UrlInterpolator with MdInterpolator { def withContext[C](fn: => C)(implicit c: WithContext[C]): ZIO[c.Env, c.Err, c.Out] = c.toZIO(fn) - def boolean(name: String): PathCodec[Boolean] = PathCodec.bool(name) - def int(name: String): PathCodec[Int] = PathCodec.int(name) - def long(name: String): PathCodec[Long] = PathCodec.long(name) - def string(name: String): PathCodec[String] = PathCodec.string(name) - val trailing: PathCodec[Path] = PathCodec.trailing - def uuid(name: String): PathCodec[UUID] = PathCodec.uuid(name) + def boolean(name: String): SegmentCodec[Boolean] = SegmentCodec.bool(name) + def int(name: String): SegmentCodec[Int] = SegmentCodec.int(name) + def long(name: String): SegmentCodec[Long] = SegmentCodec.long(name) + def string(name: String): SegmentCodec[String] = SegmentCodec.string(name) + val trailing: SegmentCodec[Path] = SegmentCodec.trailing + def uuid(name: String): SegmentCodec[UUID] = SegmentCodec.uuid(name) def anyOf(name: String, names: String*): PathCodec[Unit] = if (names.isEmpty) PathCodec.literal(name) else names.foldLeft(PathCodec.literal(name))((acc, n) => acc.orElse(PathCodec.literal(n))) diff --git a/zio-http/shared/src/test/scala/zio/http/codec/SegmentCodecSpec.scala b/zio-http/shared/src/test/scala/zio/http/codec/SegmentCodecSpec.scala new file mode 100644 index 000000000..09f7d3e1a --- /dev/null +++ b/zio-http/shared/src/test/scala/zio/http/codec/SegmentCodecSpec.scala @@ -0,0 +1,95 @@ +package zio.http.codec + +import scala.util._ + +import zio._ +import zio.test._ + +import zio.http._ +import zio.http.codec.SegmentCodec.literal + +object SegmentCodecSpec extends ZIOSpecDefault { + override def spec: Spec[TestEnvironment with Scope, Any] = suite("SegmentCodec")( + test("combining literals is simplified to a single literal") { + val combineLitLit = Try( + "prefix" ~ "suffix", + ) + + val combineIntLitLit = Try( + int("anInt") ~ "prefix" ~ "suffix", + ) + + val expectedLit: Try[SegmentCodec[Unit]] = Success(SegmentCodec.Literal("prefixsuffix")) + val expectedIntLit: Try[SegmentCodec[Int]] = Success( + SegmentCodec.Combined( + SegmentCodec.IntSeg("anInt"), + SegmentCodec.Literal("prefixsuffix"), + Combiner.combine[Int, Unit].asInstanceOf[Combiner.WithOut[Int, Unit, Int]], + ), + ) + assertTrue( + combineLitLit == expectedLit, + combineIntLitLit == expectedIntLit, + ) + }, + test("Can't combine two string extracting segments") { + val combineStrStr = Try( + string("aString") ~ string("anotherString"), + ) + val expectedErrorMsg = "Cannot combine two string segments. Their names are aString and anotherString" + assertTrue(combineStrStr.failed.toOption.map(_.getMessage).contains(expectedErrorMsg)) + }, + test("Can't combine two int extracting segments") { + val combineIntInt = Try( + int("anInt") ~ int("anotherInt"), + ) + val combineUUIDIntInt = Try( + uuid("aUUID") ~ int("anInt") ~ int("anotherInt"), + ) + val expectedErrorMsg = "Cannot combine two numeric segments. Their names are anInt and anotherInt" + assertTrue( + combineIntInt.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + combineUUIDIntInt.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + ) + }, + test("Can't combine two long extracting segments") { + val combineLongLong = Try( + long("aLong") ~ long("anotherLong"), + ) + val uuidLongLong = Try( + uuid("aUUID") ~ long("aLong") ~ long("anotherLong"), + ) + val expectedErrorMsg = "Cannot combine two numeric segments. Their names are aLong and anotherLong" + assertTrue( + combineLongLong.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + uuidLongLong.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + ) + }, + test("Can't combine an int and a long extracting segment") { + val combineIntLong = Try( + int("anInt") ~ long("aLong"), + ) + val uuidIntLong = Try( + uuid("aUUID") ~ int("anInt") ~ long("aLong"), + ) + val expectedErrorMsg = "Cannot combine two numeric segments. Their names are anInt and aLong" + assertTrue( + combineIntLong.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + uuidIntLong.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + ) + }, + test("Can't combine a long and an int extracting segment") { + val combineLongInt = Try( + long("aLong") ~ int("anInt"), + ) + val uuidLongInt = Try( + uuid("aUUID") ~ long("aLong") ~ int("anInt"), + ) + val expectedErrorMsg = "Cannot combine two numeric segments. Their names are aLong and anInt" + assertTrue( + combineLongInt.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + uuidLongInt.failed.toOption.map(_.getMessage).contains(expectedErrorMsg), + ) + }, + ) +}