diff --git a/capnp-rpc-net/capTP_capnp.ml b/capnp-rpc-net/capTP_capnp.ml index 18b94316..dc49c54c 100644 --- a/capnp-rpc-net/capTP_capnp.ml +++ b/capnp-rpc-net/capTP_capnp.ml @@ -1,33 +1,5 @@ open Eio.Std -module Metrics = struct - open Prometheus - - let namespace = "capnp" - - let subsystem = "net" - - let connections = - let help = "Number of live capnp-rpc connections" in - Gauge.v ~help ~namespace ~subsystem "connections" - - let messages_inbound_received_total = - let help = "Total number of messages received" in - Counter.v ~help ~namespace ~subsystem "messages_inbound_received_total" - - let messages_outbound_enqueued_total = - let help = "Total number of messages enqueued to be transmitted" in - Counter.v ~help ~namespace ~subsystem "messages_outbound_enqueued_total" - - let messages_outbound_sent_total = - let help = "Total number of messages transmitted" in - Counter.v ~help ~namespace ~subsystem "messages_outbound_sent_total" - - let messages_outbound_dropped_total = - let help = "Total number of messages lost due to disconnections" in - Counter.v ~help ~namespace ~subsystem "messages_outbound_dropped_total" -end - module Log = Capnp_rpc.Debug.Log module Builder = Capnp_rpc.Private.Schema.Builder @@ -42,10 +14,8 @@ module Make (Network : S.NETWORK) = struct module Serialise = Serialise.Make(Endpoint_types) type t = { - sw : Switch.t; endpoint : Endpoint.t; conn : Conn.t; - xmit_queue : Capnp.Message.rw Capnp.BytesMessage.Message.t Queue.t; mutable disconnecting : bool; } @@ -60,94 +30,49 @@ module Make (Network : S.NETWORK) = struct let tags t = Conn.tags t.conn - let drop_queue q = - Prometheus.Counter.inc Metrics.messages_outbound_dropped_total (float_of_int (Queue.length q)); - Queue.clear q - - (* [flush ~xmit_queue endpoint] writes each message in the queue until it is empty. - Invariant: - Whenever Eio blocks or switches threads, a flush thread is running iff the - queue is non-empty. *) - let rec flush ~xmit_queue endpoint = - (* We keep the item on the queue until it is transmitted, as the queue state - tells us whether there is a [flush] currently running. *) - let next = Queue.peek xmit_queue in - match Endpoint.send endpoint next with - | Error `Closed -> - Endpoint.disconnect endpoint; (* We'll read a close soon *) - drop_queue xmit_queue - | Error (`Msg msg) -> - Log.warn (fun f -> f "Error sending messages: %s (will shutdown connection)" msg); - Endpoint.disconnect endpoint; - drop_queue xmit_queue - | Ok () -> - Prometheus.Counter.inc_one Metrics.messages_outbound_sent_total; - ignore (Queue.pop xmit_queue); - if not (Queue.is_empty xmit_queue) then - flush ~xmit_queue endpoint - (* else queue is empty and flush thread is done *) - | exception ex -> - drop_queue xmit_queue; - raise ex - - (* Enqueue [message] in [xmit_queue] and ensure the flush thread is running. *) - let queue_send ~sw ~xmit_queue endpoint message = - Log.debug (fun f -> - let module M = Capnp_rpc.Private.Schema.MessageWrapper.Message in - f "queue_send: %d/%d allocated bytes in %d segs" - (M.total_size message) - (M.total_alloc_size message) - (M.num_segments message)); - let was_idle = Queue.is_empty xmit_queue in - Queue.add message xmit_queue; - Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total; - if was_idle then Eio.Fiber.fork ~sw (fun () -> flush ~xmit_queue endpoint) - let return_not_implemented t x = Log.debug (fun f -> f ~tags:(tags t) "Returning Unimplemented"); let open Builder in let m = Message.init_root () in let _ : Builder.Message.t = Message.unimplemented_set_reader m x in - queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Message.to_message m) - - let listen t = - let rec loop () = - match Endpoint.recv t.endpoint with - | Error e -> e - | Ok msg -> - let open Reader.Message in - let msg = of_message msg in - Prometheus.Counter.inc_one Metrics.messages_inbound_received_total; - match Parse.message msg with - | #Endpoint_types.In.t as msg -> - Log.debug (fun f -> - let tags = Endpoint_types.In.with_qid_tag (Conn.tags t.conn) msg in - f ~tags "<- %a" (Endpoint_types.In.pp_recv pp_msg) msg); - begin match msg with - | `Abort _ -> - t.disconnecting <- true; - Conn.handle_msg t.conn msg; - Endpoint.disconnect t.endpoint; - `Aborted - | _ -> - Conn.handle_msg t.conn msg; - loop () - end - | `Unimplemented x as msg -> - Log.info (fun f -> - let tags = Endpoint_types.Out.with_qid_tag (Conn.tags t.conn) x in - f ~tags "<- Unimplemented(%a)" (Endpoint_types.Out.pp_recv pp_msg) x); - Conn.handle_msg t.conn msg; - loop () - | `Not_implemented -> - Log.info (fun f -> f "<- unsupported message type"); - return_not_implemented t msg; - loop () - in - loop () + Endpoint.send t.endpoint (Message.to_message m) + + let rec listen t = + match Endpoint.recv ~tags:(tags t) t.endpoint with + | Error e -> e + | Ok msg -> + let open Reader.Message in + let msg = of_message msg in + match Parse.message msg with + | #Endpoint_types.In.t as msg -> + Log.debug (fun f -> + let tags = Endpoint_types.In.with_qid_tag (Conn.tags t.conn) msg in + f ~tags "<- %a" (Endpoint_types.In.pp_recv pp_msg) msg); + begin match msg with + | `Abort _ -> + t.disconnecting <- true; + Conn.handle_msg t.conn msg; + Endpoint.disconnect t.endpoint; + Conn.disconnect t.conn (Capnp_rpc_proto.Exception.v "Received Abort from peer"); + `Aborted + | _ -> + Conn.handle_msg t.conn msg; + listen t + end + | `Unimplemented x as msg -> + Log.info (fun f -> + let tags = Endpoint_types.Out.with_qid_tag (Conn.tags t.conn) x in + f ~tags "<- Unimplemented(%a)" (Endpoint_types.Out.pp_recv pp_msg) x); + Conn.handle_msg t.conn msg; + listen t + | `Not_implemented -> + Log.info (fun f -> f "<- unsupported message type"); + return_not_implemented t msg; + listen t let send_abort t ex = - queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Serialise.message (`Abort ex)) + Endpoint.send t.endpoint (Serialise.message (`Abort ex)); + Endpoint.flush t.endpoint (* We're probably about to disconnect *) let disconnect t ex = if not t.disconnecting then ( @@ -160,21 +85,17 @@ module Make (Network : S.NETWORK) = struct let disconnecting t = t.disconnecting let connect ~sw ~restore ?(tags=Logs.Tag.empty) endpoint = - let xmit_queue = Queue.create () in - let queue_send msg = queue_send ~sw ~xmit_queue endpoint (Serialise.message msg) in + let queue_send msg = Endpoint.send endpoint (Serialise.message msg) in let restore = Restorer.fn restore in let fork = Fiber.fork ~sw in let conn = Conn.create ~restore ~tags ~fork ~queue_send in { - sw; conn; endpoint; - xmit_queue; disconnecting = false; } let listen t = - Prometheus.Gauge.inc_one Metrics.connections; let tags = Conn.tags t.conn in begin match listen t with @@ -187,8 +108,6 @@ module Make (Network : S.NETWORK) = struct ); send_abort t (Capnp_rpc.Exception.v ~ty:`Failed (Printexc.to_string ex)) end; - Log.info (fun f -> f ~tags "Connection closed"); - Prometheus.Gauge.dec_one Metrics.connections; Eio.Cancel.protect (fun () -> disconnect t (Capnp_rpc.Exception.v ~ty:`Disconnected "Connection closed") ); diff --git a/capnp-rpc-net/endpoint.ml b/capnp-rpc-net/endpoint.ml index 140f42e3..96c4bc1c 100644 --- a/capnp-rpc-net/endpoint.ml +++ b/capnp-rpc-net/endpoint.ml @@ -1,5 +1,27 @@ open Eio.Std +module Metrics = struct + open Prometheus + + let namespace = "capnp" + + let subsystem = "net" + + let connections = + let help = "Number of live capnp-rpc connections" in + Gauge.v ~help ~namespace ~subsystem "connections" + + let messages_inbound_received_total = + let help = "Total number of messages received" in + Counter.v ~help ~namespace ~subsystem "messages_inbound_received_total" + + let messages_outbound_enqueued_total = + let help = "Total number of messages enqueued to be transmitted" in + Counter.v ~help ~namespace ~subsystem "messages_outbound_enqueued_total" +end + +module Write = Eio.Buf_write + let src = Logs.Src.create "endpoint" ~doc:"Send and receive Cap'n'Proto messages" module Log = (val Logs.src_log src: Logs.LOG) @@ -11,8 +33,10 @@ type flow = Eio.Flow.two_way_ty r type t = { flow : flow; + writer : Write.t; decoder : Capnp.Codecs.FramedStream.t; peer_id : Auth.Digest.t; + recv_buf : Cstruct.t; } let peer_id t = t.peer_id @@ -20,7 +44,9 @@ let peer_id t = t.peer_id let of_flow ~peer_id flow = let decoder = Capnp.Codecs.FramedStream.empty compression in let flow = (flow :> flow) in - { flow; decoder; peer_id } + let writer = Write.create 4096 in + let recv_buf = Cstruct.create 4096 in + { flow; writer; decoder; peer_id; recv_buf } let dump_msg = let next = ref 0 in @@ -33,42 +59,78 @@ let dump_msg = close_out ch let send t msg = - let data = Capnp.Codecs.serialize ~compression msg in - if record_sent_messages then dump_msg data; - match Eio.Flow.copy_string data t.flow with - | () - | exception End_of_file -> Ok () - | exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) -> - Log.info (fun f -> f "%a" Eio.Exn.pp ex); - Error `Closed - | exception ex -> - Eio.Fiber.check (); - Error (`Msg (Printexc.to_string ex)) + Log.debug (fun f -> + let module M = Capnp_rpc.Private.Schema.MessageWrapper.Message in + f "queue_send: %d/%d allocated bytes in %d segs" + (M.total_size msg) + (M.total_alloc_size msg) + (M.num_segments msg)); + Capnp.Codecs.serialize_iter_copyless ~compression msg ~f:(fun x len -> Write.string t.writer x ~len); + Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total; + if record_sent_messages then dump_msg (Capnp.Codecs.serialize ~compression msg) -let rec recv t = +let rec recv ~tags t = match Capnp.Codecs.FramedStream.get_next_frame t.decoder with - | Ok msg -> Ok (Capnp.BytesMessage.Message.readonly msg) + | Ok msg -> + Prometheus.Counter.inc_one Metrics.messages_inbound_received_total; + (* We often want to send multiple response messages while processing a batch of requests, + so pause the writer to collect them. We'll unpause on the next [single_read]. *) + Write.pause t.writer; + Ok (Capnp.BytesMessage.Message.readonly msg) | Error Capnp.Codecs.FramingError.Unsupported -> failwith "Unsupported Cap'n'Proto frame received" | Error Capnp.Codecs.FramingError.Incomplete -> - Log.debug (fun f -> f "Incomplete; waiting for more data..."); - let buf = Cstruct.create 4096 in (* TODO: make this efficient *) - match Eio.Flow.single_read t.flow buf with + Log.debug (fun f -> f ~tags "Incomplete; waiting for more data..."); + (* We probably scheduled one or more application fibers to run while handling the last + batch of messages. Give them a chance to run now while the writer is paused, because + they might want to send more messages immediately. *) + Fiber.yield (); + Write.unpause t.writer; + match Eio.Flow.single_read t.flow t.recv_buf with | got -> - Log.debug (fun f -> f "Read %d bytes" got); - Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string buf ~len:got); - recv t + Log.debug (fun f -> f ~tags "Read %d bytes" got); + Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string t.recv_buf ~len:got); + recv ~tags t | exception End_of_file -> - Log.info (fun f -> f "Connection closed"); + Log.info (fun f -> f ~tags "Received end-of-stream"); Error `Closed | exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) -> - Log.info (fun f -> f "%a" Eio.Exn.pp ex); + Log.info (fun f -> f ~tags "Receive failed: %a" Eio.Exn.pp ex); Error `Closed let disconnect t = try Eio.Flow.shutdown t.flow `All - with - | Invalid_argument _ - | Eio.Io (Eio.Net.E Connection_reset _, _) -> + with Eio.Io (Eio.Net.E Connection_reset _, _) -> (* TCP connection already shut down, so TLS shutdown failed. Ignore. *) () + +let flush t = + Write.unpause t.writer; + (* Give the writer a chance to send the last of the data. + We could use [Write.flush] to be sure the data got sent, but this code is + only used to send aborts, which isn't very important and it's probably + better to drop the buffered messages if one yield isn't enough. *) + Fiber.yield () + +let rec run_writer ~tags t = + let bufs = Write.await_batch t.writer in + match Eio.Flow.single_write t.flow bufs with + | n -> Write.shift t.writer n; run_writer ~tags t + | exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) -> + Log.info (fun f -> f ~tags "Send failed: %a" Eio.Exn.pp ex) + | exception ex -> + Eio.Fiber.check (); + Log.warn (fun f -> f ~tags "Error sending messages: %a (will shutdown connection)" Fmt.exn ex) + +let run_writer ~tags t = + let cleanup () = + Prometheus.Gauge.dec_one Metrics.connections; + disconnect t (* The listen fiber will read end-of-stream soon *) + in + Prometheus.Gauge.inc_one Metrics.connections; + match run_writer ~tags t with + | () -> cleanup () + | exception ex -> + let bt = Printexc.get_raw_backtrace () in + cleanup (); + Printexc.raise_with_backtrace ex bt diff --git a/capnp-rpc-net/endpoint.mli b/capnp-rpc-net/endpoint.mli index 674a6a29..95316d58 100644 --- a/capnp-rpc-net/endpoint.mli +++ b/capnp-rpc-net/endpoint.mli @@ -6,11 +6,15 @@ val src : Logs.src type t (** A wrapper for a byte-stream (flow). *) -val send : t -> 'a Capnp.BytesMessage.Message.t -> (unit, [`Closed | `Msg of string]) result -(** [send t msg] transmits [msg]. *) +val send : t -> 'a Capnp.BytesMessage.Message.t -> unit +(** [send t msg] enqueues [msg]. *) -val recv : t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result -(** [recv t] reads the next message from the remote peer. +val run_writer : tags:Logs.Tag.set -> t -> unit +(** [run_writer ~tags t] runs a loop that transmits batches of messages from [t]. + It returns when the flow is closed. *) + +val recv : tags:Logs.Tag.set -> t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result +(** [recv ~tags t] reads the next message from the remote peer. It returns [Error `Closed] if the connection to the peer is lost. *) val of_flow : peer_id:Auth.Digest.t -> _ Eio.Flow.two_way -> t @@ -19,6 +23,10 @@ val of_flow : peer_id:Auth.Digest.t -> _ Eio.Flow.two_way -> t val peer_id : t -> Auth.Digest.t (** [peer_id t] is the fingerprint of the peer's public key, or [Auth.Digest.insecure] if TLS isn't being used. *) + +val flush : t -> unit +(** [flush t] is useful to try to send any buffered data before disconnecting. + Otherwise, the final abort message is likely to get lost. *) val disconnect : t -> unit (** [disconnect t] shuts down the underlying flow. *) diff --git a/capnp-rpc-net/vat.ml b/capnp-rpc-net/vat.ml index 64827d8e..008d3c49 100644 --- a/capnp-rpc-net/vat.ml +++ b/capnp-rpc-net/vat.ml @@ -43,7 +43,11 @@ module Make (Network : S.NETWORK) = struct let run_connection_generic t ~add ~remove endpoint = let conn = CapTP.connect ~sw:t.sw ~tags:t.tags ~restore:t.restore endpoint in add conn; - Fun.protect (fun () -> CapTP.listen conn) + Fun.protect (fun () -> + Fiber.both + (fun () -> Endpoint.run_writer ~tags:t.tags endpoint) + (fun () -> CapTP.listen conn) + ) ~finally:(fun () -> remove conn; Eio.Condition.broadcast t.connection_removed diff --git a/test-bin/calc_direct.ml b/test-bin/calc_direct.ml index acf1a432..13002f23 100644 --- a/test-bin/calc_direct.ml +++ b/test-bin/calc_direct.ml @@ -29,18 +29,24 @@ module Logging = struct Logs.set_reporter (reporter id) end +let run_connection conn endpoint = + Fiber.all [ + (* Normally the vat runs a leak handler to free resources that get GC'd + with a non-zero reference count. We're not using a vat, so run it ourselves. *) + Capnp_rpc.Leak_handler.run; + (fun () -> Capnp_rpc_unix.CapTP.listen conn); + (fun () -> Capnp_rpc_net.Endpoint.run_writer ~tags:Logs.Tag.empty endpoint); + ] + module Parent = struct let run socket = Logging.init "parent"; Switch.run @@ fun sw -> - (* Normally the vat runs a leak handler to free resources that get GC'd - with a non-zero reference count. We're not using a vat, so run it ourselves. *) - Fiber.fork_daemon ~sw Capnp_rpc.Leak_handler.run; (* Run Cap'n Proto RPC protocol on [socket]: *) let p = Capnp_rpc_net.Endpoint.of_flow socket ~peer_id:Capnp_rpc_net.Auth.Digest.insecure in Logs.info (fun f -> f "Connecting to child process..."); let conn = Capnp_rpc_unix.CapTP.connect ~sw ~restore:Capnp_rpc_net.Restorer.none p in - Fiber.fork_daemon ~sw (fun () -> Capnp_rpc_unix.CapTP.listen conn; `Stop_daemon); + Fiber.fork_daemon ~sw (fun () -> run_connection conn p; `Stop_daemon); (* Get the child's service object: *) let calc = Capnp_rpc_unix.CapTP.bootstrap conn service_name in (* Use the service: *) @@ -56,17 +62,14 @@ module Child = struct let run socket = Logging.init "child"; Switch.run @@ fun sw -> - Fiber.fork_daemon ~sw Capnp_rpc.Leak_handler.run; let socket = Eio_unix.Net.import_socket_stream ~sw ~close_unix:false socket in let service = Calc.local ~sw in let restore = Capnp_rpc_net.Restorer.single service_name service in (* Run Cap'n Proto RPC protocol on [socket]: *) - let endpoint = Capnp_rpc_net.Endpoint.of_flow socket - ~peer_id:Capnp_rpc_net.Auth.Digest.insecure - in + let endpoint = Capnp_rpc_net.Endpoint.of_flow socket ~peer_id:Capnp_rpc_net.Auth.Digest.insecure in let conn = Capnp_rpc_unix.CapTP.connect ~sw ~restore endpoint in Logs.info (fun f -> f "Serving requests..."); - Capnp_rpc_unix.CapTP.listen conn + run_connection conn endpoint end let () = diff --git a/test-bin/echo/echo_bench.ml b/test-bin/echo/echo_bench.ml index 47ccac9a..0ee7daea 100755 --- a/test-bin/echo/echo_bench.ml +++ b/test-bin/echo/echo_bench.ml @@ -7,12 +7,11 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let run_client service = - (* let n = 100000 in *) (* XXX: improve speed *) - let n = 1000 in + let n = 100000 in let ops = List.init n (fun i -> let payload = Int.to_string i in let desired_result = "echo:" ^ payload in - fun () -> + fun () -> let res = Echo.ping service payload in assert (res = desired_result) ) in @@ -20,7 +19,7 @@ let run_client service = ops |> Fiber.List.iter ~max_fibers:12 (fun v -> v ()); let ed = Unix.gettimeofday () in let rate = (Int.to_float n) /. (ed -. st) in - Logs.info (fun m -> m "rate = %f" rate ) + Logs.info (fun m -> m "rate = %f" rate) let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) @@ -35,9 +34,10 @@ let start_server ~sw net = let () = Eio_main.run @@ fun env -> Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> - Switch.run @@ fun sw -> + Switch.run ~name:"main" @@ fun sw -> let uri = start_server ~sw env#net in Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; + Switch.run ~name:"client" @@ fun sw -> let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in Sturdy_ref.with_cap_exn sr run_client diff --git a/unix/network.mli b/unix/network.mli index 17cd939f..b38ff192 100644 --- a/unix/network.mli +++ b/unix/network.mli @@ -36,6 +36,6 @@ val accept_connection : secret_key:Capnp_rpc_net.Auth.Secret_key.t option -> [> Eio.Flow.two_way_ty | Eio.Resource.close_ty] r -> (Capnp_rpc_net.Endpoint.t, [> `Msg of string]) result -(** [accept_connection ~switch ~secret_key flow] is a new endpoint for [flow]. +(** [accept_connection ~secret_key flow] is a new endpoint for [flow]. If [secret_key] is not [None], it is used to perform a TLS server-side handshake. Otherwise, the connection is not encrypted. *)