Skip to content

Commit

Permalink
Use stream type at higher levels of the API
Browse files Browse the repository at this point in the history
  • Loading branch information
aantron committed Feb 8, 2022
1 parent c7fb937 commit f82a51b
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 85 deletions.
33 changes: 23 additions & 10 deletions src/dream.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ type client = Message.client
type server = Message.server
type 'a promise = 'a Message.promise

type stream = Stream.stream



(* Methods *)
Expand Down Expand Up @@ -149,12 +151,16 @@ let all_cookies = Cookie.all_cookies

let body = Message.body
let set_body = Message.set_body
let read = Helpers.read
let write = Helpers.write
let flush = Helpers.flush
let close = Helpers.close



(* Streaming I/O *)

let read = Message.read
let write = Message.write
let flush = Message.flush
let close = Message.close
type buffer = Stream.buffer
type stream = Stream.stream
let client_stream = Message.client_stream
let server_stream = Message.server_stream
let set_client_stream = Message.set_client_stream
Expand Down Expand Up @@ -414,12 +420,19 @@ let write_buffer ?(offset = 0) ?length message chunk =
| None -> Bigstringaf.length chunk - offset
in
let string = Bigstringaf.substring chunk ~off:offset ~len:length in
write ~kind:`Binary message string
write ~kind:`Binary (Message.server_stream message) string

type websocket =
Message.response

let send ?kind response chunk =
write ?kind (Message.server_stream response) chunk

let receive response =
read (Message.server_stream response)

type websocket = Message.response
let send = write
let receive = read
let close_websocket = close
let close_websocket ?code response =
close ?code (Message.server_stream response)

type 'a local = 'a Message.field
let new_local = Message.new_field
Expand Down
17 changes: 10 additions & 7 deletions src/dream.mli
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ and 'a promise = 'a Lwt.t
exception backtrace — though, in most cases, you should still extend it with
[raise] and [let%lwt], instead. *)

type stream
(* TODO Document. *)



