Skip to content

Commit

Permalink
Implementation and tests - non seekable multipart form files
Browse files Browse the repository at this point in the history
  • Loading branch information
ordinaryorange committed Nov 23, 2021
1 parent bbb1c8c commit f644892
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 12 deletions.
34 changes: 22 additions & 12 deletions src/Net/Http.fs
Original file line number Diff line number Diff line change
Expand Up @@ -1313,11 +1313,11 @@ module internal HttpHelpers =
let mutable streams = streams |> Seq.cache

let rec readFromStream buffer offset count =
if Seq.isEmpty streams
then 0
else
let stream = Seq.head streams
let read = stream.Read(buffer, offset, min count (int stream.Length))
match streams |> Seq.tryHead with
| None -> 0
| Some stream ->
let qty = if stream.CanSeek then min count (int stream.Length) else count
let read = stream.Read(buffer, offset, qty)
if read < count
then
stream.Dispose()
Expand All @@ -1327,9 +1327,9 @@ module internal HttpHelpers =
else read

override x.CanRead = true
override x.CanSeek = false
override x.CanSeek = match length with | None -> false | Some _ -> true
override x.CanWrite = false
override x.Length with get () = length
override x.Length with get () = length |> Option.defaultWith (fun () -> NotSupportedException() |> raise)
override x.Position with get () = v and set(_) = failwith "no position setting"
override x.Flush() = ()
override x.CanTimeout = false
Expand All @@ -1351,7 +1351,16 @@ module internal HttpHelpers =
let writeMultipart (boundary: string) (parts: seq<MultipartItem>) (e : Encoding) =
let newlineStream () = new MemoryStream(e.GetBytes "\r\n") :> Stream
let prefixedBoundary = sprintf "--%s" boundary
let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, fileStream)) ->
let trySumLength streams = //allows seq to be blocking & non seekable
let mutable seekable = true
let mutable length = 0L
let takeIfSeekable (str: Stream) =
seekable <- str.CanSeek
if str.CanSeek then length <- length + str.Length
str.CanSeek
streams |> Seq.takeWhile takeIfSeekable |> List.ofSeq |> ignore
if seekable then Some length else None
let segments = parts |> Seq.map (fun (MultipartItem(formField, fileName, contentStream)) ->
let fileExt = Path.GetExtension fileName
let contentType = defaultArg (MimeTypes.tryFind fileExt) "application/octet-stream"
let printHeader (header, value) = sprintf "%s: %s" header value
Expand All @@ -1367,9 +1376,9 @@ module internal HttpHelpers =
[ headerStream
newlineStream()
newlineStream()
fileStream
contentStream
newlineStream()]
let partLength = partSubstreams |> Seq.sumBy (fun s -> s.Length)
let partLength = partSubstreams |> trySumLength
new CombinedStream(partLength, partSubstreams) :> Stream
)

Expand All @@ -1380,7 +1389,7 @@ module internal HttpHelpers =
new MemoryStream(bytes) :> Stream

let wholePayload = Seq.append segments [newlineStream(); endBoundaryStream; ]
let wholePayloadLength = wholePayload |> Seq.sumBy (fun s -> s.Length)
let wholePayloadLength = wholePayload |> trySumLength
new CombinedStream(wholePayloadLength, wholePayload) :> Stream

let asyncCopy (source: Stream) (dest: Stream) =
Expand All @@ -1393,7 +1402,8 @@ module internal HttpHelpers =

let writeBody (req:HttpWebRequest) (data: Stream) =
async {
req.ContentLength <- data.Length
if data.CanSeek then
req.ContentLength <- data.Length
use! output = req.GetRequestStreamAsync () |> Async.AwaitTask
do! asyncCopy data output
output.Flush()
Expand Down
59 changes: 59 additions & 0 deletions tests/FSharp.Data.Tests/Http.fs
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,62 @@ let ``correct multipart content format`` () =
let singleMultipartFormat file = sprintf "--%s\r\nContent-Disposition: form-data; name=\"%i\"; filename=\"%i\"\r\nContent-Type: application/octet-stream\r\n\r\n%s\r\n" boundary file file content
let finalFormat = [sprintf "\r\n--%s--" boundary] |> Seq.append (seq {for i in [0..numFiles] -> singleMultipartFormat i }) |> String.concat ""
str |> should equal finalFormat

[<Test>]
let ``CombinedStream has length with Some length`` () =
use combinedStream = new HttpHelpers.CombinedStream(Some 10L, [])
combinedStream.Length |> should equal 10L

[<Test>]
let ``CombinedStream can seek with Some length`` () =
use combinedStream = new HttpHelpers.CombinedStream(Some 10L, [])
combinedStream.CanSeek |> should equal true

[<Test>]
let ``CombinedStream length throws with None length`` () =
use combinedStream = new HttpHelpers.CombinedStream(None, [])
(fun () -> combinedStream.Length |> ignore) |> should throw typeof<System.NotSupportedException>

[<Test>]
let ``CombinedStream cannot seek with None length`` () =
use combinedStream = new HttpHelpers.CombinedStream(None, [])
combinedStream.CanSeek |> should equal false

type nonSeekableStream (b: byte[]) =
inherit IO.MemoryStream(b)
override _.Length with get():Int64 = failwith "Im not seekable"
override _.CanSeek with get() = false

[<Test>]
let ``Non-seekable streams create non-seekable CombinedStream`` () =
use nonSeekms = new nonSeekableStream(Array.zeroCreate 10)
let multiparts = [MultipartItem("","", nonSeekms)]
let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8
(fun () -> combinedStream.Length |> ignore) |> should throw typeof<System.NotSupportedException>
combinedStream.CanSeek |> should equal false

[<Test>]
let ``Seekable streams create Seekable CombinedStream`` () =
let byteLen = 10L
let result = byteLen + 110L //110 is headers
use ms = new IO.MemoryStream(Array.zeroCreate (int byteLen))
let multiparts = [MultipartItem("","", ms)]
let combinedStream = HttpHelpers.writeMultipart "-" multiparts Encoding.UTF8
combinedStream.Length |> should equal result
combinedStream.CanSeek |> should equal true

[<Test>]
let ``HttpWebRequest length is set with seekable streams`` () =
use ms = new IO.MemoryStream(Array.zeroCreate 10)
let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest
wr.Method <- "POST"
HttpHelpers.writeBody wr ms |> Async.RunSynchronously
wr.ContentLength |> should equal 10

[<Test>]
let ``HttpWebRequest length is not set with non-seekable streams`` () =
use nonSeekms = new nonSeekableStream(Array.zeroCreate 10)
let wr = Net.HttpWebRequest.Create("http://x") :?> Net.HttpWebRequest
wr.Method <- "POST"
HttpHelpers.writeBody wr nonSeekms |> Async.RunSynchronously
wr.ContentLength |> should equal 0

0 comments on commit f644892

Please sign in to comment.