diff --git a/src/repr/binary_codec.ml b/src/repr/binary_codec.ml new file mode 100644 index 00000000..f64cb690 --- /dev/null +++ b/src/repr/binary_codec.ml @@ -0,0 +1,271 @@ +include Binary_codec_intf +open Staging + +let unsafe_add_bytes b k = k (Bytes.unsafe_to_string b) +let str = Bytes.unsafe_of_string + +let charstring_of_code : int -> string = + let tbl = + Array.init 256 (fun i -> Bytes.unsafe_to_string (Bytes.make 1 (Char.chr i))) + in + fun [@inline always] i -> + assert (i < 256); + Array.unsafe_get tbl i + +module Unit = struct + let encode () _k = () + let decode _ ofs = (ofs, ()) [@@inline always] +end + +module Char = struct + let encode c k = k (charstring_of_code (Char.code c)) + let decode buf ofs = (ofs + 1, buf.[ofs]) [@@inline always] +end + +module Bool = struct + let encode b = Char.encode (if b then '\255' else '\000') + + let decode buf ofs = + let ofs, c = Char.decode buf ofs in + match c with '\000' -> (ofs, false) | _ -> (ofs, true) +end + +module Int8 = struct + let encode i k = k (charstring_of_code i) + + let decode buf ofs = + let ofs, c = Char.decode buf ofs in + (ofs, Stdlib.Char.code c) + [@@inline always] +end + +module Int16 = struct + let encode i = + let b = Bytes.create 2 in + Bytes.set_uint16_be b 0 i; + unsafe_add_bytes b + + let decode buf ofs = (ofs + 2, Bytes.get_uint16_be (str buf) ofs) +end + +module Int32 = struct + let encode i = + let b = Bytes.create 4 in + Bytes.set_int32_be b 0 i; + unsafe_add_bytes b + + let decode buf ofs = (ofs + 4, Bytes.get_int32_be (str buf) ofs) +end + +module Int64 = struct + let encode i = + let b = Bytes.create 8 in + Bytes.set_int64_be b 0 i; + unsafe_add_bytes b + + let decode buf ofs = (ofs + 8, Bytes.get_int64_be (str buf) ofs) +end + +module Float = struct + let encode f = Int64.encode (Stdlib.Int64.bits_of_float f) + + let decode buf ofs = + let ofs, f = Int64.decode buf ofs in + (ofs, Stdlib.Int64.float_of_bits f) +end + +module Int = struct + let encode i k = + let rec aux n k = + if n >= 0 && n < 128 then k (charstring_of_code n) + else + let out = 128 lor (n land 127) in + k (charstring_of_code out); + aux (n lsr 7) k + in + aux i k + + let decode buf ofs = + let rec aux buf n p ofs = + let ofs, i = Int8.decode buf ofs in + let n = n + ((i land 127) lsl p) in + if i >= 0 && i < 128 then (ofs, n) else aux buf n (p + 7) ofs + in + aux buf 0 0 ofs +end + +module Len = struct + let encode n i = + match n with + | `Int -> Int.encode i + | `Int8 -> Int8.encode i + | `Int16 -> Int16.encode i + | `Int32 -> Int32.encode (Stdlib.Int32.of_int i) + | `Int64 -> Int64.encode (Stdlib.Int64.of_int i) + | `Fixed _ -> Unit.encode () + | `Unboxed -> Unit.encode () + + let decode n buf ofs = + match n with + | `Int -> Int.decode buf ofs + | `Int8 -> Int8.decode buf ofs + | `Int16 -> Int16.decode buf ofs + | `Int32 -> + let ofs, i = Int32.decode buf ofs in + (ofs, Stdlib.Int32.to_int i) + | `Int64 -> + let ofs, i = Int64.decode buf ofs in + (ofs, Stdlib.Int64.to_int i) + | `Fixed n -> (ofs, n) + | `Unboxed -> (ofs, String.length buf - ofs) +end + +(* Helper functions generalising over [string] / [bytes]. *) +module Mono_container = struct + let decode_unboxed of_string of_bytes = + stage @@ fun buf ofs -> + let len = String.length buf - ofs in + if ofs = 0 then (len, of_string buf) + else + let str = Bytes.create len in + String.blit buf ofs str 0 len; + (ofs + len, of_bytes str) + + let decode of_string of_bytes = + let sub len buf ofs = + if ofs = 0 && len = String.length buf then (len, of_string buf) + else + let str = Bytes.create len in + String.blit buf ofs str 0 len; + (ofs + len, of_bytes str) + in + function + | `Fixed n -> + (* fixed-size strings are never boxed *) + stage @@ fun buf ofs -> sub n buf ofs + | n -> + stage @@ fun buf ofs -> + let ofs, len = Len.decode n buf ofs in + sub len buf ofs +end + +module String_unboxed = struct + let encode _ = stage (fun s k -> k s) + + let decode _ = + Mono_container.decode_unboxed (fun x -> x) Bytes.unsafe_to_string +end + +module Bytes_unboxed = struct + (* NOTE: makes a copy of [b], since [k] may retain the string it's given *) + let encode _ = stage (fun b k -> k (Bytes.to_string b)) + + let decode _ = + Mono_container.decode_unboxed Bytes.unsafe_of_string (fun x -> x) +end + +module String = struct + let encode len = + stage (fun s k -> + let i = String.length s in + Len.encode len i k; + k s) + + let decode len = Mono_container.decode (fun x -> x) Bytes.unsafe_to_string len +end + +module Bytes = struct + let encode len = + stage (fun s k -> + let i = Bytes.length s in + Len.encode len i k; + unsafe_add_bytes s k) + + let decode len = Mono_container.decode Bytes.unsafe_of_string (fun x -> x) len +end + +module Option = struct + let encode encode_elt v k = + match v with + | None -> Char.encode '\000' k + | Some x -> + Char.encode '\255' k; + encode_elt x k + + let decode decode_elt buf ofs = + let ofs, c = Char.decode buf ofs in + match c with + | '\000' -> (ofs, None) + | _ -> + let ofs, x = decode_elt buf ofs in + (ofs, Some x) +end + +module List = struct + let encode = + let rec encode_elements encode_elt k = function + | [] -> () + | x :: xs -> + encode_elt x k; + (encode_elements [@tailcall]) encode_elt k xs + in + fun len encode_elt -> + stage (fun x k -> + Len.encode len (List.length x) k; + encode_elements encode_elt k x) + + let decode = + let rec decode_elements decode_elt acc buf off = function + | 0 -> (off, List.rev acc) + | n -> + let off, x = decode_elt buf off in + decode_elements decode_elt (x :: acc) buf off (n - 1) + in + fun len decode_elt -> + stage (fun buf ofs -> + let ofs, len = Len.decode len buf ofs in + decode_elements decode_elt [] buf ofs len) +end + +module Array = struct + let encode = + let encode_elements encode_elt k arr = + for i = 0 to Array.length arr - 1 do + encode_elt (Array.unsafe_get arr i) k + done + in + fun n l -> + stage (fun x k -> + Len.encode n (Array.length x) k; + encode_elements l k x) + + let decode len decode_elt = + let list_decode = unstage (List.decode len decode_elt) in + stage (fun buf off -> + let ofs, l = list_decode buf off in + (ofs, Array.of_list l)) +end + +module Pair = struct + let encode a b (x, y) k = + a x k; + b y k + + let decode a b buf off = + let off, a = a buf off in + let off, b = b buf off in + (off, (a, b)) +end + +module Triple = struct + let encode a b c (x, y, z) k = + a x k; + b y k; + c z k + + let decode a b c buf off = + let off, a = a buf off in + let off, b = b buf off in + let off, c = c buf off in + (off, (a, b, c)) +end diff --git a/src/repr/binary_codec.mli b/src/repr/binary_codec.mli new file mode 100644 index 00000000..6d02b966 --- /dev/null +++ b/src/repr/binary_codec.mli @@ -0,0 +1,2 @@ +include Binary_codec_intf.Intf +(** @inline *) diff --git a/src/repr/binary_codec_intf.ml b/src/repr/binary_codec_intf.ml new file mode 100644 index 00000000..fc1de134 --- /dev/null +++ b/src/repr/binary_codec_intf.ml @@ -0,0 +1,72 @@ +open Type_core +open Staging + +type 'a encoder = 'a -> (string -> unit) -> unit +type 'a decoder = string -> int -> int * 'a + +module type S = sig + type t + + val encode : t encoder + val decode : t decoder +end + +module type S_with_length = sig + type t + + val encode : len -> t encoder staged + val decode : len -> t decoder staged +end + +module type S1 = sig + type 'a t + + val encode : 'a encoder -> 'a t encoder + val decode : 'a decoder -> 'a t decoder +end + +module type S1_with_length = sig + type 'a t + + val encode : len -> 'a encoder -> 'a t encoder staged + val decode : len -> 'a decoder -> 'a t decoder staged +end + +module type S2 = sig + type ('a, 'b) t + + val encode : 'a encoder -> 'b encoder -> ('a, 'b) t encoder + val decode : 'a decoder -> 'b decoder -> ('a, 'b) t decoder +end + +module type S3 = sig + type ('a, 'b, 'c) t + + val encode : 'a encoder -> 'b encoder -> 'c encoder -> ('a, 'b, 'c) t encoder + val decode : 'a decoder -> 'b decoder -> 'c decoder -> ('a, 'b, 'c) t decoder +end + +module type Intf = sig + module type S = S + module type S1 = S1 + module type S2 = S2 + module type S3 = S3 + + module Unit : S with type t := unit + module Bool : S with type t := bool + module Char : S with type t := char + module Int : S with type t := int + module Int16 : S with type t := int + module Int32 : S with type t := int32 + module Int64 : S with type t := int64 + module Float : S with type t := float + module String : S_with_length with type t := string + module String_unboxed : S_with_length with type t := string + module Bytes : S_with_length with type t := bytes + module Bytes_unboxed : S_with_length with type t := bytes + module List : S1_with_length with type 'a t := 'a list + module Array : S1_with_length with type 'a t := 'a array + module Option : S1 with type 'a t := 'a option + module Pair : S2 with type ('a, 'b) t := 'a * 'b + module Triple : S3 with type ('a, 'b, 'c) t := 'a * 'b * 'c +end diff --git a/src/repr/type_binary.ml b/src/repr/type_binary.ml index 97224c90..0ea3af17 100644 --- a/src/repr/type_binary.ml +++ b/src/repr/type_binary.ml @@ -19,123 +19,33 @@ open Staging open Utils module Encode = struct - let chars = - Array.init 256 (fun i -> Bytes.unsafe_to_string (Bytes.make 1 (Char.chr i))) - - let unit () _k = () - let unsafe_add_bytes b k = k (Bytes.unsafe_to_string b) - let add_string s k = k s - let char c k = k chars.(Char.code c) - - let int8 i k = - assert (i < 256); - k chars.(i) - - let int16 i = - let b = Bytes.create 2 in - Bytes.set_uint16_be b 0 i; - unsafe_add_bytes b - - let int32 i = - let b = Bytes.create 4 in - Bytes.set_int32_be b 0 i; - unsafe_add_bytes b - - let int64 i = - let b = Bytes.create 8 in - Bytes.set_int64_be b 0 i; - unsafe_add_bytes b - - let float f = int64 (Int64.bits_of_float f) - let bool b = char (if b then '\255' else '\000') - - let int i k = - let rec aux n k = - if n >= 0 && n < 128 then k chars.(n) - else - let out = 128 lor (n land 127) in - k chars.(out); - aux (n lsr 7) k - in - aux i k - - let len n i = - match n with - | `Int -> int i - | `Int8 -> int8 i - | `Int16 -> int16 i - | `Int32 -> int32 (Int32.of_int i) - | `Int64 -> int64 (Int64.of_int i) - | `Fixed _ -> unit () - | `Unboxed -> unit () - - let unboxed_string _ = stage add_string - - let boxed_string n = - let len = len n in - stage @@ fun s k -> - let i = String.length s in - len i k; - add_string s k - - let string boxed = if boxed then boxed_string else unboxed_string - let unboxed_bytes _ = stage @@ fun b k -> add_string (Bytes.to_string b) k - - let boxed_bytes n = - let len = len n in - stage @@ fun s k -> - let i = Bytes.length s in - len i k; - unsafe_add_bytes s k - - let bytes boxed = if boxed then boxed_bytes else unboxed_bytes - - let list = - let rec encode_elements encode_elt k = function - | [] -> () - | x :: xs -> - encode_elt x k; - (encode_elements [@tailcall]) encode_elt k xs - in - fun l n -> - let l = unstage l in - stage (fun x k -> - len n (List.length x) k; - encode_elements l k x) - - let array = - let encode_elements encode_elt k arr = - for i = 0 to Array.length arr - 1 do - encode_elt (Array.unsafe_get arr i) k - done - in - fun l n -> - let l = unstage l in - stage (fun x k -> - len n (Array.length x) k; - encode_elements l k x) + module Bin = Binary_codec + + let string boxed n = + if boxed then Bin.String.encode n else Bin.String_unboxed.encode n + + let bytes boxed n = + if boxed then Bin.Bytes.encode n else Bin.Bytes_unboxed.encode n + + let list l n = + let l = unstage l in + Bin.List.encode n l + + let array l n = + let l = unstage l in + Bin.Array.encode n l let pair a b = let a = unstage a and b = unstage b in - stage (fun (x, y) k -> - a x k; - b y k) + stage (Bin.Pair.encode a b) let triple a b c = let a = unstage a and b = unstage b and c = unstage c in - stage (fun (x, y, z) k -> - a x k; - b y k; - c z k) + stage (Bin.Triple.encode a b c) let option o = let o = unstage o in - stage (fun v k -> - match v with - | None -> char '\000' k - | Some x -> - char '\255' k; - o x k) + stage (Bin.Option.encode o) let rec t : type a. a t -> a encode_bin = function | Self s -> fst (self s) @@ -184,159 +94,63 @@ module Encode = struct and prim : type a. boxed:bool -> a prim -> a encode_bin = fun ~boxed -> function - | Unit -> stage unit - | Bool -> stage bool - | Char -> stage char - | Int -> stage int - | Int32 -> stage int32 - | Int64 -> stage int64 - | Float -> stage float + | Unit -> stage Bin.Unit.encode + | Bool -> stage Bin.Bool.encode + | Char -> stage Bin.Char.encode + | Int -> stage Bin.Int.encode + | Int32 -> stage Bin.Int32.encode + | Int64 -> stage Bin.Int64.encode + | Float -> stage Bin.Float.encode | String n -> string boxed n | Bytes n -> bytes boxed n and record : type a. a record -> a encode_bin = fun r -> let field_encoders : (a -> (string -> unit) -> unit) list = - fields r - |> List.map @@ fun (Field f) -> - let field_encode = unstage (t f.ftype) in - fun x -> field_encode (f.fget x) + ListLabels.map (fields r) ~f:(fun (Field f) -> + let field_encode = unstage (t f.ftype) in + fun x -> field_encode (f.fget x)) in - stage (fun x k -> List.iter (fun f -> f x k) field_encoders) + stage (fun x k -> Stdlib.List.iter (fun f -> f x k) field_encoders) and variant : type a. a variant -> a encode_bin = - let c0 { ctag0; _ } = stage (int ctag0) in + let c0 { ctag0; _ } = stage (Bin.Int.encode ctag0) in let c1 c = let encode_arg = unstage (t c.ctype1) in stage (fun v k -> - int c.ctag1 k; + Bin.Int.encode c.ctag1 k; encode_arg v k) in fun v -> fold_variant { c0; c1 } v end module Decode = struct - type 'a res = int * 'a - - let unit _ ofs = (ofs, ()) [@@inline always] - let char buf ofs = (ofs + 1, buf.[ofs]) [@@inline always] - - let int8 buf ofs = - let ofs, c = char buf ofs in - (ofs, Char.code c) - [@@inline always] + module Bin = Binary_codec - let str = Bytes.unsafe_of_string - let int16 buf ofs = (ofs + 2, Bytes.get_uint16_be (str buf) ofs) - let int32 buf ofs = (ofs + 4, Bytes.get_int32_be (str buf) ofs) - let int64 buf ofs = (ofs + 8, Bytes.get_int64_be (str buf) ofs) - - let bool buf ofs = - let ofs, c = char buf ofs in - match c with '\000' -> (ofs, false) | _ -> (ofs, true) + type 'a res = int * 'a - let float buf ofs = - let ofs, f = int64 buf ofs in - (ofs, Int64.float_of_bits f) - - let int buf ofs = - let rec aux buf n p ofs = - let ofs, i = int8 buf ofs in - let n = n + ((i land 127) lsl p) in - if i >= 0 && i < 128 then (ofs, n) else aux buf n (p + 7) ofs - in - aux buf 0 0 ofs - - let len buf ofs = function - | `Int -> int buf ofs - | `Int8 -> int8 buf ofs - | `Int16 -> int16 buf ofs - | `Int32 -> - let ofs, i = int32 buf ofs in - (ofs, Int32.to_int i) - | `Int64 -> - let ofs, i = int64 buf ofs in - (ofs, Int64.to_int i) - | `Fixed n -> (ofs, n) - | `Unboxed -> (ofs, String.length buf - ofs) - - let mk_unboxed of_string of_bytes _ = - stage @@ fun buf ofs -> - let len = String.length buf - ofs in - if ofs = 0 then (len, of_string buf) - else - let str = Bytes.create len in - String.blit buf ofs str 0 len; - (ofs + len, of_bytes str) - - let mk_boxed of_string of_bytes = - let sub len buf ofs = - if ofs = 0 && len = String.length buf then (len, of_string buf) - else - let str = Bytes.create len in - String.blit buf ofs str 0 len; - (ofs + len, of_bytes str) - in - function - | `Fixed n -> - (* fixed-size strings are never boxed *) - stage @@ fun buf ofs -> sub n buf ofs - | n -> - stage @@ fun buf ofs -> - let ofs, len = len buf ofs n in - sub len buf ofs - - let mk of_string of_bytes = - let f_boxed = mk_boxed of_string of_bytes in - let f_unboxed = mk_unboxed of_string of_bytes in - fun boxed -> if boxed then f_boxed else f_unboxed - - let string = mk (fun x -> x) Bytes.unsafe_to_string - let bytes = mk Bytes.of_string (fun x -> x) + let string box = if box then Bin.String.decode else Bin.String_unboxed.decode + let bytes box = if box then Bin.Bytes.decode else Bin.Bytes_unboxed.decode let list l n = let l = unstage l in - stage (fun buf ofs -> - let ofs, len = len buf ofs n in - let rec aux acc ofs = function - | 0 -> (ofs, List.rev acc) - | n -> - let ofs, x = l buf ofs in - aux (x :: acc) ofs (n - 1) - in - aux [] ofs len) - - let array l len = - let decode_list = unstage (list l len) in - stage (fun buf ofs -> - let ofs, l = decode_list buf ofs in - (ofs, Array.of_list l)) + Bin.List.decode n l + + let array l n = + let l = unstage l in + Bin.Array.decode n l let pair a b = let a = unstage a and b = unstage b in - stage (fun buf ofs -> - let ofs, a = a buf ofs in - let ofs, b = b buf ofs in - (ofs, (a, b))) + stage (Bin.Pair.decode a b) let triple a b c = let a = unstage a and b = unstage b and c = unstage c in - stage (fun buf ofs -> - let ofs, a = a buf ofs in - let ofs, b = b buf ofs in - let ofs, c = c buf ofs in - (ofs, (a, b, c))) + stage (Bin.Triple.decode a b c) - let option : type a. a decode_bin -> a option decode_bin = - fun o -> + let option o = let o = unstage o in - stage (fun buf ofs -> - let ofs, c = char buf ofs in - match c with - | '\000' -> (ofs, None) - | _ -> - let ofs, x = o buf ofs in - (ofs, Some x)) + stage (Bin.Option.decode o) module Record_decoder = Fields_folder (struct type ('a, 'b) t = string -> int -> 'b -> 'a res @@ -391,13 +205,13 @@ module Decode = struct and prim : type a. boxed:bool -> a prim -> a decode_bin = fun ~boxed -> function - | Unit -> stage unit - | Bool -> stage bool - | Char -> stage char - | Int -> stage int - | Int32 -> stage int32 - | Int64 -> stage int64 - | Float -> stage float + | Unit -> stage Bin.Unit.decode + | Bool -> stage Bin.Bool.decode + | Char -> stage Bin.Char.decode + | Int -> stage Bin.Int.decode + | Int32 -> stage Bin.Int32.decode + | Int64 -> stage Bin.Int64.decode + | Float -> stage Bin.Float.decode | String n -> string boxed n | Bytes n -> bytes boxed n @@ -417,17 +231,16 @@ module Decode = struct and variant : type a. a variant -> a decode_bin = fun v -> let decoders : a decode_bin array = - v.vcases - |> Array.map @@ function - | C0 c -> stage (fun _ ofs -> (ofs, c.c0)) - | C1 c -> - let decode_arg = unstage (t c.ctype1) in - stage (fun buf ofs -> - let ofs, x = decode_arg buf ofs in - (ofs, c.c1 x)) + ArrayLabels.map v.vcases ~f:(function + | C0 c -> stage (fun _ ofs -> (ofs, c.c0)) + | C1 c -> + let decode_arg = unstage (t c.ctype1) in + stage (fun buf ofs -> + let ofs, x = decode_arg buf ofs in + (ofs, c.c1 x))) in stage (fun buf ofs -> - let ofs, i = int buf ofs in + let ofs, i = Bin.Int.decode buf ofs in unstage decoders.(i) buf ofs) end