diff --git a/CHANGES.md b/CHANGES.md index b16d84a84f7..b3b85a56253 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -29,6 +29,9 @@ Unreleased - Revert #7415 and #7450 (Resolve `ppx_runtime_libraries` in the target context when cross compiling) (#7887, fixes #7875, @emillon) +- Fix RPC buffer corruption issues due to multi threading. This issue was only + reproducible with large RPC payloads (#7418) + 3.8.0 (2023-05-23) ------------------ diff --git a/boot/libs.ml b/boot/libs.ml index 738066d831f..32bc4e1593c 100644 --- a/boot/libs.ml +++ b/boot/libs.ml @@ -33,6 +33,7 @@ let local_libraries = ; ("otherlibs/dune-private-libs/section", Some "Dune_section", false, None) ; ("src/dune_lang", Some "Dune_lang", false, None) ; ("vendor/opam-file-format", None, false, None) + ; ("src/dune_async_io", Some "Dune_async_io", false, None) ; ("src/fiber_util", Some "Fiber_util", false, None) ; ("src/dune_cache_storage", Some "Dune_cache_storage", false, None) ; ("src/dune_cache", Some "Dune_cache", false, None) diff --git a/src/csexp_rpc/csexp_rpc.ml b/src/csexp_rpc/csexp_rpc.ml index 01589bd8628..19bf3c8f9e0 100644 --- a/src/csexp_rpc/csexp_rpc.ml +++ b/src/csexp_rpc/csexp_rpc.ml @@ -1,23 +1,8 @@ open Stdune open Fiber.O +open Dune_async_io module Log = Dune_util.Log -module type Scheduler = sig - val async : (unit -> 'a) -> ('a, Exn_with_backtrace.t) result Fiber.t -end - -let scheduler = Fdecl.create Dyn.opaque - -let async f = - let module Scheduler = (val Fdecl.get scheduler : Scheduler) in - Scheduler.async f - -let async_exn f = - let+ res = async f in - match res with - | Ok s -> s - | Error e -> Exn_with_backtrace.reraise e - module Session_id = Id.Make () module Socket = struct @@ -97,19 +82,21 @@ end let debug = Option.is_some (Env.get Env.initial "DUNE_RPC_DEBUG") module Session = struct + let fail = function + | `Cancelled -> raise Dune_util.Report_error.Already_reported + | `Exn exn -> raise exn + module Id = Session_id type state = | Closed | Open of { out_buf : Io_buffer.t + ; in_buf : Io_buffer.t ; fd : Unix.file_descr - ; (* A mutex for modifying [out_buf]. - - Needed as long as we use threads for async IO. Once we switch to - event based IO, we won't need this mutex anymore *) - write_mutex : Mutex.t - ; in_channel : in_channel + ; mutable read_eof : bool + ; write_mutex : Fiber.Mutex.t + ; read_mutex : Fiber.Mutex.t } type t = @@ -117,39 +104,47 @@ module Session = struct ; mutable state : state } - let create fd in_channel = + let create fd = + Unix.set_nonblock fd; let id = Id.gen () in if debug then Log.info [ Pp.textf "RPC created new session %d" (Id.to_int id) ]; let state = + let size = 8192 in Open { fd - ; in_channel - ; out_buf = Io_buffer.create ~size:8192 - ; write_mutex = Mutex.create () + ; in_buf = Io_buffer.create ~size + ; out_buf = Io_buffer.create ~size + ; read_eof = false + ; write_mutex = Fiber.Mutex.create () + ; read_mutex = Fiber.Mutex.create () } in { id; state } let string_of_packet = function | None -> "EOF" - | Some csexp -> Sexp.to_string csexp + | Some csexp -> Dyn.to_string (Sexp.to_dyn csexp) let string_of_packets = function | None -> "EOF" | Some sexps -> String.concat ~sep:" " (List.map ~f:Sexp.to_string sexps) let close t = + let* () = Fiber.return () in match t.state with - | Closed -> () - | Open { write_mutex = _; fd = _; in_channel; out_buf = _ } -> - (* with a socket, there's only one fd. We make sure to close it only once. - with dune rpc init, we have two separate fd's (stdin/stdout) so we must - close both. *) - close_in_noerr in_channel; + | Closed -> Fiber.return () + | Open { fd; _ } -> + let+ () = Async_io.close fd in t.state <- Closed + module Lexer = Csexp.Parser.Lexer + module Stack = Csexp.Parser.Stack + + let min_read = 8192 + let read t = + let* () = Fiber.return () in let debug res = if debug then Log.info @@ -162,26 +157,79 @@ module Session = struct | Closed -> debug None; Fiber.return None - | Open { in_channel; _ } -> - let rec read () = - match Csexp.input_opt in_channel with - | exception Unix.Unix_error (_, _, _) -> None - | exception Sys_error _ -> None - | exception Sys_blocked_io -> read () - | Ok None -> None - | Ok (Some csexp) -> Some csexp - | Error _ -> None + | Open ({ fd; in_buf; read_mutex; _ } as open_) -> + let lexer = Lexer.create () in + let buf = Buffer.create 16 in + let rec refill () = + if Io_buffer.length in_buf > 0 then Fiber.return (Ok `Continue) + else if open_.read_eof then Fiber.return (Ok `Eof) + else + let* task = + Async_io.ready fd `Read ~f:(fun () -> + let () = Io_buffer.maybe_resize_to_fit in_buf min_read in + let pos = Io_buffer.write_pos in_buf in + let len = Io_buffer.max_write_len in_buf in + match Unix.read fd (Io_buffer.bytes in_buf) pos len with + | exception + Unix.Unix_error ((EAGAIN | EINTR | EWOULDBLOCK), _, _) -> + `Refill + | 0 -> + open_.read_eof <- true; + `Eof + | len -> + Io_buffer.commit_write in_buf ~len; + `Continue) + in + Async_io.Task.await task >>= function + | Error (`Exn e) -> Fiber.return (Error e) + | Error `Cancelled | Ok `Eof -> Fiber.return @@ Ok `Eof + | Ok `Continue -> Fiber.return @@ Ok `Continue + | Ok `Refill -> refill () + and read parser = + let* res = refill () in + match res with + | Error _ as e -> Fiber.return e + | Ok `Eof -> Fiber.return (Ok None) + | Ok `Continue -> ( + let char = Io_buffer.read_char_exn in_buf in + let token = Lexer.feed lexer char in + match token with + | Atom n -> + Buffer.clear buf; + atom parser n + | (Lparen | Rparen | Await) as token -> ( + let parser = Stack.add_token token parser in + match parser with + | Sexp (sexp, Empty) -> Fiber.return (Ok (Some sexp)) + | parser -> read parser)) + and atom parser n = + if n = 0 then + let atom = Buffer.contents buf in + match Stack.add_atom atom parser with + | Sexp (sexp, Empty) -> Fiber.return (Ok (Some sexp)) + | parser -> read parser + else + refill () >>= function + | Error _ as e -> Fiber.return e + | Ok `Eof -> Fiber.return (Ok None) + | Ok `Continue -> + let n' = Io_buffer.read_into_buffer in_buf buf ~max_len:n in + atom parser (n - n') in - let+ res = async read in - let res = + let+ res = + let* res = + Fiber.Mutex.with_lock read_mutex ~f:(fun () -> read Stack.Empty) + in match res with - | Error _ -> - close t; - None + | Error exn -> + Log.info + [ Pp.textf "Unable to read (%d)" (Id.to_int t.id); Exn.pp exn ]; + let+ () = close t in + reraise exn | Ok None -> - close t; + let+ () = close t in None - | Ok (Some sexp) -> Some sexp + | Ok (Some sexp) -> Fiber.return @@ Some sexp in debug res; res @@ -193,27 +241,37 @@ module Session = struct | Linux -> send | _ -> Unix.single_write - let rec csexp_write_loop fd out_buf token write_mutex = - Mutex.lock write_mutex; - if Io_buffer.flushed out_buf token then Mutex.unlock write_mutex + let rec csexp_write_loop fd out_buf token = + if Io_buffer.flushed out_buf token then Fiber.return (Ok ()) else (* We always make sure to try and write the entire buffer. This should minimize the amount of [write] calls we need to do *) - let written = - let bytes = Io_buffer.bytes out_buf in - let pos = Io_buffer.pos out_buf in - let len = Io_buffer.length out_buf in - try write fd bytes pos len - with exn -> - Mutex.unlock write_mutex; - reraise exn + let* task = + let* task = + Async_io.ready fd `Write ~f:(fun () -> + let bytes = Io_buffer.bytes out_buf in + let pos = Io_buffer.pos out_buf in + let len = Io_buffer.length out_buf in + match write fd bytes pos len with + | exception Unix.Unix_error ((EAGAIN | EINTR | EWOULDBLOCK), _, _) + -> `Continue + | exception Unix.Unix_error (EPIPE, _, _) -> `Cancelled + | exception exn -> `Exn exn + | written -> + Io_buffer.read out_buf written; + `Continue) + in + Async_io.Task.await task in - Io_buffer.read out_buf written; - Mutex.unlock write_mutex; - csexp_write_loop fd out_buf token write_mutex + match task with + | Error _ as e -> Fiber.return e + | Ok (`Exn exn) -> Fiber.return (Error (`Exn exn)) + | Ok `Cancelled -> Fiber.return (Error `Cancelled) + | Ok `Continue -> csexp_write_loop fd out_buf token let write t sexps = + let* () = Fiber.return () in if debug then Log.info [ Pp.verbatim @@ -235,63 +293,69 @@ module Session = struct (* TODO this hack is temporary until we get rid of dune rpc init *) Unix.shutdown fd Unix.SHUTDOWN_ALL with Unix.Unix_error (_, _, _) -> ()); - close t; - Fiber.return () + close t | Some sexps -> ( - let+ res = - Mutex.lock write_mutex; - Io_buffer.write_csexps out_buf sexps; - let flush_token = Io_buffer.flush_token out_buf in - Mutex.unlock write_mutex; - async (fun () -> csexp_write_loop fd out_buf flush_token write_mutex) + let* res = + Fiber.Mutex.with_lock write_mutex ~f:(fun () -> + Io_buffer.write_csexps out_buf sexps; + let flush_token = Io_buffer.flush_token out_buf in + csexp_write_loop fd out_buf flush_token) in match res with - | Ok () -> () - | Error e -> - close t; - Exn_with_backtrace.reraise e)) + | Ok () -> Fiber.return () + | Error error -> + let+ () = close t in + fail error)) end -let close_fd_no_error fd = try Unix.close fd with _ -> () - module Server = struct module Transport = struct type t = { fd : Unix.file_descr ; sockaddr : Unix.sockaddr - ; r_interrupt_accept : Unix.file_descr - ; w_interrupt_accept : Unix.file_descr - ; buf : Bytes.t + ; mutable task : (Unix.file_descr * Unix.sockaddr) Async_io.Task.t option + ; mutable running : bool } let create fd sockaddr ~backlog = Unix.listen fd backlog; - let r_interrupt_accept, w_interrupt_accept = Unix.pipe ~cloexec:true () in - let buf = Bytes.make 1 '0' in - { fd; sockaddr; r_interrupt_accept; w_interrupt_accept; buf } - - let accept t = - match Unix.select [ t.r_interrupt_accept; t.fd ] [] [] (-1.0) with - | r, [], [] -> - let inter, accept = - List.fold_left r ~init:(false, false) ~f:(fun (i, a) fd -> - if fd = t.fd then (i, true) - else if fd = t.r_interrupt_accept then (true, a) - else assert false) + Unix.set_nonblock fd; + { fd; sockaddr; task = None; running = true } + + let close t = + let+ () = Async_io.close t.fd in + Ok None + + let rec accept t = + let* () = Fiber.return () in + match t.running with + | false -> close t + | true -> ( + let* task = + Async_io.ready t.fd `Read ~f:(fun () -> + Unix.accept ~cloexec:true t.fd) in - if inter then None - else if accept then ( - let fd, _ = Unix.accept ~cloexec:true t.fd in + t.task <- Some task; + let* res = Async_io.Task.await task in + match res with + | Error (`Exn (Unix.Unix_error (Unix.EAGAIN, _, _))) -> accept t + | Error (`Exn exn) -> + let+ _ = close t in + Error (Exn_with_backtrace.capture exn) + | Error `Cancelled -> close t + | Ok (fd, _) -> Socket.maybe_set_nosigpipe fd; - Unix.clear_nonblock fd; - Some fd) - else assert false - | _, _, _ -> assert false - | exception Unix.Unix_error (Unix.EBADF, _, _) -> None + Unix.set_nonblock fd; + Fiber.return @@ Ok (Some fd)) let stop t = - let _ = Unix.write t.w_interrupt_accept t.buf 0 1 in - close_fd_no_error t.fd; + let* () = Fiber.return () in + t.running <- false; + let+ () = + match t.task with + | None -> Fiber.return () + | Some task -> Async_io.Task.cancel task + in match t.sockaddr with | ADDR_UNIX p -> Fpath.unlink_no_err p | _ -> () @@ -325,20 +389,11 @@ module Server = struct | `Closed -> Code_error.raise "already closed" [] | `Running _ -> Code_error.raise "already running" [] | `Init fd -> - let* transport = - async_exn (fun () -> Transport.create fd t.sockaddr ~backlog:t.backlog) - in + let transport = Transport.create fd t.sockaddr ~backlog:t.backlog in t.state <- `Running transport; let+ () = Fiber.Ivar.fill t.ready () in - let accept () = - async (fun () -> - Transport.accept transport - |> Option.map ~f:(fun client -> - let in_ = Unix.in_channel_of_descr client in - (client, in_))) - in let loop () = - let+ accept = accept () in + let+ accept = Transport.accept transport in match accept with | Error exn -> Log.info @@ -353,18 +408,21 @@ module Server = struct accepted." ]; None - | Ok (Some (fd, in_)) -> - let session = Session.create fd in_ in + | Ok (Some fd) -> + let session = Session.create fd in Some session in Fiber.Stream.In.create loop let stop t = - let () = + let* () = Fiber.return () in + let+ () = match t.state with - | `Closed -> () + | `Closed -> Fiber.return () | `Running t -> Transport.stop t - | `Init fd -> Unix.close fd + | `Init fd -> + Unix.close fd; + Fiber.return () in t.state <- `Closed @@ -376,12 +434,9 @@ end module Client = struct module Transport = struct - type t = - { fd : Unix.file_descr - ; sockaddr : Unix.sockaddr - } + type t = { fd : Unix.file_descr } - let close t = close_fd_no_error t.fd + let close t = Unix.close t.fd let create sockaddr = let fd = @@ -389,11 +444,8 @@ module Client = struct (Unix.domain_of_sockaddr sockaddr) Unix.SOCK_STREAM 0 in - { sockaddr; fd } - - let connect t = - let () = Socket.connect t.fd t.sockaddr in - t.fd + Unix.set_nonblock fd; + { fd } end type t = @@ -404,17 +456,17 @@ module Client = struct let create sockaddr = { sockaddr; transport = None } let connect t = - let+ task = - async (fun () -> - let transport = Transport.create t.sockaddr in - t.transport <- Some transport; - let client = Transport.connect transport in - let in_ = Unix.in_channel_of_descr client in - (client, in_)) - in - match task with - | Error exn -> Error exn - | Ok (fd, in_) -> Ok (Session.create fd in_) + let* () = Fiber.return () in + let backtrace = Printexc.get_callstack 10 in + let transport = Transport.create t.sockaddr in + let fd = transport.fd in + t.transport <- Some transport; + Async_io.connect Socket.connect fd t.sockaddr >>| function + | Ok () -> Ok (Session.create fd) + | Error `Cancelled -> + let exn = Failure "connect cancelled" in + Error { Exn_with_backtrace.exn; backtrace } + | Error (`Exn exn) -> Error { Exn_with_backtrace.exn; backtrace } let connect_exn t = let+ res = connect t in diff --git a/src/csexp_rpc/csexp_rpc.mli b/src/csexp_rpc/csexp_rpc.mli index a881e8bcd26..2b33b805e61 100644 --- a/src/csexp_rpc/csexp_rpc.mli +++ b/src/csexp_rpc/csexp_rpc.mli @@ -15,15 +15,6 @@ open Stdune -module type Scheduler = sig - (** [async f] enqueue task [f] *) - - val async : (unit -> 'a) -> ('a, Exn_with_backtrace.t) result Fiber.t -end - -(** Hack until we move [Dune_engine.Scheduler] into own library *) -val scheduler : (module Scheduler) Fdecl.t - module Session : sig (** Rpc session backed by two threads. One thread for reading, and another for writing *) @@ -61,7 +52,7 @@ module Server : sig to the server *) val ready : t -> unit Fiber.t - val stop : t -> unit + val stop : t -> unit Fiber.t val serve : t -> Session.t Fiber.Stream.In.t Fiber.t diff --git a/src/csexp_rpc/dune b/src/csexp_rpc/dune index 3a04d57270f..6201fa9a72c 100644 --- a/src/csexp_rpc/dune +++ b/src/csexp_rpc/dune @@ -7,6 +7,7 @@ dune_util csexp fiber + dune_async_io threads.posix (re_export unix)) (foreign_stubs diff --git a/src/csexp_rpc/io_buffer.ml b/src/csexp_rpc/io_buffer.ml index 519306d88b7..c861e81c9e7 100644 --- a/src/csexp_rpc/io_buffer.ml +++ b/src/csexp_rpc/io_buffer.ml @@ -99,3 +99,26 @@ let to_dyn ({ bytes; pos_r; pos_w; total_written } as t) = ; ("pos_w", int pos_w) ; ("pos_r", int pos_r) ] + +let write_pos t = t.pos_w + +let max_write_len t = Bytes.length t.bytes - t.pos_w + +let commit_write t ~len = + let pos_w = t.pos_w + len in + if pos_w > Bytes.length t.bytes then + Code_error.raise "not enough space to commit write" + [ ("len", Dyn.int len); ("t", to_dyn t) ]; + t.pos_w <- pos_w + +let read_char_exn t = + assert (t.pos_r < t.pos_w); + let c = Bytes.get t.bytes t.pos_r in + read t 1; + c + +let read_into_buffer t buf ~max_len = + let len = min max_len (t.pos_w - t.pos_r) in + Buffer.add_subbytes buf t.bytes t.pos_r len; + read t len; + len diff --git a/src/csexp_rpc/io_buffer.mli b/src/csexp_rpc/io_buffer.mli index 4ea8d1c38d8..69028ab8666 100644 --- a/src/csexp_rpc/io_buffer.mli +++ b/src/csexp_rpc/io_buffer.mli @@ -31,3 +31,27 @@ val pos : t -> int (** [length t] the number of bytes to read [bytes t] *) val length : t -> int + +(** [write_pos t] returns the write position inside the buffer we're allowed to + write on *) +val write_pos : t -> int + +(** [commit_write t ~len] tell the buffer [t] that we've written [len] writes. + [len] must be smaller or equal to [max_write_len t] *) +val commit_write : t -> len:int -> unit + +(** [max_write_len t] returns the maximum contiguous write size the buffer can + fit *) +val max_write_len : t -> int + +(** [read_char_exn t] reads and returns the next available byte. [t] must not be + empty *) +val read_char_exn : t -> char + +(** [read_into_buffer t buf ~max_len] reads at most [max_len] from [t] and + appends what was read to [buf] *) +val read_into_buffer : t -> Buffer.t -> max_len:int -> int + +(** [maybe_resize_to_fit t size] resizes [t] to fit a write of size [size] if + necessary *) +val maybe_resize_to_fit : t -> int -> unit diff --git a/src/dune_async_io/async_io.ml b/src/dune_async_io/async_io.ml new file mode 100644 index 00000000000..3a0c4eea6a7 --- /dev/null +++ b/src/dune_async_io/async_io.ml @@ -0,0 +1,302 @@ +open Stdune +open Fiber.O + +module Fd = struct + type t = Unix.file_descr + + let equal = Poly.equal + + let hash = Poly.hash + + let to_dyn = Dyn.opaque +end + +module type Scheduler = sig + val fill_jobs : Fiber.fill list -> unit + + val register_job_started : unit -> unit + + val cancel_job_started : unit -> unit + + val spawn_thread : (unit -> unit) -> unit +end + +let byte = Bytes.make 1 '0' + +module Task_id = Id.Make () + +type t = + { readers : (Unix.file_descr, packed_task Queue.t) Table.t + ; writers : (Unix.file_descr, packed_task Queue.t) Table.t + ; mutable to_close : Unix.file_descr list + ; pipe_read : Unix.file_descr + ; (* write a byte here to interrupt the select loop *) + pipe_write : Unix.file_descr + ; mutex : Mutex.t + ; scheduler : (module Scheduler) + ; mutable running : bool + ; (* this flag is to save a write to the pipe we used to interrupt select *) + mutable interrupting : bool + ; pipe_buf : Bytes.t + } + +and 'a task = + { job : unit -> 'a + ; ivar : ('a, [ `Cancelled | `Exn of exn ]) result Fiber.Ivar.t + ; select : t + ; what : [ `Read | `Write ] + ; fd : Unix.file_descr + ; id : Task_id.t + } + +and packed_task = Task : _ task -> packed_task + +let interrupt t = + if not t.interrupting then ( + assert (Unix.single_write t.pipe_write byte 0 1 = 1); + t.interrupting <- true) + +module Task = struct + type 'a t = 'a task + + let await task = Fiber.Ivar.read task.ivar + + let cancel t = + let* () = Fiber.return () in + Mutex.lock t.select.mutex; + let+ () = + Fiber.Ivar.peek t.ivar >>= function + | Some _ -> Fiber.return () + | None -> + let module Scheduler = (val t.select.scheduler) in + let table = + match t.what with + | `Read -> t.select.readers + | `Write -> t.select.writers + in + let should_interrupt = + match Table.find table t.fd with + | None -> false + | Some q -> + let new_q = Queue.create () in + Queue.iter q ~f:(fun (Task t' as task) -> + if Task_id.equal t.id t'.id then Scheduler.cancel_job_started () + else Queue.push new_q task); + if Queue.is_empty new_q then ( + Table.remove table t.fd; + Queue.is_empty q) + else ( + Table.add_exn table t.fd new_q; + Queue.length new_q <> Queue.length q) + in + if should_interrupt then interrupt t.select; + Fiber.Ivar.fill t.ivar (Error `Cancelled) + in + Mutex.unlock t.select.mutex +end + +let drain_until_ready queue acc = + match Queue.pop queue with + | None -> acc + | Some (Task task) -> + let result = try Ok (task.job ()) with exn -> Error (`Exn exn) in + Fiber.Fill (task.ivar, result) :: acc + +let make_fills fds pipe_fd waiters init = + List.fold_left fds ~init:(false, init) ~f:(fun (pipe, acc) fd -> + if Fd.equal fd pipe_fd then (true, acc) + else + let acc = + match Table.find waiters fd with + | None -> acc + | Some w -> + let acc = drain_until_ready w acc in + if Queue.is_empty w then Table.remove waiters fd; + acc + in + (pipe, acc)) + +let drain_pipe pipe buf = + match Unix.read pipe buf 0 (Bytes.length buf) with + | _ -> () + | exception Unix.Unix_error (Unix.EAGAIN, _, _) -> () + +let rec drain_cancel q acc = + match Queue.pop q with + | None -> acc + | Some (Task task) -> + drain_cancel q (Fiber.Fill (task.ivar, Error `Cancelled) :: acc) + +let maybe_cancel table fd acc = + match Table.find table fd with + | None -> acc + | Some q -> + Table.remove table fd; + drain_cancel q acc + +let rec select_loop t = + (match t.to_close with + | [] -> () + | to_close -> ( + let fills = + List.fold_left to_close ~init:[] ~f:(fun acc fd -> + Unix.close fd; + let acc = maybe_cancel t.readers fd acc in + maybe_cancel t.writers fd acc) + in + t.to_close <- []; + match fills with + | [] -> () + | _ :: _ -> + let module Scheduler = (val t.scheduler) in + Scheduler.fill_jobs fills)); + match t.running with + | false -> + Unix.close t.pipe_write; + Unix.close t.pipe_read + | true -> + let readers, writers, ex = + let read = t.pipe_read :: Table.keys t.readers in + let write = Table.keys t.writers in + Mutex.unlock t.mutex; + (* At this point, if any [ready] acquires the lock, they need to check if + [read] or [write] contain their fd. If it doesn't, the write + [t.pipe_write] will interrupt this select *) + Unix.select read write [] (-1.0) + in + assert (ex = []); + (* Before we acquire the lock, it's possible that new tasks were added. + This is fine. *) + Mutex.lock t.mutex; + let seen_pipe, fills = make_fills readers t.pipe_read t.readers [] in + (* we will never see [t.pipe_read] in the next list, but there's no harm in + this *) + let _, fills = make_fills writers t.pipe_read t.writers fills in + if seen_pipe then ( + drain_pipe t.pipe_read t.pipe_buf; + t.interrupting <- false); + (match fills with + | [] -> () + | _ :: _ -> + let module Scheduler = (val t.scheduler) in + Scheduler.fill_jobs fills); + select_loop t + +let t_var = Fiber.Var.create () + +let with_io scheduler f = + let module Scheduler = (val scheduler : Scheduler) in + let t = + let pipe_read, pipe_write = Unix.pipe ~cloexec:true () in + if not Sys.win32 then ( + Unix.set_nonblock pipe_read; + Unix.set_nonblock pipe_write); + { readers = Table.create (module Fd) 64 + ; writers = Table.create (module Fd) 64 + ; mutex = Mutex.create () + ; scheduler + ; running = true + ; pipe_read + ; pipe_write + ; pipe_buf = Bytes.create 512 + ; interrupting = false + ; to_close = [] + } + in + let () = + Scheduler.spawn_thread (fun () -> + Mutex.lock t.mutex; + Exn.protect + ~f:(fun () -> select_loop t) + ~finally:(fun () -> Mutex.unlock t.mutex)) + in + Fiber.Var.set t_var t (fun () -> + Fiber.finalize f ~finally:(fun () -> + Mutex.lock t.mutex; + t.running <- false; + interrupt t; + Mutex.unlock t.mutex; + Fiber.return ())) + +let with_ f = + let+ t = Fiber.Var.get_exn t_var in + Mutex.lock t.mutex; + Exn.protect ~f:(fun () -> f t) ~finally:(fun () -> Mutex.unlock t.mutex) + +let cancel_fd scheduler table fd = + match Table.find table fd with + | None -> Fiber.return () + | Some tasks -> + Table.remove table fd; + let module Scheduler = (val scheduler : Scheduler) in + Queue.to_list tasks + |> Fiber.parallel_iter ~f:(fun (Task t) -> + Scheduler.cancel_job_started (); + Fiber.Ivar.fill t.ivar (Error `Cancelled)) + +let close fd = + let* t = Fiber.Var.get_exn t_var in + Mutex.lock t.mutex; + (* everything below is guaranteed not to raise so the mutex will be unlocked + in the end. There's no need to use [protect] to make sure we don't deadlock *) + t.to_close <- fd :: t.to_close; + let+ () = + Fiber.fork_and_join_unit + (fun () -> cancel_fd t.scheduler t.readers fd) + (fun () -> cancel_fd t.scheduler t.writers fd) + in + interrupt t; + Mutex.unlock t.mutex + +let ready fd what ~f:job = + with_ @@ fun t -> + let module Scheduler = (val t.scheduler) in + Scheduler.register_job_started (); + let ivar = Fiber.Ivar.create () in + let q, interrupt_needed = + let table = + match what with + | `Read -> t.readers + | `Write -> t.writers + in + match Table.find table fd with + | Some q -> (q, false) + | None -> + let q = Queue.create () in + Table.add_exn table fd q; + (q, true) + in + let task = { ivar; select = t; job; what; id = Task_id.gen (); fd } in + Queue.push q (Task task); + if interrupt_needed then interrupt t; + task + +let rec with_retry f fd = + match f () with + | () -> Fiber.return (Ok ()) + | exception Unix.Unix_error (EWOULDBLOCK, x, y) when Sys.win32 -> + Fiber.return (Error (`Unix (Unix.EINPROGRESS, x, y))) + | exception Unix.Unix_error ((EAGAIN | EWOULDBLOCK | EINTR), _, _) -> ( + let* task = ready fd `Write ~f:Fun.id in + Task.await task >>= function + | Ok () -> with_retry f fd + | Error `Cancelled as e -> Fiber.return e + | Error (`Exn _) -> assert false) + | exception Unix.Unix_error (err, x, y) -> + Fiber.return (Error (`Unix (err, x, y))) + +let connect f fd socket = + let* () = Fiber.return () in + with_retry (fun () -> f fd socket) fd >>= function + | Ok () -> Fiber.return (Ok ()) + | Error (`Unix (Unix.EISCONN, _, _)) when Sys.win32 -> Fiber.return (Ok ()) + | Error (`Unix (EINPROGRESS, _, _)) -> ( + let* task = ready fd `Write ~f:(fun () -> Unix.getsockopt_error fd) in + Task.await task >>| function + | Error _ as e -> e + | Ok None -> Ok () + | Ok (Some err) -> Error (`Exn (Unix.Unix_error (err, "connect", "")))) + | Error (`Unix (e, x, y)) -> + Fiber.return @@ Error (`Exn (Unix.Unix_error (e, x, y))) + | Error (`Exn _) as e -> Fiber.return e + | Error `Cancelled as e -> Fiber.return e diff --git a/src/dune_async_io/async_io.mli b/src/dune_async_io/async_io.mli new file mode 100644 index 00000000000..34954652d5f --- /dev/null +++ b/src/dune_async_io/async_io.mli @@ -0,0 +1,65 @@ +(** Poor man's asynchronous IO on sockets (and pipes on Unix) + + Problematic in three ways: + + - Needs to run in a separate thread because our scheduler loop does not + allow polling for fd's and custom events. This requires unnecessary + locking. + + - Uses the rather slow select primitive. There's much better options on + every operating system. + + - Relies on the "pipe trick" to be interruptible. This is the best we can do + with select. *) + +(* TODO one day switch to lev and integrate all of this directly into the + scheduler. This should solve all the problems above. *) + +module type Scheduler = sig + val fill_jobs : Fiber.fill list -> unit + + val register_job_started : unit -> unit + + val cancel_job_started : unit -> unit + + val spawn_thread : (unit -> unit) -> unit +end + +(** [with_io scheduler f] runs [f] with [scheduler]. All operations in this + module must be executed inside [f]. *) +val with_io : (module Scheduler) -> (unit -> 'a Fiber.t) -> 'a Fiber.t + +(** [close fd] must be used to close any file descriptor which has been watched + at some point. This is needed to make sure we never close a file descriptor + that is being selected. Any associated operations with [fd] will be + cancelled. *) +val close : Unix.file_descr -> unit Fiber.t + +module Task : sig + (** A cancellable task *) + type 'a t + + (** Cancel a running task *) + val cancel : _ t -> unit Fiber.t + + (** Wait for a task to complete *) + val await : 'a t -> ('a, [ `Cancelled | `Exn of exn ]) result Fiber.t +end + +(** [ready fd what ~f] wait until [what] can be done on [fd] in a non-blocking + way and then call [f]. Note that [f] will be called in a different thread, + so it should only be used for atomic or synchronized operations. *) +val ready : + Unix.file_descr -> [ `Read | `Write ] -> f:(unit -> 'a) -> 'a Task.t Fiber.t + +(** [connect fd sock] will do the equivalent of [Unix.connect fd sock] but + without blocking. As in the other functions, you must call + [Unix.set_nonblock fd] before calling this function. + + It's possible to implement this function using the other functions in this + module. But since it's a bit non trivial, the implementation is done here. *) +val connect : + (Unix.file_descr -> Unix.sockaddr -> unit) + -> Unix.file_descr + -> Unix.sockaddr + -> (unit, [ `Cancelled | `Exn of exn ]) result Fiber.t diff --git a/src/dune_async_io/dune b/src/dune_async_io/dune new file mode 100644 index 00000000000..7b968d793a0 --- /dev/null +++ b/src/dune_async_io/dune @@ -0,0 +1,3 @@ +(library + (name dune_async_io) + (libraries stdune threads.posix unix fiber)) diff --git a/src/dune_async_io/dune_async_io.ml b/src/dune_async_io/dune_async_io.ml new file mode 100644 index 00000000000..657ea75a601 --- /dev/null +++ b/src/dune_async_io/dune_async_io.ml @@ -0,0 +1 @@ +module Async_io = Async_io diff --git a/src/dune_engine/dune b/src/dune_engine/dune index 857f66a8146..d3893a77a1f 100644 --- a/src/dune_engine/dune +++ b/src/dune_engine/dune @@ -11,6 +11,7 @@ dyn fiber memo + dune_async_io threads.posix predicate_lang dune_sexp diff --git a/src/dune_engine/scheduler.ml b/src/dune_engine/scheduler.ml index 7e4e4ffb78f..487c1a7fb34 100644 --- a/src/dune_engine/scheduler.ml +++ b/src/dune_engine/scheduler.ml @@ -1,6 +1,7 @@ open Import open Fiber.O open Dune_thread_pool +open Dune_async_io module Config = struct type t = @@ -137,8 +138,12 @@ module Event : sig val send_worker_task_completed : t -> Fiber.fill -> unit + val send_worker_tasks_completed : t -> Fiber.fill list -> unit + val register_worker_task_started : t -> unit + val cancel_work_task_started : t -> unit + val send_file_watcher_task : t -> (unit -> Dune_file_watcher.Event.t list) -> unit @@ -226,6 +231,9 @@ end = struct let register_worker_task_started q = q.pending_worker_tasks <- q.pending_worker_tasks + 1 + let cancel_work_task_started q = + q.pending_worker_tasks <- q.pending_worker_tasks - 1 + let add_event q f = Mutex.lock q.mutex; f q; @@ -384,6 +392,10 @@ end = struct let send_worker_task_completed q event = add_event q (fun q -> Queue.push q.worker_tasks_completed event) + let send_worker_tasks_completed q events = + add_event q (fun q -> + List.iter events ~f:(Queue.push q.worker_tasks_completed)) + let send_invalidation_events q events = add_event q (fun q -> q.invalidation_events <- q.invalidation_events @ events) @@ -1035,6 +1047,20 @@ end = struct let run t f : _ result = let fiber = set t (fun () -> + let module Scheduler = struct + let spawn_thread = spawn_thread + + let register_job_started () = + Event.Queue.register_worker_task_started t.events + + let fill_jobs jobs = + Event.Queue.send_worker_tasks_completed t.events jobs + + let cancel_job_started () = + Event.Queue.cancel_work_task_started t.events + end in + Async_io.with_io (module (Scheduler : Async_io.Scheduler)) + @@ fun () -> Fiber.map_reduce_errors (module Monoid.Unit) f @@ -1070,12 +1096,6 @@ let async f = Event.Queue.register_worker_task_started t.events; Fiber.Ivar.read ivar -let () = - Fdecl.set Csexp_rpc.scheduler - (module struct - let async f = async f - end) - let async_exn f = async f >>| function | Error exn -> Exn_with_backtrace.reraise exn diff --git a/src/dune_rpc_impl/server.ml b/src/dune_rpc_impl/server.ml index f8caa33b76a..53acb9c7630 100644 --- a/src/dune_rpc_impl/server.ml +++ b/src/dune_rpc_impl/server.ml @@ -205,9 +205,9 @@ let stop (t : _ t) = Fiber.fork_and_join_unit (fun () -> Action_runner.Rpc_server.stop t.config.action_runner) (fun () -> - let+ server = Fiber.Ivar.peek t.config.server_ivar in + let* server = Fiber.Ivar.peek t.config.server_ivar in match server with - | None -> () + | None -> Fiber.return () | Some server -> Csexp_rpc.Server.stop server) let handler (t : _ t Fdecl.t) action_runner_server handle : @@ -317,8 +317,7 @@ let handler (t : _ t Fdecl.t) action_runner_server handle : in let shutdown () = Fiber.fork_and_join_unit Scheduler.shutdown (fun () -> - Csexp_rpc.Server.stop (Lazy.force t.config.server); - Fiber.return ()) + Csexp_rpc.Server.stop (Lazy.force t.config.server)) in Fiber.fork_and_join_unit terminate_sessions shutdown in diff --git a/test/expect-tests/csexp_rpc/csexp_rpc_tests.ml b/test/expect-tests/csexp_rpc/csexp_rpc_tests.ml index 1126a460ea1..4bbe24fceee 100644 --- a/test/expect-tests/csexp_rpc/csexp_rpc_tests.ml +++ b/test/expect-tests/csexp_rpc/csexp_rpc_tests.ml @@ -60,7 +60,7 @@ let%expect_test "csexp server life cycle" = (match response with | None -> log "no response" | Some sexp -> log "received %s" (Csexp.to_string sexp)); - let+ () = Session.write client None in + let* () = Session.write client None in log "closed"; Server.stop server) (fun () -> diff --git a/test/expect-tests/dune_action_runner/dune_action_runner.ml b/test/expect-tests/dune_action_runner/dune_action_runner.ml index 1a84a86184a..8af6ddecffe 100755 --- a/test/expect-tests/dune_action_runner/dune_action_runner.ml +++ b/test/expect-tests/dune_action_runner/dune_action_runner.ml @@ -25,7 +25,7 @@ let run () = Dune_rpc_server.Handler.create ~version:(3, 7) ~on_init:(fun _ _ -> print_endline "server: client connected"; - Csexp_rpc.Server.stop csexp_server |> Fiber.return) + Csexp_rpc.Server.stop csexp_server) () in Action_runner.Rpc_server.implement_handler action_runner_server handler; diff --git a/test/expect-tests/dune_async_io/async_io_tests.ml b/test/expect-tests/dune_async_io/async_io_tests.ml new file mode 100644 index 00000000000..bbb94fb5816 --- /dev/null +++ b/test/expect-tests/dune_async_io/async_io_tests.ml @@ -0,0 +1,59 @@ +open Stdune +open Fiber.O +module Scheduler = Dune_engine.Scheduler +open Dune_async_io + +let config = + { Scheduler.Config.concurrency = 1 + ; stats = None + ; insignificant_changes = `Ignore + ; signal_watcher = `No + ; watch_exclusions = [] + } + +let%expect_test "read readiness" = + ( Scheduler.Run.go config ~on_event:(fun _ _ -> ()) @@ fun () -> + let r, w = Unix.pipe ~cloexec:true () in + if not Sys.win32 then Unix.set_nonblock r; + let* task = Async_io.ready r `Read ~f:ignore in + assert (Unix.write w (Bytes.of_string "0") 0 1 = 1); + Async_io.Task.await task >>= function + | Error _ -> assert false + | Ok () -> + let bytes = Bytes.of_string "1" in + assert (Unix.read r bytes 0 1 = 1); + assert (Bytes.to_string bytes = "0"); + Unix.close w; + let+ () = Async_io.close r in + print_endline "succesful read" ); + [%expect {| succesful read |}] + +let%expect_test "write readiness" = + ( Scheduler.Run.go config ~on_event:(fun _ _ -> ()) @@ fun () -> + let r, w = Unix.pipe ~cloexec:true () in + if not Sys.win32 then Unix.set_nonblock w; + let* task = Async_io.ready w `Write ~f:ignore in + Async_io.Task.await task >>= function + | Error _ -> assert false + | Ok () -> + assert (Unix.write w (Bytes.of_string "0") 0 1 = 1); + Unix.close r; + let+ () = Async_io.close w in + print_endline "succesful write" ); + [%expect {| succesful write |}] + +let%expect_test "cancel task" = + ( Scheduler.Run.go config ~on_event:(fun _ _ -> ()) @@ fun () -> + let r, w = Unix.pipe ~cloexec:true () in + if not Sys.win32 then Unix.set_nonblock r; + let* task = Async_io.ready r `Read ~f:ignore in + Fiber.fork_and_join_unit + (fun () -> + Async_io.Task.await task >>= function + | Ok () | Error (`Exn _) -> assert false + | Error `Cancelled -> + Unix.close w; + let+ () = Async_io.close r in + print_endline "successfully cancelled") + (fun () -> Async_io.Task.cancel task) ); + [%expect {| successfully cancelled |}] diff --git a/test/expect-tests/dune_async_io/async_io_tests.mli b/test/expect-tests/dune_async_io/async_io_tests.mli new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/expect-tests/dune_async_io/dune b/test/expect-tests/dune_async_io/dune new file mode 100644 index 00000000000..b51f5aad039 --- /dev/null +++ b/test/expect-tests/dune_async_io/dune @@ -0,0 +1,20 @@ +(library + (name dune_async_io_tests) + (inline_tests) + (preprocess + (pps ppx_expect)) + (libraries + stdune + dune_engine + unix + threads.posix + fiber + dune_async_io + dune_tests_common + ;; This is because of the (implicit_transitive_deps false) + ;; in dune-project + ppx_expect.config + ppx_expect.config_types + ppx_expect.common + base + ppx_inline_test.config))