Skip to content

Commit

Permalink
eio(client): 'run' to tak stack and crypto params
Browse files Browse the repository at this point in the history
Make `run` accept crypto rng generator parameters.
The 'f' function is injected with stack parameter.

Address misc. reviewer comments.
  • Loading branch information
bikallem committed Oct 26, 2022
1 parent 0b24a89 commit 6b59d56
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 50 deletions.
11 changes: 7 additions & 4 deletions dns-client-eio.opam
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
opam-version: "2.0"
maintainer: "team AT robur dot io"
authors: ["Bikal Gurung <gbikal@gmail.com>"]
maintainer: "Bikal Gurung <gbikal AT gmail DOT com>"
authors: ["Bikal Gurung <gbikal AT gmail DOT com>"]
homepage: "https://github.com/mirage/ocaml-dns"
bug-reports: "https://github.com/mirage/ocaml-dns/issues"
dev-repo: "git+https://github.com/mirage/ocaml-dns.git"
Expand All @@ -20,8 +20,11 @@ depends: [
"dns-client" {>= version}
"mirage-clock" {>= "3.0.0"}
"mtime" {>= "1.2.0"}
"mirage-crypto-rng" {>= version}
"mirage-crypto-rng-eio" {>= version}
"mirage-crypto-rng-eio" {>= "0.10.7"}
"domain-name" {>= "0.4.0"}
"mtime" {>= "1.2.0"}
"fmt" {>= "0.8.8"}
"eio_main" {>= "0.5"}
]
synopsis: "DNS client for eio"
description: """
Expand Down
90 changes: 53 additions & 37 deletions eio/client/dns_client_eio.ml
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
module E = Eio

type env = <
clock : E.Time.clock ;
net : E.Net.t;
fs : E.Dir.t;
secure_random : E.Flow.source;
>
type 'a env = <
clock : Eio.Time.clock ;
net : Eio.Net.t ;
fs : Eio.Fs.dir Eio.Path.t ;
secure_random : Eio.Flow.source;
..
> as 'a

type io_addr = Ipaddr.t * int
type stack = env * E.Switch.t
type stack = {
sw : Eio.Switch.t;
clock : Eio.Time.clock;
net : Eio.Net.t;
resolv_conf : Eio.Fs.dir Eio.Path.t
}

module Transport : Dns_client.S
with type io_addr = io_addr
Expand All @@ -18,7 +22,7 @@ module Transport : Dns_client.S
type nonrec io_addr = io_addr
type nonrec stack = stack
type +'a io = 'a
type context = E.Net.stream_socket
type context = <Eio.Net.stream_socket; Eio.Flow.close>

type nameservers =
| Static of io_addr Queue.t
Expand All @@ -30,17 +34,15 @@ module Transport : Dns_client.S
type t = {
nameservers : nameservers ;
timeout_ns : int64 ; (* Timeout in nano seconds *)
env : env;
sw : E.Switch.t ;
mutex : E.Mutex.t ;
stack : stack;
mutex : Eio.Mutex.t ;
}

let read_file env file =
match E.Dir.load env#fs file with
let read_file file =
match Eio.Path.load file with
| content -> Ok content
| exception e ->
let err = "Error while reading file: " ^ file ^ ". " ^ (Printexc.to_string e) in
Error (`Msg err)
| exception e ->
Fmt.error_msg "Error while reading file %a: %a" Eio.Path.pp file Fmt.exn e

(* Prioritises IPv6 nameservers before IPv4 nameservers so that we
are more conformant with the happy eyballs RFC when implementing it.
Expand Down Expand Up @@ -71,7 +73,7 @@ module Transport : Dns_client.S
|> List.map (fun ip -> ip, 53)
|> ipv6_first_queue

let create ?nameservers ~timeout (env, sw) =
let create ?nameservers ~timeout stack =
let nameservers =
match nameservers with
| Some (proto, ns) -> begin
Expand All @@ -87,7 +89,7 @@ module Transport : Dns_client.S
let nameservers, digest =
match
let ( let* ) = Result.bind in
let* data = read_file env "/etc/resolv.conf" in
let* data = read_file stack.resolv_conf in
let* ips = decode_resolv_conf data in
Ok (ips, Digest.string data)
with
Expand All @@ -96,8 +98,8 @@ module Transport : Dns_client.S
in
(Resolv_conf { nameservers; digest })
in
let mutex = E.Mutex.create () in
{ nameservers; timeout_ns = timeout; env; sw; mutex }
let mutex = Eio.Mutex.create () in
{ nameservers; timeout_ns = timeout; stack; mutex }

let nameservers0
{ nameservers =
Expand Down Expand Up @@ -130,7 +132,7 @@ module Transport : Dns_client.S
resolv_conf.digest <- None;
resolv_conf.nameservers <- default_resolvers ()
in
match read_file t.env "/etc/resolv.conf", resolv_conf.digest with
match read_file t.stack.resolv_conf, resolv_conf.digest with
| Ok data, Some d ->
let digest = Digest.string data in
if Digest.equal digest d then () else decode_update data digest
Expand All @@ -154,17 +156,17 @@ module Transport : Dns_client.S
if n >= Queue.length ns_q then
Error (`Msg "Unable to connect to specified nameservers")
else
let (ip, port) = E.Mutex.use_ro t.mutex @@ fun () -> Queue.peek ns_q in
let ip = ipaddr_octects ip |> E.Net.Ipaddr.of_raw in
let (ip, port) = Eio.Mutex.use_ro t.mutex @@ fun () -> Queue.peek ns_q in
let ip = ipaddr_octects ip |> Eio.Net.Ipaddr.of_raw in
let stream = `Tcp (ip, port) in
try
let timeout = Duration.to_f t.timeout_ns in
E.Time.with_timeout_exn t.env#clock timeout @@ fun () ->
let flow = E.Net.connect ~sw:t.sw t.env#net stream in
Eio.Time.with_timeout_exn t.stack.clock timeout @@ fun () ->
let flow = Eio.Net.connect ~sw:t.stack.sw t.stack.net stream in
Ok flow
with E.Time.Timeout ->
with Eio.Time.Timeout ->
(* Push the non responsive nameserver to the back of the queue. *)
let ns = E.Mutex.use_rw ~protect:true t.mutex @@ fun () -> Queue.pop ns_q in
let ns = Eio.Mutex.use_rw ~protect:true t.mutex @@ fun () -> Queue.pop ns_q in
Queue.push ns ns_q;
try_ns_connection t (n + 1) ns_q

Expand All @@ -176,23 +178,37 @@ module Transport : Dns_client.S
let send_recv ctx dns_query =
if Cstruct.length dns_query > 4 then
try
let src = E.Flow.cstruct_source [dns_query] in
E.Flow.copy src ctx;
let src = Eio.Flow.cstruct_source [dns_query] in
Eio.Flow.copy src ctx;
let dns_response = Cstruct.create 2048 in
let got = E.Flow.read ctx dns_response in
let got = Eio.Flow.read ctx dns_response in
Ok (Cstruct.sub dns_response 0 got)
with e -> Error (`Msg (Printexc.to_string e))
else
Error (`Msg "Invalid DNS query packet (data length <= 4)")

let close flow = try E.Flow.close flow with _ -> ()
let close flow = Eio.Flow.close flow
let bind a f = f a
let lift v = v
end

module Client = Dns_client.Make(Transport)
module type DNS_CLIENT = module type of Dns_client.Make(Transport)

let run env (f:(module DNS_CLIENT) -> 'a) =
Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () ->
f (module Client)
module Client = Dns_client.Make(Transport)

let run (type a)
?(resolv_conf = "/etc/resolv.conf")
?g
(crypto_generator: a Mirage_crypto_rng.generator)
env (f: Transport.stack -> (module DNS_CLIENT) -> 'b)
=
let module M = (val crypto_generator) in
Mirage_crypto_rng_eio.run ?g (module M) env @@ fun () ->
Eio.Switch.run @@ fun sw ->
let stack = {
sw;
clock = env#clock;
net = env#net;
resolv_conf = Eio.Path.(env#fs / resolv_conf) }
in
f stack (module Client)
19 changes: 14 additions & 5 deletions eio/client/dns_client_eio.mli
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
type env = <
type 'a env = <
clock : Eio.Time.clock ;
net : Eio.Net.t ;
fs : Eio.Dir.t ;
fs : Eio.Fs.dir Eio.Path.t ;
secure_random : Eio.Flow.source ;
>
..
> as 'a

module Transport : Dns_client.S
with type io_addr = Ipaddr.t * int
and type stack = env * Eio.Switch.t
and type +'a io = 'a

module type DNS_CLIENT = module type of Dns_client.Make(Transport)

val run : < env; ..> -> ((module DNS_CLIENT) -> 'a) -> 'a
val run :
?resolv_conf:string
-> ?g: 'a
-> 'a Mirage_crypto_rng.generator
-> _ env
-> (Transport.stack -> (module DNS_CLIENT) -> 'b)
-> 'b
(** [run crypto_rng_generator env f] starts the [crypto_rng_generator] required by [ocaml-dns].
It then creates [stack] and [Client] module and calls [f stack (module Client)]. module
[Client] can be used to execute Dns client functions. *)
6 changes: 2 additions & 4 deletions eio/client/ohost.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ let (let+) r f = Result.map f r

let display_host_ips h_name =
Eio_main.run @@ fun env ->
Eio.Switch.run @@ fun sw ->
Dns_client_eio.run env @@ fun (module Client) ->
let env = (env :> Dns_client_eio.env) in
let c = Client.create (env, sw) in
Dns_client_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun stack (module Client) ->
let c = Client.create stack in
let domain = Domain_name.(host_exn (of_string_exn h_name)) in
let ipv4 =
let+ addr = Client.gethostbyname c domain in
Expand Down

0 comments on commit 6b59d56

Please sign in to comment.