diff --git a/zio-http/src/main/scala/zio/http/Body.scala b/zio-http/src/main/scala/zio/http/Body.scala index b302e00532..aa93a9d072 100644 --- a/zio-http/src/main/scala/zio/http/Body.scala +++ b/zio-http/src/main/scala/zio/http/Body.scala @@ -58,9 +58,29 @@ trait Body { self => def asMultipartForm(implicit trace: Trace): Task[Form] = for { bytes <- asChunk - form <- Form.fromMultipartBytes(bytes, Charsets.Http).mapError(_.asException) + form <- Form.fromMultipartBytes(bytes, Charsets.Http) } yield form + /** + * Returns an effect that decodes the streaming body as a multipart form. + * + * The result is a stream of FormData objects, where each FormData may be a + * StreamingBinary or a Text object. The StreamingBinary object contains a + * stream of bytes, which has to be consumed asynchronously by the user to get + * the next FormData from the stream. + */ + def asMultipartFormStream(implicit trace: Trace): Task[StreamingForm] = + boundary match { + case Some(boundary) => + ZIO.succeed( + StreamingForm(asStream, Boundary(boundary.toString), Charsets.Http), + ) + case None => + ZIO.fail( + new IllegalStateException("Cannot decode body as streaming multipart/form-data without a known boundary"), + ) + } + /** * Returns a stream that contains the bytes of the body. This method is safe * to use with large bodies, because the elements of the returned stream are @@ -138,7 +158,7 @@ object Body { case Some(value) => form.encodeAsMultipartBytes(charset, value) case None => form.encodeAsMultipartBytes(charset) } - ChunkBody(bytes).withContentType(MediaType.multipart.`form-data`, Some(boundary)) + StreamBody(bytes, Some(MediaType.multipart.`form-data`), Some(boundary)) } /** diff --git a/zio-http/src/main/scala/zio/http/forms/Form.scala b/zio-http/src/main/scala/zio/http/forms/Form.scala index 53152fbc6e..7af64e8926 100644 --- a/zio-http/src/main/scala/zio/http/forms/Form.scala +++ b/zio-http/src/main/scala/zio/http/forms/Form.scala @@ -40,6 +40,20 @@ final case class Form(formData: Chunk[FormData]) { def append(field: FormData): Form = Form(formData :+ field) + /** + * Runs all streaming form data and stores them in memory, returning a Form + * that has no streaming parts + */ + def collectAll: ZIO[Any, Throwable, Form] = + ZIO + .foreach(formData) { + case streamingBinary: StreamingBinary => + streamingBinary.collect + case other => + ZIO.succeed(other) + } + .map(Form(_)) + def get(name: String): Option[FormData] = formData.find(_.name == name) def encodeAsURLEncoded(charset: Charset = StandardCharsets.UTF_8): String = { @@ -61,64 +75,82 @@ final case class Form(formData: Chunk[FormData]) { def encodeAsMultipartBytes( charset: Charset = StandardCharsets.UTF_8, - rng: () => String = () => new SecureRandom().nextLong().toString(), - ): (CharSequence, Chunk[Byte]) = + rng: () => String = () => new SecureRandom().nextLong().toString, + ): (CharSequence, ZStream[Any, Nothing, Byte]) = encodeAsMultipartBytes(charset, Boundary.generate(rng)) def encodeAsMultipartBytes( charset: Charset, boundary: Boundary, - ): (CharSequence, Chunk[Byte]) = { + ): (CharSequence, ZStream[Any, Nothing, Byte]) = { val encapsulatingBoundary = EncapsulatingBoundary(boundary) val closingBoundary = ClosingBoundary(boundary) - val ast = formData.flatMap { + val astStreams = formData.map { case fd @ Simple(name, value) => - Chunk( - encapsulatingBoundary, - EoL, - Header.contentDisposition(name), - EoL, - Header.contentType(fd.contentType), - EoL, - EoL, - Content(Chunk.fromArray(value.getBytes(charset))), - EoL, + ZStream.fromChunk( + Chunk( + encapsulatingBoundary, + EoL, + Header.contentDisposition(name), + EoL, + Header.contentType(fd.contentType), + EoL, + EoL, + Content(Chunk.fromArray(value.getBytes(charset))), + EoL, + ), ) case Text(name, value, contentType, filename) => - Chunk( - encapsulatingBoundary, - EoL, - Header.contentDisposition(name, filename), - EoL, - Header.contentType(contentType), - EoL, - EoL, - Content(Chunk.fromArray(value.getBytes(charset))), - EoL, + ZStream.fromChunk( + Chunk( + encapsulatingBoundary, + EoL, + Header.contentDisposition(name, filename), + EoL, + Header.contentType(contentType), + EoL, + EoL, + Content(Chunk.fromArray(value.getBytes(charset))), + EoL, + ), ) case Binary(name, data, contentType, transferEncoding, filename) => val xferEncoding = transferEncoding.map(enc => Chunk(Header.contentTransferEncoding(enc), EoL)).getOrElse(Chunk.empty) - Chunk( - encapsulatingBoundary, - EoL, - Header.contentDisposition(name, filename), - EoL, - Header.contentType(contentType), - EoL, - ) ++ xferEncoding ++ + ZStream.fromChunk( + Chunk( + encapsulatingBoundary, + EoL, + Header.contentDisposition(name, filename), + EoL, + Header.contentType(contentType), + EoL, + ) ++ xferEncoding ++ Chunk(EoL, Content(data), EoL), + ) + + case StreamingBinary(name, contentType, transferEncoding, filename, data) => + val xferEncoding = + transferEncoding.map(enc => Chunk(Header.contentTransferEncoding(enc), EoL)).getOrElse(Chunk.empty) + + ZStream.fromChunk( Chunk( + encapsulatingBoundary, EoL, - Content(data), + Header.contentDisposition(name, filename), EoL, - ) - } ++ Chunk(closingBoundary, EoL) + Header.contentType(contentType), + EoL, + ) ++ xferEncoding :+ EoL, + ) ++ data.chunks.map(Content(_)) ++ ZStream(EoL) + } - boundary.id -> ast.flatMap(_.bytes) + val stream = ZStream.fromChunk(astStreams).flatten ++ ZStream.fromChunk(Chunk(closingBoundary, EoL)) + + boundary.id -> stream.map(_.bytes).flattenChunks } } @@ -133,35 +165,13 @@ object Form { def fromMultipartBytes( bytes: Chunk[Byte], charset: Charset = StandardCharsets.UTF_8, - ): ZIO[Any, FormDecodingError, Form] = { - def process(boundary: Boundary) = ZStream - .fromChunk(bytes) - .mapAccum(FormState.fromBoundary(boundary)) { (state, byte) => - state match { - case BoundaryClosed(tree) => (FormState.fromBoundary(boundary), tree) - case BoundaryEncapsulated(tree) => (FormState.fromBoundary(boundary, Some(byte)), tree) - case buffer: FormStateBuffer => - val state = buffer.append(byte) - state match { - case BoundaryClosed(prevContent) => (state, prevContent) - case _ => (state, Chunk.empty[FormAST]) - } - } - } - .collectZIO { - case chunk if chunk.nonEmpty => - FormData.fromFormAST(chunk, charset) - } - .runCollect - .map(apply) - + ): ZIO[Any, Throwable, Form] = for { boundary <- ZIO .fromOption(Boundary.fromContent(bytes, charset)) - .mapError(_ => FormDecodingError.BoundaryNotFoundInContent) - form <- process(boundary) + .orElseFail(FormDecodingError.BoundaryNotFoundInContent.asException) + form <- StreamingForm(ZStream.fromChunk(bytes), boundary, charset).collectAll } yield form - } def fromURLEncoded(encoded: String, encoding: Charset): ZIO[Any, FormDecodingError, Form] = { val fields = ZIO.foreach(encoded.split("&")) { pair => diff --git a/zio-http/src/main/scala/zio/http/forms/FormAST.scala b/zio-http/src/main/scala/zio/http/forms/FormAST.scala index 09eed266c8..fdd562d706 100644 --- a/zio-http/src/main/scala/zio/http/forms/FormAST.scala +++ b/zio-http/src/main/scala/zio/http/forms/FormAST.scala @@ -23,7 +23,14 @@ import zio._ import zio.http.model.Header.ContentTransferEncoding import zio.http.model._ -private[forms] sealed trait FormAST { def bytes: Chunk[Byte] } +private[forms] sealed trait FormAST { + def bytes: Chunk[Byte] + + def isContent: Boolean = this match { + case FormAST.Content(_) => true + case _ => false + } +} private[forms] object FormAST { diff --git a/zio-http/src/main/scala/zio/http/forms/FormData.scala b/zio-http/src/main/scala/zio/http/forms/FormData.scala index 729a8c0bb8..3d593adfcf 100644 --- a/zio-http/src/main/scala/zio/http/forms/FormData.scala +++ b/zio-http/src/main/scala/zio/http/forms/FormData.scala @@ -20,6 +20,8 @@ import java.nio.charset._ import zio._ +import zio.stream.{Take, ZStream} + import zio.http.forms.FormAST._ import zio.http.forms.FormDecodingError._ import zio.http.model.Header.ContentTransferEncoding @@ -66,6 +68,20 @@ object FormData { filename: Option[String] = None, ) extends FormData + final case class StreamingBinary( + name: String, + contentType: MediaType, + transferEncoding: Option[ContentTransferEncoding] = None, + filename: Option[String] = None, + data: ZStream[Any, Nothing, Byte], + ) extends FormData { + def collect: ZIO[Any, Nothing, Binary] = { + data.runCollect.map { bytes => + Binary(name, bytes, contentType, transferEncoding, filename) + } + } + } + final case class Text( name: String, value: String, @@ -96,9 +112,9 @@ object FormData { } for { - disposition <- ZIO.fromOption(extract._1).mapError(_ => FormDataMissingContentDisposition) - name <- ZIO.fromOption(extract._1.flatMap(_.fields.get("name"))).mapError(_ => ContentDispositionMissingName) - charset <- ZIO + disposition <- ZIO.fromOption(extract._1).orElseFail(FormDataMissingContentDisposition) + name <- ZIO.fromOption(extract._1.flatMap(_.fields.get("name"))).orElseFail(ContentDispositionMissingName) + charset <- ZIO .attempt(extract._2.flatMap(x => x.fields.get("charset").map(Charset.forName)).getOrElse(defaultCharset)) .mapError(e => InvalidCharset(e.getMessage)) contentParts = extract._4.tail // Skip the first empty line @@ -115,6 +131,45 @@ object FormData { else Binary(name, content, contentType, transferEncoding, disposition.fields.get("filename")) } + private[http] def getContentType(ast: Chunk[FormAST]): MediaType = + ast.collectFirst { + case header: Header if header.name == "Content-Type" => + MediaType.forContentType(header.preposition) + }.flatten.getOrElse(MediaType.text.plain) + + private[http] def incomingStreamingBinary( + ast: Chunk[FormAST], + queue: Queue[Take[Nothing, Byte]], + ): ZIO[Any, FormDecodingError, FormData] = { + val extract = + ast.foldLeft((Option.empty[Header], Option.empty[Header], Option.empty[Header])) { + case (accum, header: Header) if header.name == "Content-Disposition" => + (Some(header), accum._2, accum._3) + case (accum, header: Header) if header.name == "Content-Type" => + (accum._1, Some(header), accum._3) + case (accum, header: Header) if header.name == "Content-Transfer-Encoding" => + (accum._1, accum._2, Some(header)) + case (accum, _) => accum + } + + for { + disposition <- ZIO.fromOption(extract._1).orElseFail(FormDataMissingContentDisposition) + name <- ZIO.fromOption(extract._1.flatMap(_.fields.get("name"))).orElseFail(ContentDispositionMissingName) + contentType = extract._2 + .flatMap(x => MediaType.forContentType(x.preposition)) + .getOrElse(MediaType.text.plain) + transferEncoding = extract._3 + .flatMap(x => ContentTransferEncoding.parse(x.preposition).toOption) + + } yield StreamingBinary( + name, + contentType, + transferEncoding, + disposition.fields.get("filename"), + ZStream.fromQueue(queue).flattenTake, + ) + } + def textField(name: String, value: String, mediaType: MediaType = MediaType.text.plain): FormData = Text(name, value, mediaType, None) @@ -127,4 +182,12 @@ object FormData { transferEncoding: Option[ContentTransferEncoding] = None, filename: Option[String] = None, ): FormData = Binary(name, data, mediaType, transferEncoding, filename) + + def streamingBinaryField( + name: String, + data: ZStream[Any, Nothing, Byte], + mediaType: MediaType, + transferEncoding: Option[ContentTransferEncoding] = None, + filename: Option[String] = None, + ): FormData = StreamingBinary(name, mediaType, transferEncoding, filename, data) } diff --git a/zio-http/src/main/scala/zio/http/forms/FormState.scala b/zio-http/src/main/scala/zio/http/forms/FormState.scala index c6fd3aab18..100a9a0002 100644 --- a/zio-http/src/main/scala/zio/http/forms/FormState.scala +++ b/zio-http/src/main/scala/zio/http/forms/FormState.scala @@ -30,6 +30,7 @@ private[forms] object FormState { buffer: Chunk[Byte], lastByte: Option[Byte], boundary: Boundary, + dropContents: Boolean, ) extends FormState { self => def append(byte: Byte): FormState = { @@ -42,7 +43,7 @@ private[forms] object FormState { def flush(ast: FormAST): FormStateBuffer = self.copy( - tree = tree :+ ast, + tree = if (ast.isContent && dropContents) tree else tree :+ ast, buffer = Chunk.empty, lastByte = None, phase = phase0, @@ -76,6 +77,8 @@ private[forms] object FormState { } } + + def startIgnoringContents: FormStateBuffer = self.copy(dropContents = true) } final case class BoundaryEncapsulated(buffer: Chunk[FormAST]) extends FormState @@ -83,7 +86,7 @@ private[forms] object FormState { final case class BoundaryClosed(buffer: Chunk[FormAST]) extends FormState def fromBoundary(boundary: Boundary, lastByte: Option[Byte] = None): FormState = - FormStateBuffer(Chunk.empty, Phase.Part1, Chunk.empty, lastByte, boundary) + FormStateBuffer(Chunk.empty, Phase.Part1, Chunk.empty, lastByte, boundary, dropContents = false) sealed trait Phase diff --git a/zio-http/src/main/scala/zio/http/forms/StreamingForm.scala b/zio-http/src/main/scala/zio/http/forms/StreamingForm.scala new file mode 100644 index 0000000000..8a5c707f32 --- /dev/null +++ b/zio-http/src/main/scala/zio/http/forms/StreamingForm.scala @@ -0,0 +1,153 @@ +package zio.http.forms + +import java.nio.charset.Charset + +import scala.collection.immutable + +import zio.{Chunk, Queue, ZIO} + +import zio.stream.{Take, ZStream} + +final case class StreamingForm(source: ZStream[Any, Throwable, Byte], boundary: Boundary, charset: Charset) { + + /** + * Runs the streaming form and collects all parts in memory, returning a Form + */ + def collectAll: ZIO[Any, Throwable, Form] = + data + .mapZIOPar(1) { + case sb: FormData.StreamingBinary => + sb.collect + case other: FormData => + ZIO.succeed(other) + } + .runCollect + .map { formData => + Form(formData) + } + + def data: ZStream[Any, Throwable, FormData] = + source + .mapAccumZIO(initialState) { (state, byte) => + state.formState match { + case formState: FormState.FormStateBuffer => + val nextFormState = formState.append(byte) + (state.currentQueue match { + case Some(queue) => + val (newBuffer, maybeTake) = addByteToBuffer(state.buffer, byte) + maybeTake match { + case Some(take) => queue.offer(take).as(newBuffer) + case None => ZIO.succeed(newBuffer) + } + case None => + ZIO.succeed(state.buffer) + }).flatMap { newBuffer => + nextFormState match { + case newFormState: FormState.FormStateBuffer => + if ( + state.currentQueue.isEmpty && + newFormState.phase == FormState.Phase.Part2 && + !state.inNonStreamingPart + ) { + val contentType = FormData.getContentType(newFormState.tree) + if (contentType.binary) { + for { + newQueue <- Queue.unbounded[Take[Nothing, Byte]] + _ <- newQueue.offer(Take.chunk(newFormState.tree.collect { case FormAST.Content(bytes) => + bytes + }.flatten)) + streamingFormData <- FormData + .incomingStreamingBinary(newFormState.tree, newQueue) + .mapError(_.asException) + nextState = state.copy( + formState = newFormState, + currentQueue = Some(newQueue), + buffer = newBuffer, + ) + } yield (nextState, Some(streamingFormData)) + } else { + val nextState = state.copy(formState = newFormState, inNonStreamingPart = true) + ZIO.succeed((nextState, None)) + } + } else { + val nextState = state.copy(formState = newFormState, buffer = newBuffer) + ZIO.succeed((nextState, None)) + } + case FormState.BoundaryEncapsulated(ast) => + if (state.inNonStreamingPart) { + FormData + .fromFormAST(ast, charset) + .mapBoth( + _.asException, + { formData => + (state.reset, Some(formData)) + }, + ) + } else { + ZIO.succeed((state.reset, None)) + } + case FormState.BoundaryClosed(ast) => + if (state.inNonStreamingPart) { + FormData + .fromFormAST(ast, charset) + .mapBoth( + _.asException, + { formData => + (state.reset, Some(formData)) + }, + ) + } else { + ZIO.succeed((state.reset, None)) + } + } + } + case _ => + ZIO.succeed(state, None) + } + } + .collect { case Some(formData) => + formData + } + + private def initialState: StreamingForm.State = + StreamingForm.initialState(boundary) + + private val crlfBoundary: Chunk[Byte] = Chunk[Byte](13, 10) ++ boundary.encapsulationBoundaryBytes + + private def addByteToBuffer( + buffer: immutable.Queue[Byte], + byte: Byte, + ): (immutable.Queue[Byte], Option[Take[Nothing, Byte]]) = + if (buffer.length < (crlfBoundary.length - 1)) { + // Not enough bytes to check if we have the boundary + (buffer.enqueue(byte), None) + } else { + val newBuffer = buffer.enqueue(byte) + val newBufferChunk = Chunk.fromIterable(newBuffer) + if (newBufferChunk == crlfBoundary) { + // We have found the boundary + (immutable.Queue.empty, Some(Take.end)) + } else { + // We don't have the boundary + val (out, remaining) = newBuffer.dequeue + (remaining, Some(Take.single(out))) + } + } +} + +object StreamingForm { + private final case class State( + boundary: Boundary, + formState: FormState, + currentQueue: Option[Queue[Take[Nothing, Byte]]], + buffer: immutable.Queue[Byte], + inNonStreamingPart: Boolean, + ) { + + def reset: State = + State(boundary, FormState.fromBoundary(boundary), None, immutable.Queue.empty, inNonStreamingPart = false) + } + + private def initialState(boundary: Boundary): State = + State(boundary, FormState.fromBoundary(boundary), None, immutable.Queue.empty, inNonStreamingPart = false) +} diff --git a/zio-http/src/main/scala/zio/http/netty/NettyBody.scala b/zio-http/src/main/scala/zio/http/netty/NettyBody.scala index 566b53ab1f..3d5f075d19 100644 --- a/zio-http/src/main/scala/zio/http/netty/NettyBody.scala +++ b/zio-http/src/main/scala/zio/http/netty/NettyBody.scala @@ -25,7 +25,7 @@ import zio.stream.ZStream import zio.http.Body import zio.http.Body.{UnsafeBytes, UnsafeWriteable} import zio.http.internal.BodyEncoding -import zio.http.model.{Headers, MediaType} +import zio.http.model.{Header, Headers, MediaType} import io.netty.buffer.{ByteBuf, ByteBufUtil} import io.netty.channel.{Channel => JChannel} @@ -38,12 +38,16 @@ object NettyBody extends BodyEncoding { */ def fromAsciiString(asciiString: AsciiString): Body = AsciiStringBody(asciiString) - private[zio] def fromAsync(unsafeAsync: UnsafeAsync => Unit): Body = AsyncBody(unsafeAsync) + private[zio] def fromAsync( + unsafeAsync: UnsafeAsync => Unit, + contentTypeHeader: Option[Header.ContentType] = None, + ): Body = AsyncBody(unsafeAsync, contentTypeHeader.map(_.mediaType), contentTypeHeader.flatMap(_.boundary)) /** * Helper to create Body from ByteBuf */ - def fromByteBuf(byteBuf: ByteBuf): Body = new ByteBufBody(byteBuf) + def fromByteBuf(byteBuf: ByteBuf, contentTypeHeader: Option[Header.ContentType] = None): Body = + ByteBufBody(byteBuf, contentTypeHeader.map(_.mediaType), contentTypeHeader.flatMap(_.boundary)) override def fromCharSequence(charSequence: CharSequence, charset: Charset): Body = fromAsciiString(new AsciiString(charSequence, charset)) diff --git a/zio-http/src/main/scala/zio/http/netty/NettyResponse.scala b/zio-http/src/main/scala/zio/http/netty/NettyResponse.scala index 1e4c6f0dd1..2557809242 100644 --- a/zio-http/src/main/scala/zio/http/netty/NettyResponse.scala +++ b/zio-http/src/main/scala/zio/http/netty/NettyResponse.scala @@ -19,7 +19,7 @@ package zio.http.netty import zio.{Promise, Trace, Unsafe} import zio.http.Response.NativeResponse -import zio.http.model.{Headers, Status} +import zio.http.model.{Header, Headers, Status} import zio.http.netty.client.{ChannelState, ClientResponseStreamHandler} import zio.http.netty.model.Conversions import zio.http.{Body, Response} @@ -36,7 +36,7 @@ object NettyResponse { val status = Conversions.statusFromNetty(jRes.status()) val headers = Conversions.headersFromNetty(jRes.headers()) val copiedBuffer = Unpooled.copiedBuffer(jRes.content()) - val data = NettyBody.fromByteBuf(copiedBuffer) + val data = NettyBody.fromByteBuf(copiedBuffer, headers.header(Header.ContentType)) new NativeResponse(data, headers, status, () => NettyFutureExecutor.executed(ctx.close())) } diff --git a/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 6d9b3cd911..9d1c67f502 100644 --- a/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -217,24 +217,33 @@ private[zio] final case class ServerInboundHandler( case _ => None } + val headers = Conversions.headersFromNetty(nettyReq.headers()) + val contentType = headers.header(Header.ContentType) + nettyReq match { case nettyReq: FullHttpRequest => Request( - NettyBody.fromByteBuf(nettyReq.content()), - Conversions.headersFromNetty(nettyReq.headers()), + NettyBody.fromByteBuf( + nettyReq.content(), + contentType, + ), + headers, Conversions.methodFromNetty(nettyReq.method()), URL.fromString(nettyReq.uri()).getOrElse(URL.empty), protocolVersion, remoteAddress, ) case nettyReq: HttpRequest => - val body = NettyBody.fromAsync { async => - addAsyncBodyHandler(ctx, async) - } + val body = NettyBody.fromAsync( + { async => + addAsyncBodyHandler(ctx, async) + }, + contentType, + ) Request( body, - Conversions.headersFromNetty(nettyReq.headers()), + headers, Conversions.methodFromNetty(nettyReq.method()), URL.fromString(nettyReq.uri()).getOrElse(URL.empty), protocolVersion, diff --git a/zio-http/src/test/scala/zio/http/forms/FormSpec.scala b/zio-http/src/test/scala/zio/http/forms/FormSpec.scala index 595d5b2377..919d96c1fa 100644 --- a/zio-http/src/test/scala/zio/http/forms/FormSpec.scala +++ b/zio-http/src/test/scala/zio/http/forms/FormSpec.scala @@ -23,7 +23,8 @@ import scala.annotation.nowarn import zio._ import zio.test._ -import zio.http.forms +import zio.stream.{ZStream, ZStreamAspect} + import zio.http.forms.Fixtures._ import zio.http.model.Header.ContentTransferEncoding import zio.http.model.MediaType @@ -56,13 +57,13 @@ object FormSpec extends ZIOSpecDefault { test("encoding") { val form = Form( - FormData.Text("csv-data", "foo,bar,baz", MediaType.text.csv), - FormData.Binary( + FormData.textField("csv-data", "foo,bar,baz", MediaType.text.csv), + FormData.binaryField( "file", Chunk[Byte](0x50, 0x4e, 0x47), MediaType.image.png, ), - FormData.Binary( + FormData.binaryField( "corgi", Chunk.fromArray(base64Corgi.getBytes()), MediaType.image.png, @@ -71,57 +72,135 @@ object FormSpec extends ZIOSpecDefault { ), ) - val (_, actualBytes) = form.encodeAsMultipartBytes(rng = () => "AaB03x") - val form2 = Form.fromMultipartBytes(multipartFormBytes2) - - form2.map { form2 => - assertTrue( - actualBytes == multipartFormBytes2, - form2 == form, - ).??(new String(actualBytes.toArray, StandardCharsets.UTF_8)) - } + val (_, actualByteStream) = form.encodeAsMultipartBytes(rng = () => "AaB03x") + for { + form2 <- Form.fromMultipartBytes(multipartFormBytes2) + actualBytes <- actualByteStream.runCollect + } yield assertTrue( + actualBytes == multipartFormBytes2, + form2 == form, + ) }, test("decoding") { val boundary = Boundary("AaB03x") - val form = Form.fromMultipartBytes(multipartFormBytes1) - - form.map { form => - val bytes = form.encodeAsMultipartBytes(StandardCharsets.UTF_8, boundary) - - val (text: FormData.Text) :: (image1: FormData.Binary) :: (image2: FormData.Binary) :: Nil = - form.formData.toList - assertTrue( - bytes._2 == multipartFormBytes1, - form.formData.size == 3, - text.name == "submit-name", - text.value == "Larry", - text.contentType == MediaType.text.plain, - text.filename.isEmpty, - image1.name == "files", - image1.data == Chunk[Byte](0x50, 0x4e, 0x47), - image1.contentType == MediaType.image.png, - image1.transferEncoding.isEmpty, - image1.filename.get == "file1.txt", - image2.name == "corgi", - image2.contentType == MediaType.image.png, - image2.transferEncoding.get == ContentTransferEncoding.Base64, - image2.data == Chunk.fromArray(base64Corgi.getBytes()), - ) - } + for { + form <- Form.fromMultipartBytes(multipartFormBytes1) + encoding = form.encodeAsMultipartBytes(StandardCharsets.UTF_8, boundary) + bytes <- encoding._2.runCollect + (text: FormData.Text) :: (image1: FormData.Binary) :: (image2: FormData.Binary) :: Nil = form.formData.toList + } yield assertTrue( + bytes == multipartFormBytes1, + form.formData.size == 3, + text.name == "submit-name", + text.value == "Larry", + text.contentType == MediaType.text.plain, + text.filename.isEmpty, + image1.name == "files", + image1.data == Chunk[Byte](0x50, 0x4e, 0x47), + image1.contentType == MediaType.image.png, + image1.transferEncoding.isEmpty, + image1.filename.get == "file1.txt", + image2.name == "corgi", + image2.contentType == MediaType.image.png, + image2.transferEncoding.get == ContentTransferEncoding.Base64, + image2.data == Chunk.fromArray(base64Corgi.getBytes()), + ) }, test("decoding 2") { Form.fromMultipartBytes(multipartFormBytes3).map { form => assertTrue( form.get("file").get.filename.get == "test.jsonl", form.get("file").get.valueAsString.isEmpty, - form.get("file").get.asInstanceOf[FormData.Binary].data.size == 67, + form.get("file").get.asInstanceOf[FormData.Binary].data.size == 69, ) } }, ) - def spec = suite("FormSpec")(urlEncodedSuite, multiFormSuite) + val multiFormStreamingSuite: Spec[Any, Throwable] = + suite("multipart/form-data streaming")( + test("encoding") { + + val form = Form( + FormData.textField("csv-data", "foo,bar,baz", MediaType.text.csv), + FormData.streamingBinaryField( + "file", + ZStream.fromChunk(Chunk[Byte](0x50, 0x4e, 0x47)) @@ ZStreamAspect.rechunk(3), + MediaType.image.png, + ), + FormData.streamingBinaryField( + "corgi", + ZStream.fromChunk(Chunk.fromArray(base64Corgi.getBytes())) @@ ZStreamAspect.rechunk(8), + MediaType.image.png, + Some(ContentTransferEncoding.Base64), + Some("corgi.png"), + ), + ) + + val (_, actualByteStream) = form.encodeAsMultipartBytes(rng = () => "AaB03x") + + for { + form2 <- Form.fromMultipartBytes(multipartFormBytes2) + actualBytes <- actualByteStream.runCollect + collectedForm <- form.collectAll + } yield assertTrue( + actualBytes == multipartFormBytes2, + form2 == collectedForm, + ) + }, + test("decoding") { + val boundary = Boundary("AaB03x") + + val stream = ZStream.fromChunk(multipartFormBytes1) @@ ZStreamAspect.rechunk(4) + val form = StreamingForm(stream, boundary, StandardCharsets.UTF_8) + + form.data + .mapZIOPar(1) { + case sb: FormData.StreamingBinary => + sb.collect + case other: FormData => + ZIO.succeed(other) + } + .runCollect + .map { formData => + val (text: FormData.Text) :: (image1: FormData.Binary) :: (image2: FormData.Binary) :: Nil = formData.toList + assertTrue( + formData.size == 3, + text.name == "submit-name", + text.value == "Larry", + text.contentType == MediaType.text.plain, + text.filename.isEmpty, + image1.name == "files", + image1.data == Chunk[Byte](0x50, 0x4e, 0x47), + image1.contentType == MediaType.image.png, + image1.transferEncoding.isEmpty, + image1.filename.get == "file1.txt", + image2.name == "corgi", + image2.contentType == MediaType.image.png, + image2.transferEncoding.get == ContentTransferEncoding.Base64, + image2.data == Chunk.fromArray(base64Corgi.getBytes()), + ) + } + }, + test("decoding 2") { + val boundary = Boundary("X-INSOMNIA-BOUNDARY") + val stream = ZStream.fromChunk(multipartFormBytes3) @@ ZStreamAspect.rechunk(16) + val streamingForm = StreamingForm(stream, boundary, StandardCharsets.UTF_8) + streamingForm.collectAll.map { form => + val contents = + new String(form.get("file").get.asInstanceOf[FormData.Binary].data.toArray, StandardCharsets.UTF_8) + assertTrue( + form.get("file").get.filename.get == "test.jsonl", + form.get("file").get.valueAsString.isEmpty, + form.get("file").get.asInstanceOf[FormData.Binary].data.size == 69, + contents == """{"prompt": "", "completion": ""}""" + "\r\n", + ) + } + }, + ) + + def spec = suite("FormSpec")(urlEncodedSuite, multiFormSuite, multiFormStreamingSuite) }