diff --git a/test_wycheproof/test.ml b/test_wycheproof/test.ml index 53971be..4034a8b 100644 --- a/test_wycheproof/test.ml +++ b/test_wycheproof/test.ml @@ -12,68 +12,85 @@ let strip_prefix s ~prefix = Some (Str.string_after s prefix_len) else None -let rec strip_prefixes s ~prefixes = - match prefixes with - | [] -> - None - | prefix :: prefixes -> ( - match strip_prefix s ~prefix with - | Some _ as r -> - r - | None -> - strip_prefixes s ~prefixes ) - -let strip_asn1 = - let prefixes = - List.map Hex.to_string - [`Hex "3059301306072a8648ce3d020106082a8648ce3d030107034200"] +let result_of_option ~msg = function + | None -> + Error msg + | Some x -> + Ok x + +let strip_asn1 s = + let prefix = + Hex.to_string + (`Hex "3059301306072a8648ce3d020106082a8648ce3d030107034200") in - strip_prefixes ~prefixes + result_of_option ~msg:"unknown ASN1 prefix" (strip_prefix ~prefix s) + +let ( >>= ) xr f = + match xr with + | Error _ as e -> + e + | Ok x -> + f x let parse_point s = - match strip_asn1 s with - | Some payload -> - Fiat_p256.point_of_hex (Hex.of_string payload) - | None -> - None - -let parse_scalar s = Fiat_p256.scalar_of_hex (Hex.of_string s) - -let test_valid ~private_ ~public ~expected () = - match (parse_point public, parse_scalar private_) with - | Some point, Some scalar -> - let got = Cstruct.to_string (Fiat_p256.dh ~scalar ~point) in - Alcotest.check hex "should be equal" expected got - | _ -> - failwith "cannot parse test case" - -(* -let test_invalid ~private_ ~public () = - Alcotest.check_raises "should raise" Unverified_api.Error (fun () -> - let scalar = Unverified_api.parse_scalar private_ in - let point = Unverified_api.parse_point public in - ignore (Unverified_api.dh scalar point) ) - *) - -let make_test {tcId; comment; private_; public; shared; result; flags; _} = - let name = Printf.sprintf "%d - %s" tcId comment in - let ignored_flags = [invalid_asn; compressed_point; unnamed_curve] in - match result with + strip_asn1 s + >>= fun payload -> + result_of_option ~msg:"cannot parse point" + (Fiat_p256.point_of_hex (Hex.of_string payload)) + +let parse_scalar s = + result_of_option ~msg:"cannot parse scalar" + (Fiat_p256.scalar_of_hex (Hex.of_string s)) + +type test = + { name : string + ; point : Fiat_p256.point + ; scalar : Fiat_p256.scalar + ; expected : string } + +let interpret_test {name; point; scalar; expected} = + let run () = + let got = Cstruct.to_string (Fiat_p256.dh ~scalar ~point) in + Alcotest.check hex __LOC__ expected got + in + (name, `Quick, run) + +type strategy = + | Test of test + | Skip + +let test_name test = Printf.sprintf "%d - %s" test.tcId test.comment + +let make_test test = + let ignored_flags = ["InvalidAsn"; "CompressedPoint"; "UnnamedCurve"] in + match test.result with | _ - when List.exists - (fun ignored_flag -> List.mem ignored_flag flags) - ignored_flags -> - [] + when has_ignored_flag test ~ignored_flags -> + Ok Skip + | Invalid -> + Ok Skip | Valid |Acceptable -> - [(name, `Quick, test_valid ~private_ ~public ~expected:shared)] - | Invalid -> + parse_point test.public + >>= fun point -> + parse_scalar test.private_ + >>= fun scalar -> + let name = test_name test in + Ok (Test {name; point; scalar; expected = test.shared}) + +let concat_map f l = List.map f l |> List.concat + +let to_tests x = + match make_test x with + | Ok (Test t) -> + [interpret_test t] + | Ok Skip -> [] + | Error e -> + failwith e let tests = let data = load_file_exn "ecdh_secp256r1_test.json" in - data.testGroups - |> List.map (fun group -> List.map make_test group.tests |> List.concat) - |> List.concat + concat_map (fun group -> concat_map to_tests group.tests) data.testGroups let () = Alcotest.run "Wycheproof-hacl-p256" [("test vectors", tests)] diff --git a/wycheproof/wycheproof.ml b/wycheproof/wycheproof.ml index 30e6621..66edcf1 100644 --- a/wycheproof/wycheproof.ml +++ b/wycheproof/wycheproof.ml @@ -34,57 +34,6 @@ let test_result_of_yojson = function | _ -> Error "test_result" -type flag = - | Twist - | LowOrderPublic - | SmallPublicKey - | CompressedPoint - | AddSubChain - | InvalidPublic - | WrongOrder - | UnnamedCurve - | UnusedParam - | ModifiedPrime - | WeakPublicKey - | InvalidAsn -[@@deriving show] - -let invalid_asn = InvalidAsn - -let compressed_point = CompressedPoint - -let unnamed_curve = UnnamedCurve - -let flag_of_yojson = function - | `String "LowOrderPublic" -> - Ok LowOrderPublic - | `String "Twist" -> - Ok Twist - | `String "Small public key" -> - Ok SmallPublicKey - | `String "CompressedPoint" -> - Ok CompressedPoint - | `String "AddSubChain" -> - Ok AddSubChain - | `String "InvalidPublic" -> - Ok InvalidPublic - | `String "WrongOrder" -> - Ok WrongOrder - | `String "UnnamedCurve" -> - Ok UnnamedCurve - | `String "UnusedParam" -> - Ok UnusedParam - | `String "ModifiedPrime" -> - Ok ModifiedPrime - | `String "WeakPublicKey" -> - Ok WeakPublicKey - | `String "InvalidAsn" -> - Ok InvalidAsn - | `String s -> - Error ("Unknown flag: " ^ s) - | _ -> - Error "flag_of_yojson" - type test = { tcId : int ; comment : string @@ -93,9 +42,14 @@ type test = ; private_ : hex [@yojson.key "private"] ; shared : hex ; result : test_result - ; flags : flag list } + ; flags : string list } [@@deriving of_yojson, show] +let has_ignored_flag test ~ignored_flags = + List.exists + (fun ignored_flag -> List.mem ignored_flag test.flags) + ignored_flags + type test_group = { curve : json ; tests : test list diff --git a/wycheproof/wycheproof.mli b/wycheproof/wycheproof.mli index 9964538..c98b646 100644 --- a/wycheproof/wycheproof.mli +++ b/wycheproof/wycheproof.mli @@ -10,14 +10,6 @@ type test_result = | Invalid [@@deriving show] -type flag [@@deriving show] - -val invalid_asn : flag - -val compressed_point : flag - -val unnamed_curve : flag - type test = { tcId : int ; comment : string @@ -26,9 +18,11 @@ type test = ; private_ : hex ; shared : hex ; result : test_result - ; flags : flag list } + ; flags : string list } [@@deriving show] +val has_ignored_flag : test -> ignored_flags:string list -> bool + type test_group = { curve : json ; tests : test list