(** {1 Methods} *)
Expand Down Expand Up @@ -522,7 +525,7 @@ val stream :
?status:[< status ] ->
?code:int ->
?headers:(string * string) list ->
(response -> unit promise) -> response promise
(stream -> unit promise) -> response promise
(** Same as {!Dream.val-respond}, but calls {!Dream.set_stream} internally to
prepare the response for stream writing, and then runs the callback
asynchronously to do it. See example
Expand All @@ -538,7 +541,7 @@ val stream :

val websocket :
?headers:(string * string) list ->
(response -> unit promise) -> response promise
(stream -> unit promise) -> response promise
(** Creates a fresh [101 Switching Protocols] response. Once this response is
returned to Dream's HTTP layer, the callback is passed a new
{!type-websocket}, and the application can begin using it. See example
Expand Down Expand Up @@ -763,8 +766,9 @@ https://aantron.github.io/dream/#val-set_body
(**/**)

(** {2 Streaming} *)
(* TODO Should probably be promoted to its own section. *)

val read : 'a message -> string option promise
val read : stream -> string option promise
(** Retrieves a body chunk. The chunk is not buffered, thus it can only be read
once. See example
{{:https://github.com/aantron/dream/tree/master/example/j-stream#files}
Expand All @@ -780,15 +784,15 @@ https://aantron.github.io/dream/#val-set_stream
"]
(**/**)

val write : ?kind:[< `Text | `Binary ] -> response -> string -> unit promise
val write : ?kind:[< `Text | `Binary ] -> stream -> string -> unit promise
(** Streams out the string. The promise is fulfilled when the response can
accept more writes. *)
(* TODO Document clearly which of the writing functions can raise exceptions. *)

val flush : response -> unit promise
val flush : stream -> unit promise
(** Flushes write buffers. Data is sent to the client. *)

val close : ?code:int -> 'a message -> unit promise
val close : ?code:int -> stream -> unit promise
(** Finishes the response stream. *)
(* TODO Fix comment. *)

Expand All @@ -811,7 +815,6 @@ type buffer =
(* TODO Remove old functions from signature. *)
(* TODO Should there be a section for this somewhere? Probably "low-level
streaming" should be promoted to a top-level section, Streaming. *)
type stream

val client_stream : 'a message -> stream
val server_stream : 'a message -> stream
Expand Down
48 changes: 24 additions & 24 deletions src/graphql/graphql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ let run_query make_context schema request json =
let operation_id json =
Yojson.Basic.Util.(json |> member "id" |> to_string_option)

let close_and_clean ?code subscriptions response =
let%lwt () = Message.close ?code response in
let close_and_clean ?code subscriptions websocket =
let%lwt () = Message.close ?code websocket in
Hashtbl.iter (fun _ close -> close ()) subscriptions;
Lwt.return_unit

Expand Down Expand Up @@ -114,12 +114,12 @@ let complete_message id =

(* TODO Take care to pass around the request Lwt.key in async, etc. *)
(* TODO Test client complete racing against a stream. *)
let handle_over_websocket make_context schema subscriptions request response =
let handle_over_websocket make_context schema subscriptions request (websocket : Stream.stream) =
let rec loop inited =
match%lwt Helpers.read response with
match%lwt Message.read websocket with
| None ->
log.info (fun log -> log ~request "GraphQL WebSocket closed by client");
close_and_clean subscriptions response
close_and_clean subscriptions websocket
| Some message ->

log.debug (fun log -> log ~request "Message '%s'" message);
Expand All @@ -128,38 +128,38 @@ let handle_over_websocket make_context schema subscriptions request response =
match Yojson.Basic.from_string message with
| exception _ ->
log.warning (fun log -> log ~request "GraphQL message is not JSON");
close_and_clean subscriptions response ~code:4400
close_and_clean subscriptions websocket ~code:4400
| json ->

match Yojson.Basic.Util.(json |> member "type" |> to_string_option) with
| None ->
log.warning (fun log -> log ~request "GraphQL message lacks a type");
close_and_clean subscriptions response ~code:4400
close_and_clean subscriptions websocket ~code:4400
| Some message_type ->

match message_type with

| "connection_init" ->
if inited then begin
log.warning (fun log -> log ~request "Duplicate connection_init");
close_and_clean subscriptions response ~code:4429
close_and_clean subscriptions websocket ~code:4429
end
else begin
let%lwt () = Helpers.write response ack_message in
let%lwt () = Message.write websocket ack_message in
loop true
end

| "complete" ->
if not inited then begin
log.warning (fun log -> log ~request "complete before connection_init");
close_and_clean subscriptions response ~code:4401
close_and_clean subscriptions websocket ~code:4401
end
else begin
match operation_id json with
| None ->
log.warning (fun log ->
log ~request "client complete: operation id missing");
close_and_clean subscriptions response ~code:4400
close_and_clean subscriptions websocket ~code:4400
| Some id ->
begin match Hashtbl.find_opt subscriptions id with
| None -> ()
Expand All @@ -172,14 +172,14 @@ let handle_over_websocket make_context schema subscriptions request response =
if not inited then begin
log.warning (fun log ->
log ~request "subscribe before connection_init");
close_and_clean subscriptions response ~code:4401
close_and_clean subscriptions websocket ~code:4401
end
else begin
match operation_id json with
| None ->
log.warning (fun log ->
log ~request "subscribe: operation id missing");
close_and_clean subscriptions response ~code:4400
close_and_clean subscriptions websocket ~code:4400
| Some id ->

let payload = json |> Yojson.Basic.Util.member "payload" in
Expand All @@ -193,21 +193,21 @@ let handle_over_websocket make_context schema subscriptions request response =
log.warning (fun log ->
log ~request
"subscribe: error %s" (Yojson.Basic.to_string json));
Helpers.write response (error_message id json)
Message.write websocket (error_message id json)

(* It's not clear that this case ever occurs, because graphql-ws is
only used for subscriptions, at the protocol level. *)
| Ok (`Response json) ->
let%lwt () = Helpers.write response (data_message id json) in
let%lwt () = Helpers.write response (complete_message id) in
let%lwt () = Message.write websocket (data_message id json) in
let%lwt () = Message.write websocket (complete_message id) in
Lwt.return_unit

| Ok (`Stream (stream, close)) ->
match Hashtbl.mem subscriptions id with
| true ->
log.warning (fun log ->
log ~request "subscribe: duplicate operation id");
close_and_clean subscriptions response ~code:4409
close_and_clean subscriptions websocket ~code:4409
| false ->

Hashtbl.replace subscriptions id close;
Expand All @@ -216,15 +216,15 @@ let handle_over_websocket make_context schema subscriptions request response =
let%lwt () =
stream |> Lwt_stream.iter_s (function
| Ok json ->
Helpers.write response (data_message id json)
Message.write websocket (data_message id json)
| Error json ->
log.warning (fun log ->
log ~request
"Subscription: error %s" (Yojson.Basic.to_string json));
Helpers.write response (error_message id json))
Message.write websocket (error_message id json))
in

let%lwt () = Helpers.write response (complete_message id) in
let%lwt () = Message.write websocket (complete_message id) in
Hashtbl.remove subscriptions id;
Lwt.return_unit

Expand All @@ -240,12 +240,12 @@ let handle_over_websocket make_context schema subscriptions request response =

try%lwt
let%lwt () =
Helpers.write
response
Message.write
websocket
(error_message id (make_error "Internal Server Error"))
in
if !subscribed then
Helpers.write response (complete_message id)
Message.write websocket (complete_message id)
else
Lwt.return_unit
with _ ->
Expand All @@ -258,7 +258,7 @@ let handle_over_websocket make_context schema subscriptions request response =
| message_type ->
log.warning (fun log ->
log ~request "Unknown WebSocket message type '%s'" message_type);
close_and_clean subscriptions response ~code:4400
close_and_clean subscriptions websocket ~code:4400
in

loop false
Expand Down
4 changes: 2 additions & 2 deletions src/pure/message.ml
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ let flush stream =
promise

(* TODO Should close even be promise-valued? *)
let close ?(code = 1000) message =
Stream.close message.server_stream code;
let close ?(code = 1000) stream =
Stream.close stream code;
Lwt.return_unit

let client_stream message =
Expand Down
2 changes: 1 addition & 1 deletion src/pure/message.mli
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ val read : Stream.stream -> string option promise
val write :
?kind:[< `Text | `Binary ] -> Stream.stream -> string -> unit promise
val flush : Stream.stream -> unit promise
val close : ?code:int -> 'a message -> unit promise
val close : ?code:int -> Stream.stream -> unit promise
val client_stream : 'a message -> Stream.stream
val server_stream : 'a message -> Stream.stream
val set_client_stream : 'a message -> Stream.stream -> unit
Expand Down
41 changes: 2 additions & 39 deletions src/server/helpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ let stream ?status ?code ?headers callback =
Message.response ?status ?code ?headers client_stream server_stream in
(* TODO Should set up an error handler for this. YES. *)
(* TODO Make sure the request id is propagated to the callback. *)
Lwt.async (fun () -> callback response);
Lwt.async (fun () -> callback server_stream);
Lwt.return response

(* TODO Mark the request as a WebSocket request for HTTP. *)
Expand All @@ -109,50 +109,13 @@ let websocket ?headers callback =
Message.response
~status:`Switching_Protocols ?headers Stream.empty Stream.null in
let server_stream = Message.create_websocket response in
(* TODO Figure out what should actually be returned to the client and/or
provided to the callback. Probably the server stream. The surface API for
WebSockets also needs to be designed. *)
ignore server_stream;
(* TODO Make sure the request id is propagated to the callback. *)
(* TODO Close the WwbSocket on leaked exceptions, etc. *)
Lwt.async (fun () -> callback response);
Lwt.async (fun () -> callback server_stream);
Lwt.return response

let empty ?headers status =
respond ?headers ~status ""

let not_found _ =
respond ~status:`Not_Found ""



(* TODO Once the WebSocket API exists, these functions should not check whether
the message has a WebSocket. *)
let read message =
match Message.get_websocket message with
| None ->
Message.read (Message.server_stream message)
| Some (_client_stream, server_stream) ->
Message.read server_stream

let write ?kind message chunk =
match Message.get_websocket message with
| None ->
Message.write ?kind (Message.server_stream message) chunk
| Some (_client_stream, server_stream) ->
Message.write ?kind server_stream chunk

let flush message =
match Message.get_websocket message with
| None ->
Message.flush (Message.server_stream message)
| Some (_client_stream, server_stream) ->
Message.flush server_stream

let close ?(code = 1000) message =
match Message.get_websocket message with
| None ->
Message.close message
| Some (_client_stream, server_stream) ->
Stream.close server_stream code;
Lwt.return_unit
8 changes: 6 additions & 2 deletions src/server/upload.ml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ and upload (request : Message.request) =
failwith message

| Some content_type ->
let body = Lwt_stream.from (fun () -> Helpers.read request) in
let body =
Lwt_stream.from (fun () ->
Message.read (Message.server_stream request)) in
let `Parse th, stream =
Multipart_form_lwt.stream ~identify body content_type in
Lwt.async (fun () -> let%lwt _ = th in Lwt.return_unit);
Expand All @@ -135,7 +137,9 @@ let multipart ?(csrf=true) ~now request =
match content_type with
| None -> Lwt.return `Wrong_content_type
| Some content_type ->
let body = Lwt_stream.from (fun () -> Helpers.read request) in
let body =
Lwt_stream.from (fun () ->
Message.read (Message.server_stream request)) in
match%lwt Multipart_form_lwt.of_stream_to_list body content_type with
| Error (`Msg _err) ->
Lwt.return `Wrong_content_type (* XXX(dinosaure): better error? *)
Expand Down

0 comments on commit f82a51b

Please sign in to comment.