diff --git a/Cargo.lock b/Cargo.lock index e6f92245e..1a4eac8e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + [[package]] name = "aes" version = "0.7.5" @@ -30,7 +40,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241" dependencies = [ "cfg-if", - "cipher 0.4.3", + "cipher 0.4.4", "cpufeatures", ] @@ -40,11 +50,25 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df5f85a83a7d8b0442b6aa7b504b8212c1733da07b98aae43d4bc21b2cb3cdf6" dependencies = [ - "aead", + "aead 0.4.3", "aes 0.7.5", "cipher 0.3.0", "ctr 0.8.0", - "ghash", + "ghash 0.4.4", + "subtle", +] + +[[package]] +name = "aes-gcm" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82e1366e0c69c9f927b1fa5ce2c7bf9eafc8f9268c0b9800729e8b267612447c" +dependencies = [ + "aead 0.5.2", + "aes 0.8.2", + "cipher 0.4.4", + "ctr 0.9.2", + "ghash 0.5.0", "subtle", ] @@ -57,11 +81,51 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-wincon", + "concolor-override", + "concolor-query", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" + +[[package]] +name = "anstyle-parse" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-wincon" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" +dependencies = [ + "anstyle", + "windows-sys 0.45.0", +] + [[package]] name = "anyhow" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" +checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" [[package]] name = "assert_matches" @@ -71,13 +135,13 @@ checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" [[package]] name = "async-trait" -version = "0.1.66" +version = "0.1.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84f9ebcc6c1f5b8cb160f6990096a5c127f423fcb6e1ccc46c370cbdfb75dfc" +checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] @@ -86,12 +150,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "az" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b7e4c2464d97fe331d41de9d5db0def0a96f4d823b8b32a2efd503578988973" - [[package]] name = "base16ct" version = "0.1.1" @@ -112,9 +170,9 @@ checksum = "a4a4ddaa51a5bc52a6948f74c06d20aaaddb71924eab79b8c97a8c556e942d6a" [[package]] name = "base64ct" -version = "1.5.3" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b645a089122eccb6111b4f81cbc1a49f5900ac4666bb93ac027feaecf15607bf" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "bitflags" @@ -133,9 +191,9 @@ dependencies = [ [[package]] name = "block-buffer" -version = "0.10.3" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ "generic-array", ] @@ -146,12 +204,6 @@ version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" -[[package]] -name = "bytemuck" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c041d3eab048880cb0b86b256447da3f18859a163c3b8d8893f4e6368abe6393" - [[package]] name = "byteorder" version = "1.4.3" @@ -194,7 +246,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a18446b09be63d457bbec447509e85f662f32952b035ce892290396bc0b0cff5" dependencies = [ - "aead", + "aead 0.4.3", "chacha20", "cipher 0.3.0", "poly1305", @@ -203,9 +255,9 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.23" +version = "0.4.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" +checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" dependencies = [ "iana-time-zone", "js-sys", @@ -248,9 +300,9 @@ dependencies = [ [[package]] name = "cipher" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1873270f8f7942c191139cb8a40fd228da6c3fd2fc376d7e92d47aa14aeb59e" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ "crypto-common", "inout", @@ -258,48 +310,53 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.8" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" +checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" dependencies = [ - "bitflags", + "clap_builder", "clap_derive", - "clap_lex", - "is-terminal", "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", "strsim", - "termcolor", ] [[package]] name = "clap_derive" -version = "4.1.8" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" +checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" dependencies = [ "heck", - "proc-macro-error", "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] name = "clap_lex" -version = "0.3.1" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "783fe232adfca04f90f56201b26d79682d4cd2625e0bc7290b95123afe558ade" -dependencies = [ - "os_str_bytes", -] +checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" [[package]] name = "cmac" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "606383658416244b8dc4b36f864ec1f86cb922b95c41a908fd07aeb01cad06fa" +checksum = "8543454e3c3f5126effff9cd44d562af4e31fb8ce1cc0d3dcd8f084515dbc1aa" dependencies = [ - "cipher 0.4.3", + "cipher 0.4.4", "dbl", "digest 0.10.6", ] @@ -314,6 +371,21 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "concolor-override" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" +dependencies = [ + "windows-sys 0.45.0", +] + [[package]] name = "console_error_panic_hook" version = "0.1.7" @@ -342,25 +414,19 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" +checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181" dependencies = [ "libc", ] -[[package]] -name = "crunchy" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" - [[package]] name = "crypto-bigint" version = "0.3.2" @@ -380,6 +446,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core", "typenum", ] @@ -408,7 +475,7 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" dependencies = [ - "cipher 0.4.3", + "cipher 0.4.4", ] [[package]] @@ -426,9 +493,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.90" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90d59d9acd2a682b4e40605a242f6670eaa58c5957471cbf85e8aa6a0b97a5e8" +checksum = "f61f1b6389c3fe1c316bf8a4dccc90a38208354b330925bce1f74a6c4756eb93" dependencies = [ "cc", "cxxbridge-flags", @@ -438,9 +505,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.90" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebfa40bda659dd5c864e65f4c9a2b0aff19bea56b017b9b77c73d3766a453a38" +checksum = "12cee708e8962df2aeb38f594aae5d827c022b6460ac71a7a3e2c3c2aae5a07b" dependencies = [ "cc", "codespan-reporting", @@ -448,24 +515,24 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn", + "syn 2.0.13", ] [[package]] name = "cxxbridge-flags" -version = "1.0.90" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "457ce6757c5c70dc6ecdbda6925b958aae7f959bda7d8fb9bde889e34a09dc03" +checksum = "7944172ae7e4068c533afbb984114a56c46e9ccddda550499caa222902c7f7bb" [[package]] name = "cxxbridge-macro" -version = "1.0.90" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebf883b7aacd7b2aeb2a7b338648ee19f57c140d4ee8e52c68979c6b2f7f2263" +checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] @@ -603,7 +670,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ - "block-buffer 0.10.3", + "block-buffer 0.10.4", "crypto-common", "subtle", ] @@ -649,13 +716,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.2.8" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" dependencies = [ "errno-dragonfly", "libc", - "winapi", + "windows-sys 0.45.0", ] [[package]] @@ -687,18 +754,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "fixed" -version = "1.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55f3be4cf4fc227d3a63bb77512a2b7d364200b2a715f389155785c4d3345495" -dependencies = [ - "az", - "bytemuck", - "half", - "typenum", -] - [[package]] name = "fnv" version = "1.0.7" @@ -731,9 +786,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" dependencies = [ "futures-channel", "futures-core", @@ -746,9 +801,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" dependencies = [ "futures-core", "futures-sink", @@ -756,15 +811,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" [[package]] name = "futures-executor" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" dependencies = [ "futures-core", "futures-task", @@ -773,38 +828,38 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" [[package]] name = "futures-macro" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] name = "futures-sink" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" [[package]] name = "futures-task" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" [[package]] name = "futures-util" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ "futures-channel", "futures-core", @@ -820,9 +875,9 @@ dependencies = [ [[package]] name = "generic-array" -version = "0.14.6" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", @@ -848,7 +903,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1583cc1656d7839fd3732b80cf4f38850336cdb9b8ded1cd399ca62958de3c99" dependencies = [ "opaque-debug", - "polyval", + "polyval 0.5.3", +] + +[[package]] +name = "ghash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d930750de5717d2dd0b8c0d42c076c0e884c81a73e6cab859bbd2339c71e3e40" +dependencies = [ + "opaque-debug", + "polyval 0.6.0", ] [[package]] @@ -864,9 +929,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" +checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d" dependencies = [ "bytes", "fnv", @@ -881,15 +946,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "half" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" -dependencies = [ - "crunchy", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -987,7 +1043,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8603fbb7674b627cb9f3de6ae5ba8b7e621ef5fdea6489363cb4bef26e5fce52" dependencies = [ - "aes-gcm", + "aes-gcm 0.9.4", "chacha20poly1305", "getrandom", "hkdf", @@ -1002,9 +1058,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" dependencies = [ "bytes", "fnv", @@ -1036,9 +1092,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.24" +version = "0.14.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" +checksum = "cc5e554ff619822309ffd57d8734d77cd5ce6238bc956f037ea06c58238c9899" dependencies = [ "bytes", "futures-channel", @@ -1073,16 +1129,16 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.53" +version = "0.1.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765" +checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "winapi", + "windows", ] [[package]] @@ -1107,9 +1163,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.2" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown", @@ -1135,25 +1191,26 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.5" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" dependencies = [ + "hermit-abi 0.3.1", "libc", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] name = "ipnet" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" +checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" [[package]] name = "is-terminal" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +checksum = "256017f749ab3117e93acb91063009e1f1bb56d03965b14c2c8df4eb02c524d8" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", @@ -1163,9 +1220,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "js-sys" @@ -1176,6 +1233,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3afef3b6eff9ce9d8ff9b3601125eec7f0c8cbac7abd14f355d053fa56c98768" +dependencies = [ + "cpufeatures", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1184,9 +1250,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.139" +version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" [[package]] name = "link-cplusplus" @@ -1199,9 +1265,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.1.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "lock_api" @@ -1251,9 +1317,9 @@ checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "mime" -version = "0.3.16" +version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "mio" @@ -1338,9 +1404,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.48" +version = "0.10.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "518915b97df115dd36109bfa429a48b8f737bd05508cf9588977b599648926d2" +checksum = "4d2f106ab837a24e03672c59b1239669a0596406ff657c3c0835b6b7f0f35a33" dependencies = [ "bitflags", "cfg-if", @@ -1353,13 +1419,13 @@ dependencies = [ [[package]] name = "openssl-macros" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b501e44f11665960c7e7fcf062c7d96a14ade4aa98116c004b2e37b5be7d736c" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] @@ -1370,23 +1436,16 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.83" +version = "0.9.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "666416d899cf077260dac8698d60a60b435a46d57e82acb1be3d0dad87284e5b" +checksum = "3a20eace9dc2d82904039cb76dcf50fb1a0bba071cfd1629720b5d6f1ddba0fa" dependencies = [ - "autocfg", "cc", "libc", "pkg-config", "vcpkg", ] -[[package]] -name = "os_str_bytes" -version = "6.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" - [[package]] name = "overload" version = "0.1.1" @@ -1433,7 +1492,7 @@ checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.2.16", "smallvec", "windows-sys 0.45.0", ] @@ -1515,7 +1574,7 @@ checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -1555,7 +1614,7 @@ checksum = "048aeb476be11a4b6ca432ca569e375810de9294ae78f4774e78ea98a9246ede" dependencies = [ "cpufeatures", "opaque-debug", - "universal-hash", + "universal-hash 0.4.1", ] [[package]] @@ -1567,7 +1626,19 @@ dependencies = [ "cfg-if", "cpufeatures", "opaque-debug", - "universal-hash", + "universal-hash 0.4.1", +] + +[[package]] +name = "polyval" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef234e08c11dfcb2e56f79fd70f6f2eb7f025c0ce2333e82f4f0518ecad30c6" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash 0.5.0", ] [[package]] @@ -1578,53 +1649,30 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prio" -version = "0.10.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0453e1ec5f2f48af9a67aab1d812f9aec48619dda0d9a59e2f79ebd235448ec" +checksum = "9c2aa1f9faa3fab6f02b54025f411d6f4fcd31765765600db339280e3678ae20" dependencies = [ "aes 0.8.2", - "aes-gcm", - "base64 0.13.1", + "aes-gcm 0.10.1", + "base64 0.21.0", "byteorder", "cmac", "ctr 0.9.2", - "fixed", "getrandom", "ring", "serde", + "sha3", "static_assertions", + "subtle", "thiserror", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] @@ -1652,9 +1700,9 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" [[package]] name = "quote" -version = "1.0.23" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -1698,11 +1746,20 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ "regex-syntax", ] @@ -1718,24 +1775,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.28" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" - -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi", -] +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "reqwest" -version = "0.11.14" +version = "0.11.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" +checksum = "27b71749df584b7f4cac2c426c127a7c785a5106cc98f7a8feb044115f0fa254" dependencies = [ "base64 0.21.0", "bytes", @@ -1834,23 +1882,23 @@ dependencies = [ [[package]] name = "rustix" -version = "0.36.8" +version = "0.37.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43abb88211988493c1abb44a70efa56ff0ce98f233b7b276146f1f3f7ba9644" +checksum = "1aef160324be24d31a62147fae491c14d2204a3865c7ca8c3b0d7f7bcb3ea635" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "schannel" @@ -1869,9 +1917,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "scratch" -version = "1.0.3" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddccb15bcce173023b3fedd9436f882a0739b8dfb45e4f6b6002bee5929f61b2" +checksum = "1792db035ce95be60c3f8853017b3999209281c24e2ba5bc8e59bf97a0c590c1" [[package]] name = "sec1" @@ -1911,9 +1959,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.154" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cdd151213925e7f1ab45a9bbfb129316bd00799784b174b7cc7bcd16961c49e" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" dependencies = [ "serde_derive", ] @@ -1931,20 +1979,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.154" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fc80d722935453bcafdc2c9a73cd6fac4dc1938f0346035d84bf99fa9e33217" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] name = "serde_json" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" dependencies = [ "itoa", "ryu", @@ -1987,6 +2035,16 @@ dependencies = [ "digest 0.10.6", ] +[[package]] +name = "sha3" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdf0c33fae925bdc080598b84bc15c55e7b9a4a43b3c704da051f977469691c9" +dependencies = [ + "digest 0.10.6", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.4" @@ -2023,9 +2081,9 @@ checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" [[package]] name = "slab" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" dependencies = [ "autocfg", ] @@ -2038,9 +2096,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" dependencies = [ "libc", "winapi", @@ -2098,29 +2156,27 @@ dependencies = [ ] [[package]] -name = "synstructure" -version = "0.12.6" +name = "syn" +version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec" dependencies = [ "proc-macro2", "quote", - "syn", - "unicode-xid", + "unicode-ident", ] [[package]] name = "tempfile" -version = "3.3.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" +checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" dependencies = [ "cfg-if", "fastrand", - "libc", - "redox_syscall", - "remove_dir_all", - "winapi", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.45.0", ] [[package]] @@ -2134,22 +2190,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] @@ -2196,19 +2252,18 @@ checksum = "a787719c86efec1535b071bde678f5fd649380e8005cd1ebd0afeb4bcc4d2a85" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] name = "tokio" -version = "1.26.0" +version = "1.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" +checksum = "d0de47a4eecbe11f498978a9b29d792f0d2692d1dd003650c24c76510e3bc001" dependencies = [ "autocfg", "bytes", "libc", - "memchr", "mio", "num_cpus", "parking_lot", @@ -2221,13 +2276,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.8.2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" +checksum = "61a573bdc87985e9d6ddeed1b3d864e8a302c847e40d647746df2f1de209d1ce" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.13", ] [[package]] @@ -2280,7 +2335,7 @@ checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -2345,15 +2400,15 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.10" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" [[package]] name = "unicode-normalization" @@ -2370,12 +2425,6 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" -[[package]] -name = "unicode-xid" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" - [[package]] name = "universal-hash" version = "0.4.1" @@ -2386,6 +2435,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "universal-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d3160b73c9a19f7e2939a2fdad446c57c1bbbbf4d919d3213ff1267a580d8b5" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.7.1" @@ -2404,6 +2463,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "valuable" version = "0.1.0" @@ -2461,7 +2526,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-shared", ] @@ -2495,7 +2560,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2560,19 +2625,28 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" +dependencies = [ + "windows-targets 0.48.0", +] + [[package]] name = "windows-sys" version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", ] [[package]] @@ -2581,65 +2655,131 @@ version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" dependencies = [ - "windows-targets", + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", ] [[package]] name = "windows-targets" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.1" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" [[package]] name = "windows_aarch64_msvc" -version = "0.42.1" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_gnu" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" [[package]] name = "windows_i686_msvc" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" [[package]] name = "windows_x86_64_gnu" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.1" +version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "windows_x86_64_msvc" -version = "0.42.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" [[package]] name = "winreg" @@ -2700,7 +2840,7 @@ dependencies = [ "async-trait", "proc-macro2", "quote", - "syn", + "syn 1.0.109", "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-macro-support", @@ -2733,21 +2873,20 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.5.7" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c394b5bd0c6f669e7275d9c20aa90ae064cb22e75a1cad54e1b34088034b149f" +checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" dependencies = [ "zeroize_derive", ] [[package]] name = "zeroize_derive" -version = "1.3.3" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bf07cb3e50ea2003396695d58bf46bc9887a1f362260446fad6bc4e79bd36c" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn", - "synstructure", + "syn 2.0.13", ] diff --git a/daphne/Cargo.toml b/daphne/Cargo.toml index c959e3db4..4e194e2e5 100644 --- a/daphne/Cargo.toml +++ b/daphne/Cargo.toml @@ -29,7 +29,7 @@ hpke-rs-rust-crypto = { version = "0.1.1"} lazy_static = "1.4.0" matchit = "0.7.0" paste = "1.0.12" -prio = { version = "0.10.0", features = ["prio2"] } +prio = { version = "0.12.0", features = ["prio2"] } prometheus = "0.13.3" rand = "0.8.5" ring = "0.16.20" diff --git a/daphne/dapf/Cargo.toml b/daphne/dapf/Cargo.toml index 3f2d4b4d4..5986da5ab 100644 --- a/daphne/dapf/Cargo.toml +++ b/daphne/dapf/Cargo.toml @@ -14,7 +14,7 @@ license = "BSD-3-Clause" daphne = { path = ".." } assert_matches = "1.5.0" base64 = "0.21.0" -prio = "0.10.0" +prio = "0.12.0" serde = { version = "1.0.154", features = ["derive"] } serde_json = "1.0.94" url = { version = "2.3.1", features = ["serde"] } diff --git a/daphne/dapf/src/bin/dapf.rs b/daphne/dapf/src/bin/dapf.rs index d9d0433a1..b952a7fc2 100644 --- a/daphne/dapf/src/bin/dapf.rs +++ b/daphne/dapf/src/bin/dapf.rs @@ -6,10 +6,10 @@ use clap::{Parser, Subcommand}; use daphne::{ constants, hpke::HpkeReceiverConfig, - messages::{decode_base64url, BatchSelector, CollectReq, CollectResp, HpkeConfig, Id, Query}, + messages::{BatchSelector, Collection, CollectionReq, HpkeConfig, Query, TaskId}, DapMeasurement, DapVersion, ProblemDetails, VdafConfig, }; -use prio::codec::{Decode, ParameterizedEncode}; +use prio::codec::{Decode, ParameterizedDecode, ParameterizedEncode}; use reqwest::blocking::{Client, ClientBuilder}; use std::{ io::{stdin, Read}, @@ -151,8 +151,12 @@ async fn main() -> Result<()> { serde_json::from_str(&buf).with_context(|| "failed to parse JSON from stdin")?; // Construct collect request. - let collect_req = CollectReq { - task_id, + let collect_req = CollectionReq { + draft02_task_id: if version == DapVersion::Draft02 { + Some(task_id) + } else { + None + }, query, agg_param: Vec::default(), }; @@ -212,7 +216,7 @@ async fn main() -> Result<()> { let receiver = cli.hpke_receiver.as_ref().ok_or_else(|| { anyhow!("received response, but cannot decrypt without HPKE receiver config") })?; - let collect_resp = CollectResp::get_decoded(&resp.bytes()?)?; + let collect_resp = Collection::get_decoded_with_param(&version, &resp.bytes()?)?; let agg_res = vdaf .consume_encrypted_agg_shares( receiver, @@ -230,15 +234,14 @@ async fn main() -> Result<()> { } } -fn parse_id(id_str: &str) -> Result { - let id_bytes = decode_base64url(id_str.as_bytes()) +fn parse_id(id_str: &str) -> Result { + TaskId::try_from_base64url(id_str) .ok_or_else(|| anyhow!("failed to decode ID")) - .with_context(|| "expected URL-safe, base64 string")?; - Ok(Id(id_bytes)) + .with_context(|| "expected URL-safe, base64 string") } // TODO(cjpatton) Refactor integration tests to use this method. -fn get_hpke_config(http_client: &Client, task_id: &Id, base_url: &str) -> Result { +fn get_hpke_config(http_client: &Client, task_id: &TaskId, base_url: &str) -> Result { let url = Url::parse(base_url) .with_context(|| "failed to parse base URL")? .join("hpke_config")?; diff --git a/daphne/src/auth.rs b/daphne/src/auth.rs index 36051ac5a..1e642a6fc 100644 --- a/daphne/src/auth.rs +++ b/daphne/src/auth.rs @@ -5,7 +5,7 @@ use crate::{ constants::sender_for_media_type, - messages::{constant_time_eq, Id}, + messages::{constant_time_eq, TaskId}, DapError, DapRequest, DapSender, }; use async_trait::async_trait; @@ -58,13 +58,13 @@ pub trait BearerTokenProvider<'a> { /// Fetch the Leader's bearer token for the given task, if the task is recognized. async fn get_leader_bearer_token_for( &'a self, - task_id: &'a Id, + task_id: &'a TaskId, ) -> Result, DapError>; /// Fetch the Collector's bearer token for the given task, if the task is recognized. async fn get_collector_bearer_token_for( &'a self, - task_id: &'a Id, + task_id: &'a TaskId, ) -> Result, DapError>; /// Returns true if the given bearer token matches the leader token configured for the "taskprov" extension. @@ -77,7 +77,7 @@ pub trait BearerTokenProvider<'a> { /// media type. async fn authorize_with_bearer_token( &'a self, - task_id: &'a Id, + task_id: &'a TaskId, media_type: &'static str, ) -> Result { if matches!(sender_for_media_type(media_type), Some(DapSender::Leader)) { diff --git a/daphne/src/constants.rs b/daphne/src/constants.rs index 593056d71..ff7947a7d 100644 --- a/daphne/src/constants.rs +++ b/daphne/src/constants.rs @@ -3,23 +3,29 @@ //! Constants used in the DAP protocol. -use crate::DapSender; +use crate::{DapSender, DapVersion}; // Media types for HTTP requests. // // TODO spec: Decide if media type should be enforced. (We currently don't.) In any case, it may be // useful to enforce this for testing purposes. pub const DRAFT02_MEDIA_TYPE_HPKE_CONFIG: &str = "application/dap-hpke-config"; +pub const DRAFT02_MEDIA_TYPE_AGG_INIT_REQ: &str = "application/dap-aggregate-initialize-req"; +pub const DRAFT02_MEDIA_TYPE_AGG_INIT_RESP: &str = "application/dap-aggregate-initialize-resp"; +pub const DRAFT02_MEDIA_TYPE_AGG_CONT_REQ: &str = "application/dap-aggregate-continue-req"; +pub const DRAFT02_MEDIA_TYPE_AGG_CONT_RESP: &str = "application/dap-aggregate-continue-resp"; +pub const DRAFT02_MEDIA_TYPE_AGG_SHARE_RESP: &str = "application/dap-aggregate-share-resp"; +pub const DRAFT02_MEDIA_TYPE_COLLECT_RESP: &str = "application/dap-collect-resp"; pub const MEDIA_TYPE_HPKE_CONFIG_LIST: &str = "application/dap-hpke-config-list"; pub const MEDIA_TYPE_REPORT: &str = "application/dap-report"; -pub const MEDIA_TYPE_AGG_INIT_REQ: &str = "application/dap-aggregate-initialize-req"; -pub const MEDIA_TYPE_AGG_INIT_RESP: &str = "application/dap-aggregate-initialize-resp"; -pub const MEDIA_TYPE_AGG_CONT_REQ: &str = "application/dap-aggregate-continue-req"; -pub const MEDIA_TYPE_AGG_CONT_RESP: &str = "application/dap-aggregate-continue-resp"; +pub const MEDIA_TYPE_AGG_INIT_REQ: &str = "application/dap-aggregation-job-init-req"; +pub const MEDIA_TYPE_AGG_INIT_RESP: &str = "application/dap-aggregation-job-resp"; +pub const MEDIA_TYPE_AGG_CONT_REQ: &str = "application/dap-aggregation-job-continue-req"; +pub const MEDIA_TYPE_AGG_CONT_RESP: &str = "application/dap-aggregation-job-continue-resp"; pub const MEDIA_TYPE_AGG_SHARE_REQ: &str = "application/dap-aggregate-share-req"; -pub const MEDIA_TYPE_AGG_SHARE_RESP: &str = "application/dap-aggregate-share-resp"; +pub const MEDIA_TYPE_AGG_SHARE_RESP: &str = "application/dap-aggregate-share"; pub const MEDIA_TYPE_COLLECT_REQ: &str = "application/dap-collect-req"; -pub const MEDIA_TYPE_COLLECT_RESP: &str = "application/dap-collect-resp"; +pub const MEDIA_TYPE_COLLECT_RESP: &str = "application/dap-collection"; /// Check if the provided value for the HTTP Content-Type is valid media type for DAP. If so, then /// return a static reference to the media type. @@ -27,27 +33,49 @@ pub fn media_type_for(content_type: &str) -> Option<&'static str> { match content_type { DRAFT02_MEDIA_TYPE_HPKE_CONFIG => Some(DRAFT02_MEDIA_TYPE_HPKE_CONFIG), MEDIA_TYPE_REPORT => Some(MEDIA_TYPE_REPORT), + DRAFT02_MEDIA_TYPE_AGG_INIT_REQ => Some(DRAFT02_MEDIA_TYPE_AGG_INIT_REQ), MEDIA_TYPE_AGG_INIT_REQ => Some(MEDIA_TYPE_AGG_INIT_REQ), + DRAFT02_MEDIA_TYPE_AGG_INIT_RESP => Some(DRAFT02_MEDIA_TYPE_AGG_INIT_RESP), MEDIA_TYPE_AGG_INIT_RESP => Some(MEDIA_TYPE_AGG_INIT_RESP), + DRAFT02_MEDIA_TYPE_AGG_CONT_REQ => Some(DRAFT02_MEDIA_TYPE_AGG_CONT_REQ), MEDIA_TYPE_AGG_CONT_REQ => Some(MEDIA_TYPE_AGG_CONT_REQ), + DRAFT02_MEDIA_TYPE_AGG_CONT_RESP => Some(DRAFT02_MEDIA_TYPE_AGG_CONT_RESP), MEDIA_TYPE_AGG_CONT_RESP => Some(MEDIA_TYPE_AGG_CONT_RESP), MEDIA_TYPE_AGG_SHARE_REQ => Some(MEDIA_TYPE_AGG_SHARE_REQ), + DRAFT02_MEDIA_TYPE_AGG_SHARE_RESP => Some(DRAFT02_MEDIA_TYPE_AGG_SHARE_RESP), MEDIA_TYPE_AGG_SHARE_RESP => Some(MEDIA_TYPE_AGG_SHARE_RESP), MEDIA_TYPE_COLLECT_REQ => Some(MEDIA_TYPE_COLLECT_REQ), + DRAFT02_MEDIA_TYPE_COLLECT_RESP => Some(DRAFT02_MEDIA_TYPE_COLLECT_RESP), MEDIA_TYPE_COLLECT_RESP => Some(MEDIA_TYPE_COLLECT_RESP), _ => None, } } +/// draft02 compatibility: Substitute the content type with the corresponding media type for the +/// older version of the protocol if necessary. +pub fn versioned_media_type_for(version: &DapVersion, content_type: &str) -> Option<&'static str> { + match (version, content_type) { + (DapVersion::Draft02, MEDIA_TYPE_AGG_INIT_REQ) => Some(DRAFT02_MEDIA_TYPE_AGG_INIT_REQ), + (DapVersion::Draft02, MEDIA_TYPE_AGG_INIT_RESP) => Some(DRAFT02_MEDIA_TYPE_AGG_INIT_RESP), + (DapVersion::Draft02, MEDIA_TYPE_AGG_CONT_REQ) => Some(DRAFT02_MEDIA_TYPE_AGG_CONT_REQ), + (DapVersion::Draft02, MEDIA_TYPE_AGG_CONT_RESP) => Some(DRAFT02_MEDIA_TYPE_AGG_CONT_RESP), + (DapVersion::Draft02, MEDIA_TYPE_AGG_SHARE_RESP) => Some(DRAFT02_MEDIA_TYPE_AGG_SHARE_RESP), + (DapVersion::Draft02, MEDIA_TYPE_COLLECT_RESP) => Some(DRAFT02_MEDIA_TYPE_COLLECT_RESP), + _ => media_type_for(content_type), + } +} + /// Return the sender that would send a message with the given media type (or none if the sender /// can't be determined). pub fn sender_for_media_type(media_type: &'static str) -> Option { match media_type { DRAFT02_MEDIA_TYPE_HPKE_CONFIG | MEDIA_TYPE_REPORT => Some(DapSender::Client), MEDIA_TYPE_COLLECT_REQ => Some(DapSender::Collector), - MEDIA_TYPE_AGG_INIT_REQ | MEDIA_TYPE_AGG_CONT_REQ | MEDIA_TYPE_AGG_SHARE_REQ => { - Some(DapSender::Leader) - } + MEDIA_TYPE_AGG_INIT_REQ + | MEDIA_TYPE_AGG_CONT_REQ + | MEDIA_TYPE_AGG_SHARE_REQ + | DRAFT02_MEDIA_TYPE_AGG_INIT_REQ + | DRAFT02_MEDIA_TYPE_AGG_CONT_REQ => Some(DapSender::Leader), _ => None, } } diff --git a/daphne/src/hpke.rs b/daphne/src/hpke.rs index dae58efd7..af8d1762c 100644 --- a/daphne/src/hpke.rs +++ b/daphne/src/hpke.rs @@ -14,7 +14,7 @@ use hpke_rs_rust_crypto::HpkeRustCrypto as ImplHpkeCrypto; use crate::{ messages::{ decode_u16_bytes, encode_u16_bytes, HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeKdfId, - HpkeKemId, Id, TransitionFailure, + HpkeKemId, TaskId, TransitionFailure, }, DapError, DapVersion, }; @@ -99,16 +99,16 @@ pub trait HpkeDecrypter<'a> { async fn get_hpke_config_for( &'a self, version: DapVersion, - task_id: Option<&Id>, + task_id: Option<&TaskId>, ) -> Result; /// Returns `true` if a ciphertext with the HPKE config ID can be consumed in the current task. - async fn can_hpke_decrypt(&self, task_id: &Id, config_id: u8) -> Result; + async fn can_hpke_decrypt(&self, task_id: &TaskId, config_id: u8) -> Result; /// Decrypt the given HPKE ciphertext using the given info and AAD string. async fn hpke_decrypt( &self, - task_id: &Id, + task_id: &TaskId, info: &[u8], aad: &[u8], ciphertext: &HpkeCiphertext, @@ -208,18 +208,18 @@ impl<'a> HpkeDecrypter<'a> for HpkeReceiverConfig { async fn get_hpke_config_for( &'a self, _version: DapVersion, - _task_id: Option<&Id>, + _task_id: Option<&TaskId>, ) -> Result { unreachable!("not implemented"); } - async fn can_hpke_decrypt(&self, _task_id: &Id, config_id: u8) -> Result { + async fn can_hpke_decrypt(&self, _task_id: &TaskId, config_id: u8) -> Result { Ok(config_id == self.config.id) } async fn hpke_decrypt( &self, - _task_id: &Id, + _task_id: &TaskId, info: &[u8], aad: &[u8], ciphertext: &HpkeCiphertext, diff --git a/daphne/src/lib.rs b/daphne/src/lib.rs index 3da70c68d..759d4d66f 100644 --- a/daphne/src/lib.rs +++ b/daphne/src/lib.rs @@ -37,26 +37,29 @@ use crate::{ hpke::HpkeReceiverConfig, messages::{ - BatchSelector, CollectResp, Duration, HpkeConfig, Id, Interval, PartialBatchSelector, - ReportId, ReportMetadata, Time, TransitionFailure, + AggregationJobId, BatchId, BatchSelector, Collection, CollectionJobId, + Draft02AggregationJobId, Duration, HpkeConfig, HpkeKemId, Interval, PartialBatchSelector, + ReportId, ReportMetadata, TaskId, Time, TransitionFailure, }, + taskprov::TaskprovVersion, vdaf::{ prio2::prio2_decode_prepare_state, prio3::{prio3_append_prepare_state, prio3_decode_prepare_state}, VdafAggregateShare, VdafError, VdafMessage, VdafState, VdafVerifyKey, }, }; -use messages::HpkeKemId; use prio::{ codec::{CodecError, Decode, Encode}, vdaf::Aggregatable as AggregatableTrait, }; +use rand::prelude::*; use serde::{Deserialize, Serialize}; use std::{ + borrow::Cow, + cmp::{max, min}, collections::{HashMap, HashSet}, fmt::Debug, }; -use taskprov::TaskprovVersion; use url::Url; /// DAP errors. @@ -175,6 +178,11 @@ pub enum DapAbort { #[error("reportTooLate")] ReportTooLate, + /// Round mismatch. The aggregators disagree on the current round of the VDAF preparation protocol. + /// This abort occurs during the aggregation sub-protocol. + #[error("roundMismatch")] + RoundMismatch, + /// Stale report. Sent in response to an upload request containing a report pertaining to a /// batch that has already been collected. #[error("staleReport")] @@ -217,6 +225,7 @@ impl DapAbort { | Self::InvalidProtocolVersion | Self::InvalidTask | Self::QueryMismatch + | Self::RoundMismatch | Self::MissingTaskId | Self::ReplayedReport | Self::ReportTooLate @@ -287,8 +296,8 @@ pub enum DapVersion { #[serde(rename = "v02")] Draft02, - #[serde(rename = "v03")] - Draft03, + #[serde(rename = "v04")] + Draft04, #[serde(other)] #[serde(rename = "unknown_version")] @@ -299,7 +308,7 @@ impl From<&str> for DapVersion { fn from(version: &str) -> Self { match version { "v02" => DapVersion::Draft02, - "v03" => DapVersion::Draft03, + "v04" => DapVersion::Draft04, _ => DapVersion::Unknown, } } @@ -309,7 +318,7 @@ impl AsRef for DapVersion { fn as_ref(&self) -> &str { match self { DapVersion::Draft02 => "v02", - DapVersion::Draft03 => "v03", + DapVersion::Draft04 => "v04", _ => panic!("tried to construct string from unknown DAP version"), } } @@ -432,7 +441,7 @@ impl DapQueryConfig { /// bucket, which is the batch determined by the batch ID (i.e., the partial batch selector). #[derive(Clone, Eq, Hash, PartialEq)] pub enum DapBatchBucket<'a> { - FixedSize { batch_id: &'a Id }, + FixedSize { batch_id: &'a BatchId }, TimeInterval { batch_window: Time }, } @@ -476,7 +485,7 @@ impl DapTaskConfig { /// numbre of seconds since the beginning of UNIX time. #[cfg(test)] pub fn query_for_current_batch_window(&self, now: u64) -> crate::messages::Query { - let start = now - (now % self.time_precision); + let start = self.quantized_time_lower_bound(now); crate::messages::Query::TimeInterval { batch_interval: crate::messages::Interval { start, @@ -485,10 +494,17 @@ impl DapTaskConfig { } } - pub(crate) fn truncate_time(&self, time: Time) -> Time { + /// Return the greatest multiple of the time_precision which is less than or equal to the + /// specified time. + pub fn quantized_time_lower_bound(&self, time: Time) -> Time { time - (time % self.time_precision) } + /// Return the least multiple of the time_precision which is greater than the specified time. + pub fn quantized_time_upper_bound(&self, time: Time) -> Time { + self.quantized_time_lower_bound(time) + self.time_precision + } + /// Compute the "batch span" of a set of output shares and, for each buckent in the span, /// aggregate the output shares into an aggregate share. pub fn batch_span_for_out_shares<'a>( @@ -506,7 +522,7 @@ impl DapTaskConfig { for out_share in out_shares.into_iter() { let bucket = match part_batch_sel { PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval { - batch_window: self.truncate_time(out_share.time), + batch_window: self.quantized_time_lower_bound(out_share.time), }, PartialBatchSelector::FixedSizeByBatchId { batch_id } => { DapBatchBucket::FixedSize { batch_id } @@ -516,6 +532,8 @@ impl DapTaskConfig { let agg_share = span.entry(bucket).or_default(); agg_share.merge(DapAggregateShare { report_count: 1, + min_time: out_share.time, + max_time: out_share.time, checksum: out_share.checksum, data: Some(out_share.data), })?; @@ -569,7 +587,7 @@ impl DapTaskConfig { for metadata in report_meta { let bucket = match part_batch_sel { PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval { - batch_window: self.truncate_time(metadata.time), + batch_window: self.quantized_time_lower_bound(metadata.time), }, PartialBatchSelector::FixedSizeByBatchId { batch_id } => { DapBatchBucket::FixedSize { batch_id } @@ -707,6 +725,8 @@ pub struct DapOutputShare { #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct DapAggregateShare { pub(crate) report_count: u64, + pub(crate) min_time: Time, + pub(crate) max_time: Time, pub(crate) checksum: [u8; 32], pub(crate) data: Option, } @@ -744,6 +764,18 @@ impl DapAggregateShare { _ => return Err(DapError::fatal("invalid aggregate share merge")), }; + if self.report_count == 0 { + // No interval yet, just copy other's interval + self.min_time = other.min_time; + self.max_time = other.max_time; + } else if other.report_count > 0 { + // Note that we don't merge if other.report_count == 0, as in that case the timestamps + // are 0 too, and thus bad to merge! + self.min_time = min(self.min_time, other.min_time); + self.max_time = max(self.max_time, other.max_time); + } else { + // Do nothing! + } self.report_count += other.report_count; for (x, y) in self.checksum.iter_mut().zip(other.checksum) { *x ^= y; @@ -759,6 +791,8 @@ impl DapAggregateShare { /// Set the aggregate share to zero. pub fn reset(&mut self) { self.report_count = 0; + self.min_time = 0; + self.max_time = 0; self.checksum = [0; 32]; self.data = None; } @@ -771,6 +805,8 @@ impl DapAggregateShare { for out_share in out_shares.into_iter() { agg_share.merge(DapAggregateShare { report_count: 1, + min_time: out_share.time, + max_time: out_share.time, checksum: out_share.checksum, data: Some(out_share.data), })?; @@ -812,7 +848,7 @@ pub enum DapHelperTransition { #[serde(rename_all = "snake_case")] pub enum VdafConfig { Prio3(Prio3Config), - Prio2 { dimension: u32 }, + Prio2 { dimension: usize }, } impl std::str::FromStr for VdafConfig { @@ -837,7 +873,7 @@ pub enum Prio3Config { /// The sum of 64-bit, unsigned integers. Each measurement is an integer in range `[0, /// 2^bits)`. - Sum { bits: u32 }, + Sum { bits: usize }, } /// DAP sender role. @@ -848,24 +884,87 @@ pub enum DapSender { Leader, } +/// Types of resources associated with DAP tasks. +#[derive(Debug)] +pub enum DapResource { + /// Aggregation job resource. + AggregationJob(AggregationJobId), + + /// Collection job resource. + CollectionJob(CollectionJobId), + + /// Undefined (or undetermined) resource. + /// + /// The resource of a DAP request is undefined if there is not a unique object (in the context + /// of a DAP task) that the request pertains to. For example: + /// + /// * The Client->Aggregator request for the HPKE config or to upload a report + /// * The Leader->Helper request for an aggregate share + /// + /// The resource of a DAP request is undetermined if its identifier could not be parsed from + /// request path. + /// + /// draft02 compatibility: In draft02, the resource of a DAP request is undetermined until the + /// request payload is parsed. Defer detrmination of the resource until then. + Undefined, +} + /// DAP request. #[derive(Debug)] pub struct DapRequest { + /// Protocol version indicated by the request. pub version: DapVersion, + + /// Request media type, sent in the "content-type" header of the HTTP request. pub media_type: Option<&'static str>, - pub task_id: Option, + + /// ID of the task with which the request is associated. This field is optional, since some + /// requests may apply to all tasks, e.g., the request for the HPKE configuration. + pub task_id: Option, + + /// The resource with which this request is associated. + pub resource: DapResource, + + /// Request payload. pub payload: Vec, + + /// Requst path (i.e., URL). pub url: Url, + + /// Sender authorization, e.g., a bearer token. pub sender_auth: Option, } impl DapRequest { - pub(crate) fn task_id(&self) -> Result<&Id, DapAbort> { + /// Return the task ID, handling a missing ID as a user error. + pub fn task_id(&self) -> Result<&TaskId, DapAbort> { if let Some(ref id) = self.task_id { Ok(id) - } else { - // Handle missing task ID as decoding failure. + } else if self.version == DapVersion::Draft02 { + // draft02: Handle missing task ID as decoding failure. Normally the task ID would be + // encoded by the message payload; it may be missing becvause parsing failed earlier on + // in the request. Err(DapAbort::UnrecognizedMessage) + } else { + // Handle missing task ID as a bad request. The task ID is normally conveyed by the + // request path; if missing at this point, it is because it was missing or couldn't be + // parsed from the request path. + Err(DapAbort::BadRequest("missing or malformed task ID".into())) + } + } + + /// Return the collection job ID, handling a missing ID as a user error. + /// + /// Note: the semantics of this method is only well-defined if the caller is the Collector and + /// the version in use is not draft02. If the caller is not the Collector, or draft02 is in + /// use, we exepct the collection job ID to be missing. + pub fn collection_job_id(&self) -> Result<&CollectionJobId, DapAbort> { + if let DapResource::CollectionJob(ref collection_job_id) = self.resource { + Ok(collection_job_id) + } else { + Err(DapAbort::BadRequest( + "missing or malformed collection job ID".into(), + )) } } } @@ -881,7 +980,7 @@ pub struct DapResponse { #[derive(Debug, Deserialize, PartialEq, Eq, Serialize)] #[serde(rename_all = "snake_case")] pub enum DapCollectJob { - Done(CollectResp), + Done(Collection), Pending, Unknown, } @@ -901,6 +1000,64 @@ pub struct DapLeaderProcessTelemetry { pub reports_processed: u64, } +/// draft02 compatibility: A logical aggregation job ID. In the latest draft, this is a 32-byte +/// string included in the HTTP request payload; in draft04, this is a 16-byte string included in +/// the HTTP request path. This type unifies these into one type so that any protocol logic that +/// is agnostic to these details can use the same object. +pub enum MetaAggregationJobId<'a> { + Draft02(Cow<'a, Draft02AggregationJobId>), + Draft04(Cow<'a, AggregationJobId>), +} + +impl MetaAggregationJobId<'_> { + /// Generate a random ID of the type required for the version. + pub(crate) fn gen_for_version(version: &DapVersion) -> Self { + let mut rng = thread_rng(); + match version { + DapVersion::Draft02 => Self::Draft02(Cow::Owned(Draft02AggregationJobId(rng.gen()))), + DapVersion::Draft04 => Self::Draft04(Cow::Owned(AggregationJobId(rng.gen()))), + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + } + } + + /// Convert this aggregation job ID into to the type that would be included in the payload of + /// the HTTP request request. + pub(crate) fn for_request_payload(&self) -> Option { + match self { + Self::Draft02(agg_job_id) => Some(agg_job_id.clone().into_owned()), + Self::Draft04(..) => None, + } + } + + /// Convert this aggregation job ID into the type taht would be included in the HTTP request + /// path. + pub(crate) fn for_request_path(&self) -> DapResource { + match self { + // In draft02, the aggregation job ID is not determined until the payload is parsed. + Self::Draft02(..) => DapResource::Undefined, + Self::Draft04(agg_job_id) => { + DapResource::AggregationJob(agg_job_id.clone().into_owned()) + } + } + } + + /// Convert this aggregation job ID into hex. + pub fn to_hex(&self) -> String { + match self { + Self::Draft02(agg_job_id) => agg_job_id.to_hex(), + Self::Draft04(agg_job_id) => agg_job_id.to_hex(), + } + } + + /// Convert this aggregation job ID into base64url form. + pub fn to_base64url(&self) -> String { + match self { + Self::Draft02(agg_job_id) => agg_job_id.to_base64url(), + Self::Draft04(agg_job_id) => agg_job_id.to_base64url(), + } + } +} + pub mod auth; pub mod constants; pub mod hpke; diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index 7dec03845..d0bcac715 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -35,45 +35,75 @@ const FIXED_SIZE_QUERY_TYPE_CURRENT_BATCH: u8 = 0x01; // Known extension types. const EXTENSION_TASKPROV: u16 = 0xff00; -/// The identifier for a DAP task. -#[derive(Clone, Debug, Default, Deserialize, Hash, PartialEq, Eq, Serialize)] -pub struct Id(#[serde(with = "hex")] pub [u8; 32]); +// Serde doesn't support derivations from const generics properly, so we have to use a macro. +macro_rules! id_struct { + ($sname:ident, $len:expr, $doc:expr) => { + #[doc=$doc] + #[derive(Clone, Debug, Default, Deserialize, Hash, PartialEq, Eq, Serialize)] + pub struct $sname(#[serde(with = "hex")] pub [u8; $len]); + + impl $sname { + /// Return the URL-safe, base64 encoding of the ID. + pub fn to_base64url(&self) -> String { + encode_base64url(self.0) + } -impl Id { - /// Return the URL-safe, base64 encoding of the task ID. - pub fn to_base64url(&self) -> String { - encode_base64url(self.0) - } + /// Return the ID encoded as a hex string. + pub fn to_hex(&self) -> String { + hex::encode(self.0) + } - /// Return the ID encoded as a hex string. - pub fn to_hex(&self) -> String { - hex::encode(self.0) - } -} + /// Decode from URL-safe, base64. + pub fn try_from_base64url>(id_base64url: T) -> Option { + Some($sname(decode_base64url(id_base64url.as_ref())?)) + } + } -impl Encode for Id { - fn encode(&self, bytes: &mut Vec) { - bytes.extend_from_slice(&self.0); - } -} + impl Encode for $sname { + fn encode(&self, bytes: &mut Vec) { + bytes.extend_from_slice(&self.0); + } + } -impl Decode for Id { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { - let mut data = [0; 32]; - bytes.read_exact(&mut data[..])?; - Ok(Id(data)) - } -} + impl Decode for $sname { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let mut data = [0; $len]; + bytes.read_exact(&mut data[..])?; + Ok($sname(data)) + } + } -impl AsRef<[u8]> for Id { - fn as_ref(&self) -> &[u8] { - &self.0 - } + impl AsRef<[u8]> for $sname { + fn as_ref(&self) -> &[u8] { + &self.0 + } + } + + impl fmt::Display for $sname { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_hex()) + } + } + }; } -impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.to_hex()) +id_struct!(AggregationJobId, 16, "Aggregation Job ID"); +id_struct!(BatchId, 32, "Batch ID"); +id_struct!(CollectionJobId, 16, "Collection Job ID"); +id_struct!(Draft02AggregationJobId, 32, "Aggregation Job ID"); +id_struct!(ReportId, 16, "Report ID (draft02)"); +id_struct!(TaskId, 32, "Task ID"); + +impl TaskId { + /// draft02 compatibility: Convert the task ID to the field that would be added to the DAP + /// request for the given version. In draft02, the task ID is generally included in the HTTP + /// request payload; in draft04, the task ID is included in the HTTP request path. + pub fn for_request_payload(&self, version: &DapVersion) -> Option { + match version { + DapVersion::Draft02 => Some(self.clone()), + DapVersion::Draft04 => None, + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + } } } @@ -83,44 +113,6 @@ pub type Duration = u64; /// The timestamp sent in a [`Report`]. pub type Time = u64; -/// A report ID. -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Hash, Serialize)] -#[allow(missing_docs)] -pub struct ReportId(pub [u8; 16]); - -impl ReportId { - /// Return the ID encoded as a hex string. - pub fn to_hex(&self) -> String { - hex::encode(self.0) - } -} - -impl Encode for ReportId { - fn encode(&self, bytes: &mut Vec) { - bytes.extend_from_slice(&self.0); - } -} - -impl Decode for ReportId { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { - let mut id = [0; 16]; - bytes.read_exact(&mut id)?; - Ok(ReportId(id)) - } -} - -impl AsRef<[u8]> for ReportId { - fn as_ref(&self) -> &[u8] { - &self.0 - } -} - -impl fmt::Display for ReportId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.to_hex()) - } -} - /// Report extensions. #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] #[serde(rename_all = "snake_case")] @@ -172,7 +164,7 @@ pub struct ReportMetadata { pub id: ReportId, pub time: Time, - /// Report extensions, only used in draft-02. In draft-03 and above, extensions are carried in encrypted input share. + /// Report extensions, only used in draft02. In draft-03 and above, extensions are carried in encrypted input share. pub extensions: Vec, } @@ -183,7 +175,7 @@ impl ParameterizedEncode for ReportMetadata { if matches!(version, DapVersion::Draft02) { encode_u16_items(bytes, &(), &self.extensions); } else if !self.extensions.is_empty() { - panic!("tried to encode extensions in the ReportMetadata for DAP > draft-02") + panic!("tried to encode extensions in the ReportMetadata for DAP > draft02") } } } @@ -220,16 +212,22 @@ impl ParameterizedDecode for ReportMetadata { #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] #[allow(missing_docs)] pub struct Report { - pub task_id: Id, - pub metadata: ReportMetadata, + pub draft02_task_id: Option, // Set in draft02 + pub report_metadata: ReportMetadata, pub public_share: Vec, pub encrypted_input_shares: Vec, } impl ParameterizedEncode for Report { fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { - self.task_id.encode(bytes); - self.metadata.encode_with_param(version, bytes); + if *version == DapVersion::Draft02 { + if let Some(id) = &self.draft02_task_id { + id.encode(bytes); + } else { + unreachable!("draft02: tried to serialize Report with missing task ID"); + } + } + self.report_metadata.encode_with_param(version, bytes); encode_u32_bytes(bytes, &self.public_share); encode_u32_items(bytes, &(), &self.encrypted_input_shares); } @@ -240,28 +238,33 @@ impl ParameterizedDecode for Report { version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { + let draft02_task_id = if *version == DapVersion::Draft02 { + Some(TaskId::decode(bytes)?) + } else { + None + }; Ok(Self { - task_id: Id::decode(bytes)?, - metadata: ReportMetadata::decode_with_param(version, bytes)?, + draft02_task_id, + report_metadata: ReportMetadata::decode_with_param(version, bytes)?, public_share: decode_u32_bytes(bytes)?, encrypted_input_shares: decode_u32_items(&(), bytes)?, }) } } -/// An initial aggregate sub-request sent in an [`AggregateInitializeReq`]. The contents of this +/// An initial aggregate sub-request sent in an [`AggregationJobInitReq`]. The contents of this /// structure pertain to a single report. #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] #[allow(missing_docs)] pub struct ReportShare { - pub metadata: ReportMetadata, + pub report_metadata: ReportMetadata, pub public_share: Vec, pub encrypted_input_share: HpkeCiphertext, } impl ParameterizedEncode for ReportShare { fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { - self.metadata.encode_with_param(version, bytes); + self.report_metadata.encode_with_param(version, bytes); encode_u32_bytes(bytes, &self.public_share); self.encrypted_input_share.encode(bytes); } @@ -273,7 +276,7 @@ impl ParameterizedDecode for ReportShare { bytes: &mut Cursor<&[u8]>, ) -> Result { Ok(Self { - metadata: ReportMetadata::decode_with_param(version, bytes)?, + report_metadata: ReportMetadata::decode_with_param(version, bytes)?, public_share: decode_u32_bytes(bytes)?, encrypted_input_share: HpkeCiphertext::decode(bytes)?, }) @@ -281,12 +284,12 @@ impl ParameterizedDecode for ReportShare { } /// Batch parameter conveyed to the Helper by the Leader in the aggregation sub-protocol. Used to -/// identify which batch the reports in the [`AggregateInitializeReq`] are intended for. +/// identify which batch the reports in the [`AggregationJobInitReq`] are intended for. #[derive(Clone, Debug, Eq, Deserialize, Hash, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] pub enum PartialBatchSelector { TimeInterval, - FixedSizeByBatchId { batch_id: Id }, + FixedSizeByBatchId { batch_id: BatchId }, } impl From for PartialBatchSelector { @@ -315,7 +318,7 @@ impl Decode for PartialBatchSelector { match u8::decode(bytes)? { QUERY_TYPE_TIME_INTERVAL => Ok(Self::TimeInterval), QUERY_TYPE_FIXED_SIZE => Ok(Self::FixedSizeByBatchId { - batch_id: Id::decode(bytes)?, + batch_id: BatchId::decode(bytes)?, }), _ => Err(CodecError::UnexpectedValue), } @@ -327,7 +330,7 @@ impl Decode for PartialBatchSelector { #[serde(rename_all = "snake_case")] pub enum BatchSelector { TimeInterval { batch_interval: Interval }, - FixedSizeByBatchId { batch_id: Id }, + FixedSizeByBatchId { batch_id: BatchId }, } impl Encode for BatchSelector { @@ -352,7 +355,7 @@ impl Decode for BatchSelector { batch_interval: Interval::decode(bytes)?, }), QUERY_TYPE_FIXED_SIZE => Ok(Self::FixedSizeByBatchId { - batch_id: Id::decode(bytes)?, + batch_id: BatchId::decode(bytes)?, }), _ => Err(CodecError::UnexpectedValue), } @@ -383,41 +386,55 @@ impl TryFrom for BatchSelector { /// Aggregate initialization request. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct AggregateInitializeReq { - pub task_id: Id, - pub agg_job_id: Id, +pub struct AggregationJobInitReq { + pub draft02_task_id: Option, // Set in draft02 + pub draft02_agg_job_id: Option, // Set in draft02 pub agg_param: Vec, pub part_batch_sel: PartialBatchSelector, pub report_shares: Vec, } -impl ParameterizedEncode for AggregateInitializeReq { +impl ParameterizedEncode for AggregationJobInitReq { fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { - self.task_id.encode(bytes); - self.agg_job_id.encode(bytes); match version { - DapVersion::Draft02 => encode_u16_bytes(bytes, &self.agg_param), - DapVersion::Draft03 => encode_u32_bytes(bytes, &self.agg_param), - _ => unreachable!("unimplemented version"), + DapVersion::Draft02 => { + self.draft02_task_id + .as_ref() + .expect("draft02: missing task ID") + .encode(bytes); + self.draft02_agg_job_id + .as_ref() + .expect("draft02: missing aggregation job ID") + .encode(bytes); + encode_u16_bytes(bytes, &self.agg_param); + } + DapVersion::Draft04 => encode_u32_bytes(bytes, &self.agg_param), + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), }; self.part_batch_sel.encode(bytes); encode_u32_items(bytes, version, &self.report_shares); } } -impl ParameterizedDecode for AggregateInitializeReq { +impl ParameterizedDecode for AggregationJobInitReq { fn decode_with_param( version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { + let (draft02_task_id, draft02_agg_job_id, agg_param) = match version { + DapVersion::Draft02 => ( + Some(TaskId::decode(bytes)?), + Some(Draft02AggregationJobId::decode(bytes)?), + decode_u16_bytes(bytes)?, + ), + DapVersion::Draft04 => (None, None, decode_u32_bytes(bytes)?), + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + }; + Ok(Self { - task_id: Id::decode(bytes)?, - agg_job_id: Id::decode(bytes)?, - agg_param: match version { - DapVersion::Draft02 => decode_u16_bytes(bytes)?, - DapVersion::Draft03 => decode_u32_bytes(bytes)?, - _ => unreachable!("unimplemented version"), - }, + draft02_task_id, + draft02_agg_job_id, + agg_param, part_batch_sel: PartialBatchSelector::decode(bytes)?, report_shares: decode_u32_items(version, bytes)?, }) @@ -426,25 +443,56 @@ impl ParameterizedDecode for AggregateInitializeReq { /// Aggregate continuation request. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct AggregateContinueReq { - pub task_id: Id, - pub agg_job_id: Id, +pub struct AggregationJobContinueReq { + pub draft02_task_id: Option, // Set in draft02 + pub draft02_agg_job_id: Option, // Set in draft02 + pub round: Option, // Not set in draft02 pub transitions: Vec, } -impl Encode for AggregateContinueReq { - fn encode(&self, bytes: &mut Vec) { - self.task_id.encode(bytes); - self.agg_job_id.encode(bytes); +impl ParameterizedEncode for AggregationJobContinueReq { + fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { + match version { + DapVersion::Draft02 => { + self.draft02_task_id + .as_ref() + .expect("draft02: missing task ID") + .encode(bytes); + self.draft02_agg_job_id + .as_ref() + .expect("draft02: missing aggregation job ID") + .encode(bytes); + } + DapVersion::Draft04 => { + self.round + .as_ref() + .expect("draft04: missing round") + .encode(bytes); + } + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + }; encode_u32_items(bytes, &(), &self.transitions); } } -impl Decode for AggregateContinueReq { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { +impl ParameterizedDecode for AggregationJobContinueReq { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let (draft02_task_id, draft02_agg_job_id, round) = match version { + DapVersion::Draft02 => ( + Some(TaskId::decode(bytes)?), + Some(Draft02AggregationJobId::decode(bytes)?), + None, + ), + DapVersion::Draft04 => (None, None, Some(u16::decode(bytes)?)), + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + }; Ok(Self { - task_id: Id::decode(bytes)?, - agg_job_id: Id::decode(bytes)?, + draft02_task_id, + draft02_agg_job_id, + round, transitions: decode_u32_items(&(), bytes)?, }) } @@ -453,8 +501,7 @@ impl Decode for AggregateContinueReq { /// Transition message. This conveyes a message sent from one Aggregator to another during the /// preparation phase of VDAF evaluation. // -// TODO spec: This is called `PrepareStep` in draft-ietf-ppm-dap-03. This is confusing because it -// overloads a term used in draft-irtf-cfrg-draft-02. +// TODO Consider renaming this to `PrepareStep` to align with draft04. #[derive(Clone, Debug, PartialEq, Eq)] pub struct Transition { pub report_id: ReportId, @@ -582,17 +629,17 @@ impl std::fmt::Display for TransitionFailure { /// An aggregate response sent from the Helper to the Leader. #[derive(Debug, PartialEq, Eq, Default)] #[allow(missing_docs)] -pub struct AggregateResp { +pub struct AggregationJobResp { pub transitions: Vec, } -impl Encode for AggregateResp { +impl Encode for AggregationJobResp { fn encode(&self, bytes: &mut Vec) { encode_u32_items(bytes, &(), &self.transitions); } } -impl Decode for AggregateResp { +impl Decode for AggregationJobResp { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { Ok(Self { transitions: decode_u32_items(&(), bytes)?, @@ -636,7 +683,7 @@ impl Decode for Interval { #[serde(rename_all = "snake_case")] pub enum Query { TimeInterval { batch_interval: Interval }, - FixedSizeByBatchId { batch_id: Id }, + FixedSizeByBatchId { batch_id: BatchId }, FixedSizeCurrentBatch, } @@ -677,13 +724,13 @@ impl ParameterizedDecode for Query { QUERY_TYPE_FIXED_SIZE => { if *decoding_parameter == DapVersion::Draft02 { Ok(Self::FixedSizeByBatchId { - batch_id: Id::decode(bytes)?, + batch_id: BatchId::decode(bytes)?, }) } else { let subtype = u8::decode(bytes)?; match subtype { FIXED_SIZE_QUERY_TYPE_BY_BATCH_ID => Ok(Self::FixedSizeByBatchId { - batch_id: Id::decode(bytes)?, + batch_id: BatchId::decode(bytes)?, }), FIXED_SIZE_QUERY_TYPE_CURRENT_BATCH => Ok(Self::FixedSizeCurrentBatch), _ => Err(CodecError::UnexpectedValue), @@ -707,35 +754,49 @@ impl Default for Query { // // TODO Add serialization tests. #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] -pub struct CollectReq { - pub task_id: Id, +pub struct CollectionReq { + pub draft02_task_id: Option, // Set in draft02 pub query: Query, pub agg_param: Vec, } -impl ParameterizedEncode for CollectReq { +impl ParameterizedEncode for CollectionReq { fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { - self.task_id.encode(bytes); + match version { + DapVersion::Draft02 => { + self.draft02_task_id + .as_ref() + .expect("draft02: missing task ID") + .encode(bytes); + } + DapVersion::Draft04 => {} + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + } self.query.encode_with_param(version, bytes); match version { DapVersion::Draft02 => encode_u16_bytes(bytes, &self.agg_param), - DapVersion::Draft03 => encode_u32_bytes(bytes, &self.agg_param), + DapVersion::Draft04 => encode_u32_bytes(bytes, &self.agg_param), _ => panic!("unimplemented DapVersion"), }; } } -impl ParameterizedDecode for CollectReq { +impl ParameterizedDecode for CollectionReq { fn decode_with_param( - decoding_parameter: &DapVersion, + version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { + let draft02_task_id = match version { + DapVersion::Draft02 => Some(TaskId::decode(bytes)?), + DapVersion::Draft04 => None, + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + }; Ok(Self { - task_id: Id::decode(bytes)?, - query: Query::decode_with_param(decoding_parameter, bytes)?, - agg_param: match decoding_parameter { + draft02_task_id, + query: Query::decode_with_param(version, bytes)?, + agg_param: match version { DapVersion::Draft02 => decode_u16_bytes(bytes)?, - DapVersion::Draft03 => decode_u32_bytes(bytes)?, + DapVersion::Draft04 => decode_u32_bytes(bytes)?, _ => panic!("unimplemented DapVersion"), }, }) @@ -746,25 +807,44 @@ impl ParameterizedDecode for CollectReq { // // TODO Add serialization tests. #[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] -pub struct CollectResp { +pub struct Collection { pub part_batch_sel: PartialBatchSelector, pub report_count: u64, + pub interval: Option, // Not set in draft02 pub encrypted_agg_shares: Vec, } -impl Encode for CollectResp { - fn encode(&self, bytes: &mut Vec) { +impl ParameterizedEncode for Collection { + fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { self.part_batch_sel.encode(bytes); self.report_count.encode(bytes); + match version { + DapVersion::Draft02 => {} + DapVersion::Draft04 => { + self.interval + .as_ref() + .expect("draft04: missing interval") + .encode(bytes); + } + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + }; encode_u32_items(bytes, &(), &self.encrypted_agg_shares); } } -impl Decode for CollectResp { - fn decode(bytes: &mut Cursor<&[u8]>) -> Result { +impl ParameterizedDecode for Collection { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { Ok(Self { part_batch_sel: PartialBatchSelector::decode(bytes)?, report_count: u64::decode(bytes)?, + interval: match version { + DapVersion::Draft02 => None, + DapVersion::Draft04 => Some(Interval::decode(bytes)?), + _ => panic!("unimplemented DapVersion"), + }, encrypted_agg_shares: decode_u32_items(&(), bytes)?, }) } @@ -775,7 +855,7 @@ impl Decode for CollectResp { // TODO Add serialization tests. #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct AggregateShareReq { - pub task_id: Id, + pub draft02_task_id: Option, // Set in draft02 pub batch_sel: BatchSelector, pub agg_param: Vec, pub report_count: u64, @@ -784,12 +864,20 @@ pub struct AggregateShareReq { impl ParameterizedEncode for AggregateShareReq { fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { - self.task_id.encode(bytes); - self.batch_sel.encode_with_param(version, bytes); match version { - DapVersion::Draft02 => encode_u16_bytes(bytes, &self.agg_param), - DapVersion::Draft03 => encode_u32_bytes(bytes, &self.agg_param), - _ => panic!("unimplemented DapVersion"), + DapVersion::Draft02 => { + self.draft02_task_id + .as_ref() + .expect("draft02: missing task ID") + .encode(bytes); + self.batch_sel.encode_with_param(version, bytes); + encode_u16_bytes(bytes, &self.agg_param); + } + DapVersion::Draft04 => { + self.batch_sel.encode_with_param(version, bytes); + encode_u32_bytes(bytes, &self.agg_param); + } + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), }; self.report_count.encode(bytes); bytes.extend_from_slice(&self.checksum); @@ -798,17 +886,26 @@ impl ParameterizedEncode for AggregateShareReq { impl ParameterizedDecode for AggregateShareReq { fn decode_with_param( - decoding_parameter: &DapVersion, + version: &DapVersion, bytes: &mut Cursor<&[u8]>, ) -> Result { + let (draft02_task_id, batch_sel, agg_param) = match version { + DapVersion::Draft02 => ( + Some(TaskId::decode(bytes)?), + BatchSelector::decode_with_param(version, bytes)?, + decode_u16_bytes(bytes)?, + ), + DapVersion::Draft04 => ( + None, + BatchSelector::decode_with_param(version, bytes)?, + decode_u32_bytes(bytes)?, + ), + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + }; Ok(Self { - task_id: Id::decode(bytes)?, - batch_sel: BatchSelector::decode_with_param(decoding_parameter, bytes)?, - agg_param: match decoding_parameter { - DapVersion::Draft02 => decode_u16_bytes(bytes)?, - DapVersion::Draft03 => decode_u32_bytes(bytes)?, - _ => panic!("unimplemented DapVersion"), - }, + draft02_task_id, + batch_sel, + agg_param, report_count: u64::decode(bytes)?, checksum: { let mut checksum = [0u8; 32]; @@ -823,17 +920,17 @@ impl ParameterizedDecode for AggregateShareReq { // // TODO Add serialization tests. #[derive(Debug)] -pub struct AggregateShareResp { +pub struct AggregateShare { pub encrypted_agg_share: HpkeCiphertext, } -impl Encode for AggregateShareResp { +impl Encode for AggregateShare { fn encode(&self, bytes: &mut Vec) { self.encrypted_agg_share.encode(bytes); } } -impl Decode for AggregateShareResp { +impl Decode for AggregateShare { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { Ok(Self { encrypted_agg_share: HpkeCiphertext::decode(bytes)?, @@ -1112,7 +1209,9 @@ pub fn encode_base64url>(input: T) -> String { } /// Decode the input as a URL-safe, base64 encoding of an `OUT_LEN`-length byte string. -pub fn decode_base64url, const OUT_LEN: usize>(input: T) -> Option<[u8; OUT_LEN]> { +pub(crate) fn decode_base64url, const OUT_LEN: usize>( + input: T, +) -> Option<[u8; OUT_LEN]> { let mut bytes = [0; OUT_LEN]; // NOTE(cjpatton) It would be better to use `decode_slice` here, but this function uses a // conservative estimate of the decoded length (`decoded_len_estimate`). See diff --git a/daphne/src/messages/mod_test.rs b/daphne/src/messages/mod_test.rs index b76c64e72..38fdaff08 100644 --- a/daphne/src/messages/mod_test.rs +++ b/daphne/src/messages/mod_test.rs @@ -5,10 +5,11 @@ use crate::messages::taskprov::{ DpConfig, QueryConfig, QueryConfigVar, TaskConfig, UrlBytes, VdafConfig, VdafTypeVar, }; use crate::messages::{ - decode_base64url, decode_base64url_vec, encode_base64url, AggregateContinueReq, - AggregateInitializeReq, AggregateResp, AggregateShareReq, BatchSelector, DapVersion, Extension, - HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeKdfId, HpkeKemId, Id, PartialBatchSelector, Report, - ReportId, ReportMetadata, ReportShare, Transition, TransitionVar, + decode_base64url, decode_base64url_vec, encode_base64url, AggregateShareReq, + AggregationJobContinueReq, AggregationJobId, AggregationJobInitReq, AggregationJobResp, + BatchId, BatchSelector, CollectionJobId, DapVersion, Draft02AggregationJobId, Extension, + HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeKdfId, HpkeKemId, PartialBatchSelector, Report, + ReportId, ReportMetadata, ReportShare, TaskId, Transition, TransitionVar, }; use crate::taskprov::{compute_task_id, TaskprovVersion}; use crate::{test_version, test_versions}; @@ -17,13 +18,18 @@ use paste::paste; use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode}; use rand::prelude::*; +fn task_id_for_version(version: DapVersion) -> Option { + if version == DapVersion::Draft02 { + Some(TaskId([1; 32])) + } else { + None + } +} + fn read_report(version: DapVersion) { let report = Report { - task_id: Id([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, - ]), - metadata: ReportMetadata { + draft02_task_id: task_id_for_version(version), + report_metadata: ReportMetadata { id: ReportId([23; 16]), time: 1637364244, extensions: vec![], @@ -53,11 +59,8 @@ test_versions! {read_report} #[test] fn read_report_with_unknown_extensions_draft02() { let report = Report { - task_id: Id([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, 16, - ]), - metadata: ReportMetadata { + draft02_task_id: task_id_for_version(DapVersion::Draft02), + report_metadata: ReportMetadata { id: ReportId([23; 16]), time: 1637364244, extensions: vec![Extension::Unhandled { @@ -86,17 +89,17 @@ fn read_report_with_unknown_extensions_draft02() { } #[test] -fn read_agg_init_req() { - let want = AggregateInitializeReq { - task_id: Id([23; 32]), - agg_job_id: Id([1; 32]), +fn read_agg_job_init_req() { + let want = AggregationJobInitReq { + draft02_task_id: Some(TaskId([23; 32])), + draft02_agg_job_id: Some(Draft02AggregationJobId([1; 32])), agg_param: b"this is an aggregation parameter".to_vec(), part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { - batch_id: Id([0; 32]), + batch_id: BatchId([0; 32]), }, report_shares: vec![ ReportShare { - metadata: ReportMetadata { + report_metadata: ReportMetadata { id: ReportId([99; 16]), time: 1637361337, extensions: Vec::default(), @@ -109,7 +112,7 @@ fn read_agg_init_req() { }, }, ReportShare { - metadata: ReportMetadata { + report_metadata: ReportMetadata { id: ReportId([17; 16]), time: 163736423, extensions: Vec::default(), @@ -124,25 +127,89 @@ fn read_agg_init_req() { ], }; - let got = AggregateInitializeReq::get_decoded_with_param( + let got = AggregationJobInitReq::get_decoded_with_param( &crate::DapVersion::Draft02, &want.get_encoded_with_param(&crate::DapVersion::Draft02), ) .unwrap(); assert_eq!(got, want); - let got = AggregateInitializeReq::get_decoded_with_param( - &crate::DapVersion::Draft03, - &want.get_encoded_with_param(&crate::DapVersion::Draft03), + + let want = AggregationJobInitReq { + draft02_task_id: None, + draft02_agg_job_id: None, + agg_param: b"this is an aggregation parameter".to_vec(), + part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { + batch_id: BatchId([0; 32]), + }, + report_shares: vec![ + ReportShare { + report_metadata: ReportMetadata { + id: ReportId([99; 16]), + time: 1637361337, + extensions: Vec::default(), + }, + public_share: b"public share".to_vec(), + encrypted_input_share: HpkeCiphertext { + config_id: 23, + enc: b"encapsulated key".to_vec(), + payload: b"ciphertext".to_vec(), + }, + }, + ReportShare { + report_metadata: ReportMetadata { + id: ReportId([17; 16]), + time: 163736423, + extensions: Vec::default(), + }, + public_share: b"public share".to_vec(), + encrypted_input_share: HpkeCiphertext { + config_id: 0, + enc: vec![], + payload: b"ciphertext".to_vec(), + }, + }, + ], + }; + + let got = AggregationJobInitReq::get_decoded_with_param( + &DapVersion::Draft04, + &want.get_encoded_with_param(&DapVersion::Draft04), ) .unwrap(); assert_eq!(got, want); } #[test] -fn read_agg_cont_req() { - let want = AggregateContinueReq { - task_id: Id([23; 32]), - agg_job_id: Id([1; 32]), +fn read_agg_job_cont_req() { + let want = AggregationJobContinueReq { + draft02_task_id: Some(TaskId([23; 32])), + draft02_agg_job_id: Some(Draft02AggregationJobId([1; 32])), + round: None, + transitions: vec![ + Transition { + report_id: ReportId([0; 16]), + var: TransitionVar::Continued(b"this is a VDAF-specific message".to_vec()), + }, + Transition { + report_id: ReportId([1; 16]), + var: TransitionVar::Continued( + b"believe it or not this is *also* a VDAF-specific message".to_vec(), + ), + }, + ], + }; + + let got = AggregationJobContinueReq::get_decoded_with_param( + &DapVersion::Draft02, + &want.get_encoded_with_param(&DapVersion::Draft02), + ) + .unwrap(); + assert_eq!(got, want); + + let want = AggregationJobContinueReq { + draft02_task_id: None, + draft02_agg_job_id: None, + round: Some(1), transitions: vec![ Transition { report_id: ReportId([0; 16]), @@ -157,16 +224,20 @@ fn read_agg_cont_req() { ], }; - let got = AggregateContinueReq::get_decoded(&want.get_encoded()).unwrap(); + let got = AggregationJobContinueReq::get_decoded_with_param( + &DapVersion::Draft04, + &want.get_encoded_with_param(&DapVersion::Draft04), + ) + .unwrap(); assert_eq!(got, want); } #[test] fn read_agg_share_req() { let want = AggregateShareReq { - task_id: Id([23; 32]), + draft02_task_id: Some(TaskId([23; 32])), batch_sel: BatchSelector::FixedSizeByBatchId { - batch_id: Id([23; 32]), + batch_id: BatchId([23; 32]), }, agg_param: b"this is an aggregation parameter".to_vec(), report_count: 100, @@ -179,17 +250,27 @@ fn read_agg_share_req() { ) .unwrap(); assert_eq!(got, want); + + let want = AggregateShareReq { + draft02_task_id: None, + batch_sel: BatchSelector::FixedSizeByBatchId { + batch_id: BatchId([23; 32]), + }, + agg_param: b"this is an aggregation parameter".to_vec(), + report_count: 100, + checksum: [0; 32], + }; let got = AggregateShareReq::get_decoded_with_param( - &DapVersion::Draft03, - &want.get_encoded_with_param(&DapVersion::Draft03), + &DapVersion::Draft04, + &want.get_encoded_with_param(&DapVersion::Draft04), ) .unwrap(); assert_eq!(got, want); } #[test] -fn read_agg_resp() { - let want = AggregateResp { +fn read_agg_job_resp() { + let want = AggregationJobResp { transitions: vec![ Transition { report_id: ReportId([22; 16]), @@ -204,7 +285,7 @@ fn read_agg_resp() { ], }; - let got = AggregateResp::get_decoded(&want.get_encoded()).unwrap(); + let got = AggregationJobResp::get_decoded(&want.get_encoded()).unwrap(); assert_eq!(got, want); } @@ -326,3 +407,33 @@ fn test_base64url() { assert_eq!(decode_base64url(encode_base64url(id)).unwrap(), id); assert_eq!(decode_base64url_vec(encode_base64url(id)).unwrap(), id); } + +#[test] +fn roundtrip_id_base64url() { + let id = AggregationJobId([7; 16]); + assert_eq!( + AggregationJobId::try_from_base64url(id.to_base64url()).unwrap(), + id + ); + + let id = BatchId([7; 32]); + assert_eq!(BatchId::try_from_base64url(id.to_base64url()).unwrap(), id); + + let id = CollectionJobId([7; 16]); + assert_eq!( + CollectionJobId::try_from_base64url(id.to_base64url()).unwrap(), + id + ); + + let id = Draft02AggregationJobId([13; 32]); + assert_eq!( + Draft02AggregationJobId::try_from_base64url(id.to_base64url()).unwrap(), + id + ); + + let id = ReportId([7; 16]); + assert_eq!(ReportId::try_from_base64url(id.to_base64url()).unwrap(), id); + + let id = TaskId([7; 32]); + assert_eq!(TaskId::try_from_base64url(id.to_base64url()).unwrap(), id); +} diff --git a/daphne/src/roles.rs b/daphne/src/roles.rs index 1dbeba098..a2fe260e1 100644 --- a/daphne/src/roles.rs +++ b/daphne/src/roles.rs @@ -5,25 +5,27 @@ use crate::{ constants::{ + versioned_media_type_for, DRAFT02_MEDIA_TYPE_AGG_CONT_REQ, DRAFT02_MEDIA_TYPE_AGG_INIT_REQ, DRAFT02_MEDIA_TYPE_HPKE_CONFIG, MEDIA_TYPE_AGG_CONT_REQ, MEDIA_TYPE_AGG_CONT_RESP, MEDIA_TYPE_AGG_INIT_REQ, MEDIA_TYPE_AGG_INIT_RESP, MEDIA_TYPE_AGG_SHARE_REQ, MEDIA_TYPE_AGG_SHARE_RESP, MEDIA_TYPE_HPKE_CONFIG_LIST, }, hpke::HpkeDecrypter, messages::{ - constant_time_eq, decode_base64url, AggregateContinueReq, AggregateInitializeReq, - AggregateResp, AggregateShareReq, AggregateShareResp, BatchSelector, CollectReq, - CollectResp, HpkeConfigList, Id, PartialBatchSelector, Query, Report, ReportId, - ReportMetadata, Time, TransitionFailure, TransitionVar, + constant_time_eq, decode_base64url, AggregateShare, AggregateShareReq, + AggregationJobContinueReq, AggregationJobInitReq, AggregationJobResp, BatchId, + BatchSelector, Collection, CollectionJobId, CollectionReq, HpkeConfigList, Interval, + PartialBatchSelector, Query, Report, ReportId, ReportMetadata, TaskId, Time, + TransitionFailure, TransitionVar, }, metrics::DaphneMetrics, DapAbort, DapAggregateShare, DapCollectJob, DapError, DapGlobalConfig, DapHelperState, DapHelperTransition, DapLeaderProcessTelemetry, DapLeaderTransition, DapOutputShare, - DapQueryConfig, DapRequest, DapResponse, DapTaskConfig, DapVersion, + DapQueryConfig, DapRequest, DapResource, DapResponse, DapTaskConfig, DapVersion, + MetaAggregationJobId, }; use async_trait::async_trait; use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode}; -use rand::prelude::*; use std::borrow::Cow; use std::collections::HashMap; use tracing::debug; @@ -35,7 +37,7 @@ pub trait DapAuthorizedSender { /// Add authorization to an outbound DAP request with the given task ID, media type, and payload. async fn authorize( &self, - task_id: &Id, + task_id: &TaskId, media_type: &'static str, payload: &[u8], ) -> Result; @@ -74,14 +76,14 @@ where async fn get_task_config_considering_taskprov( &'srv self, version: DapVersion, - task_id: Cow<'req, Id>, + task_id: Cow<'req, TaskId>, report: Option<&ReportMetadata>, ) -> Result, DapError>; /// Look up the DAP task configuration for the given task ID. async fn get_task_config_for( &'srv self, - task_id: Cow<'req, Id>, + task_id: Cow<'req, TaskId>, ) -> Result, DapError> { // We use DapVersion::Unknown here as we don't know it and we don't need to // know it as we will not be doing any taskprov task creation. @@ -96,18 +98,18 @@ where /// batch. async fn is_batch_overlapping( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result; /// Check whether the given batch ID has been observed before. This is called by the Leader /// (resp. Helper) in response to a CollectReq (resp. AggregateShareReq) for fixed-size tasks. - async fn batch_exists(&self, task_id: &Id, batch_id: &Id) -> Result; + async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result; /// Store a set of output shares. async fn put_out_shares( &self, - task_id: &Id, + task_id: &TaskId, part_batch_sel: &PartialBatchSelector, out_shares: Vec, ) -> Result<(), DapError>; @@ -115,7 +117,7 @@ where /// Fetch the aggregate share for the given batch. async fn get_agg_share( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result; @@ -124,14 +126,17 @@ where /// report being collected, etc. async fn check_early_reject<'b>( &self, - task_id: &Id, + task_id: &TaskId, part_batch_sel: &'b PartialBatchSelector, report_meta: impl Iterator, ) -> Result, DapError>; /// Mark a batch as collected. - async fn mark_collected(&self, task_id: &Id, batch_sel: &BatchSelector) - -> Result<(), DapError>; + async fn mark_collected( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result<(), DapError>; /// Handle HTTP GET to `/hpke_config?task_id=`. async fn http_get_hpke_config( @@ -154,7 +159,7 @@ where "failed to parse query parameter as URL-safe Base64".into(), ))?; - id = Some(Id(bytes)) + id = Some(TaskId(bytes)) } let hpke_config = self.get_hpke_config_for(req.version, id.as_ref()).await?; @@ -176,7 +181,7 @@ where media_type: Some(DRAFT02_MEDIA_TYPE_HPKE_CONFIG), payload: hpke_config.as_ref().get_encoded(), }), - DapVersion::Draft03 => { + DapVersion::Draft04 => { let hpke_config_list = HpkeConfigList { hpke_configs: vec![hpke_config.as_ref().clone()], }; @@ -191,7 +196,7 @@ where } } - async fn current_batch(&self, task_id: &Id) -> std::result::Result; + async fn current_batch(&self, task_id: &TaskId) -> Result; /// Access the Prometheus metrics. fn metrics(&self) -> &DaphneMetrics; @@ -204,7 +209,9 @@ macro_rules! leader_post { $task_config:expr, $path:expr, $media_type:expr, - $req_data:expr + $resource:expr, + $req_data:expr, + $is_put:expr ) => {{ let url = $task_config .helper_url @@ -214,11 +221,16 @@ macro_rules! leader_post { version: $task_config.version.clone(), media_type: Some($media_type), task_id: Some($task_id.clone()), + resource: $resource, payload: $req_data, url, sender_auth: Some($role.authorize(&$task_id, $media_type, &$req_data).await?), }; - $role.send_http_post(req).await? + if $is_put { + $role.send_http_put(req).await? + } else { + $role.send_http_post(req).await? + } }}; } @@ -232,42 +244,52 @@ where type ReportSelector; /// Store a report for use later on. - async fn put_report(&self, report: &Report) -> Result<(), DapError>; + async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError>; /// Fetch a sequence of reports to aggregate, grouped by task ID, then by partial batch /// selector. The reports returned are removed from persistent storage. async fn get_reports( &self, selector: &Self::ReportSelector, - ) -> Result>>, DapError>; + ) -> Result>>, DapError>; /// Create a collect job. // // TODO spec: Figure out if the hostname for the collect URI needs to match the Leader. - async fn init_collect_job(&self, collect_req: &CollectReq) -> Result; + async fn init_collect_job( + &self, + task_id: &TaskId, + collect_job_id: &Option, + collect_req: &CollectionReq, + ) -> Result; /// Check the status of a collect job. async fn poll_collect_job( &self, - task_id: &Id, - collect_id: &Id, + task_id: &TaskId, + collect_id: &CollectionJobId, ) -> Result; /// Fetch the current collect job queue. The result is the sequence of collect ID and request /// pairs, in order of priority. - async fn get_pending_collect_jobs(&self) -> Result, DapError>; + async fn get_pending_collect_jobs( + &self, + ) -> Result, DapError>; /// Complete a collect job by assigning it the completed [`CollectResp`](crate::messages::CollectResp). async fn finish_collect_job( &self, - task_id: &Id, - collect_id: &Id, - collect_resp: &CollectResp, + task_id: &TaskId, + collect_id: &CollectionJobId, + collect_resp: &Collection, ) -> Result<(), DapError>; /// Send an HTTP POST request. async fn send_http_post(&self, req: DapRequest) -> Result; + /// Send an HTTP PUT request. + async fn send_http_put(&self, req: DapRequest) -> Result; + /// Handle HTTP POST to `/upload`. The input is the encoded report sent in the body of the HTTP /// request. async fn http_post_upload(&'srv self, req: &'req DapRequest) -> Result<(), DapAbort> { @@ -279,12 +301,12 @@ where } let report = Report::get_decoded_with_param(&req.version, req.payload.as_ref())?; - debug!("report id is {}", report.metadata.id); + debug!("report id is {}", report.report_metadata.id); let task_config = self .get_task_config_considering_taskprov( req.version, Cow::Borrowed(req.task_id()?), - Some(&report.metadata), + Some(&report.report_metadata), ) .await? .ok_or(DapAbort::UnrecognizedTask)?; @@ -303,28 +325,29 @@ where // // TODO spec: It's not clear if this behavior is MUST, SHOULD, or MAY. if !self - .can_hpke_decrypt(&report.task_id, report.encrypted_input_shares[0].config_id) + .can_hpke_decrypt(req.task_id()?, report.encrypted_input_shares[0].config_id) .await? { return Err(DapAbort::UnrecognizedHpkeConfig); } // Check that the task has not expired. - if report.metadata.time >= task_config.as_ref().expiration { + if report.report_metadata.time >= task_config.as_ref().expiration { return Err(DapAbort::ReportTooLate); } // Store the report for future processing. At this point, the report may be rejected if // the Leader detects that the report was replayed or pertains to a batch that has already // been collected. - Ok(self.put_report(&report).await?) + Ok(self.put_report(&report, req.task_id()?).await?) } /// Handle HTTP POST to `/collect`. The input is a [`CollectReq`](crate::messages::CollectReq). /// The return value is a URI that the Collector can poll later on to get the corresponding /// [`CollectResp`](crate::messages::CollectResp). async fn http_post_collect(&'srv self, req: &'req DapRequest) -> Result { - debug!("collect for task {}", req.task_id()?); + let task_id = req.task_id()?; + debug!("collect for task {task_id}"); let now = self.get_current_time(); // Check whether the DAP version indicated by the sender is supported. @@ -338,7 +361,7 @@ where } let mut collect_req = - CollectReq::get_decoded_with_param(&req.version, req.payload.as_ref())?; + CollectionReq::get_decoded_with_param(&req.version, req.payload.as_ref())?; let wrapped_task_config = self .get_task_config_for(Cow::Borrowed(req.task_id()?)) .await? @@ -360,7 +383,7 @@ where // batches for a task to be collected concurrently for the same task, // we'd need a more complex DO state that allowed us to have batch // state go from unassigned -> in-progress -> complete. - let batch_id = self.current_batch(req.task_id()?).await?; + let batch_id = self.current_batch(task_id).await?; debug!("FixedSize batch id is {batch_id}"); collect_req.query = Query::FixedSizeByBatchId { batch_id }; } @@ -371,32 +394,46 @@ where check_batch( self, task_config, - &collect_req.task_id, + task_id, &batch_selector, &collect_req.agg_param, now, ) .await?; - Ok(self.init_collect_job(&collect_req).await?) + // draft02 compatibility: In draft02, the collection job ID is generated as a result of the + // initial collection request, whereas in the latest draft, the collection job ID is parsed + // from the request path. + let collect_job_id = match (req.version, &req.resource) { + (DapVersion::Draft02, DapResource::Undefined) => None, + (DapVersion::Draft04, DapResource::CollectionJob(ref collect_job_id)) => { + Some(collect_job_id.clone()) + } + (DapVersion::Draft04, DapResource::Undefined) => { + return Err(DapAbort::BadRequest("undefined resource".into())); + } + _ => unreachable!("unhandled resource {:?}", req.resource), + }; + + Ok(self + .init_collect_job(task_id, &collect_job_id, &collect_req) + .await?) } /// Run the aggregation sub-protocol for the given set of reports. Return the number of reports /// that were aggregated successfully. // // TODO Handle non-encodable messages gracefully. The length of `reports` may be too long to - // encode in `AggregateInitializeReq`, in which case this method will panic. We should increase + // encode in `AggregationJobInitReq`, in which case this method will panic. We should increase // the capacity of this message in the spec. In the meantime, we should at a minimum log this // when it happens. async fn run_agg_job( &self, - task_id: &Id, + task_id: &TaskId, task_config: &DapTaskConfig, part_batch_sel: &PartialBatchSelector, reports: Vec, ) -> Result { - let mut rng = thread_rng(); - // Filter out early rejected reports. // // TODO Add a test similar to http_post_aggregate_init_expired_task() in roles_test.rs that @@ -406,13 +443,13 @@ where .check_early_reject( task_id, part_batch_sel, - reports.iter().map(|report| &report.metadata), + reports.iter().map(|report| &report.report_metadata), ) .await?; let reports = reports .into_iter() .filter(|report| { - if let Some(failure) = early_rejects.get(&report.metadata.id) { + if let Some(failure) = early_rejects.get(&report.report_metadata.id) { self.metrics() .report_counter .with_label_values(&[&format!("rejected_{failure}")]) @@ -423,11 +460,11 @@ where }) .collect(); - // Prepare AggregateInitializeReq. - let agg_job_id = Id(rng.gen()); + // Prepare AggregationJobInitReq. + let agg_job_id = MetaAggregationJobId::gen_for_version(&task_config.version); let transition = task_config .vdaf - .produce_agg_init_req( + .produce_agg_job_init_req( self, task_id, task_config, @@ -437,36 +474,49 @@ where self.metrics(), ) .await?; - let (state, agg_init_req) = match transition { - DapLeaderTransition::Continue(state, agg_init_req) => (state, agg_init_req), + let (state, agg_job_init_req) = match transition { + DapLeaderTransition::Continue(state, agg_job_init_req) => (state, agg_job_init_req), DapLeaderTransition::Skip => return Ok(0), DapLeaderTransition::Uncommitted(..) => { return Err(DapError::fatal("unexpected state transition (uncommitted)").into()) } }; + let is_put = task_config.version != DapVersion::Draft02; + let url_path = if task_config.version == DapVersion::Draft02 { + "aggregate".to_string() + } else { + format!( + "tasks/{}/aggregation_jobs/{}", + task_id.to_base64url(), + agg_job_id.to_base64url() + ) + }; - // Send AggregateInitializeReq and receive AggregateResp. + // Send AggregationJobInitReq and receive AggregationJobResp. let resp = leader_post!( self, task_id, task_config, - "aggregate", - MEDIA_TYPE_AGG_INIT_REQ, - agg_init_req.get_encoded_with_param(&task_config.version) + &url_path, + versioned_media_type_for(&task_config.version, MEDIA_TYPE_AGG_INIT_REQ).unwrap(), + agg_job_id.for_request_path(), + agg_job_init_req.get_encoded_with_param(&task_config.version), + is_put ); - let agg_resp = AggregateResp::get_decoded(&resp.payload)?; + let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload)?; // Prepare AggreagteContinueReq. - let transition = task_config.vdaf.handle_agg_resp( + let transition = task_config.vdaf.handle_agg_job_resp( task_id, &agg_job_id, state, - agg_resp, + agg_job_resp, + task_config.version, self.metrics(), )?; - let (uncommited, agg_cont_req) = match transition { - DapLeaderTransition::Uncommitted(uncommited, agg_cont_req) => { - (uncommited, agg_cont_req) + let (uncommited, agg_job_cont_req) = match transition { + DapLeaderTransition::Uncommitted(uncommited, agg_job_cont_req) => { + (uncommited, agg_job_cont_req) } DapLeaderTransition::Skip => return Ok(0), DapLeaderTransition::Continue(..) => { @@ -474,22 +524,24 @@ where } }; - // Send AggregateContinueReq and receive AggregateResp. + // Send AggregationJobContinueReq and receive AggregationJobResp. let resp = leader_post!( self, task_id, task_config, - "aggregate", + &url_path, MEDIA_TYPE_AGG_CONT_REQ, - agg_cont_req.get_encoded() + agg_job_id.for_request_path(), + agg_job_cont_req.get_encoded_with_param(&task_config.version), + false ); - let agg_resp = AggregateResp::get_decoded(&resp.payload)?; + let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload)?; // Commit the output shares. let out_shares = task_config .vdaf - .handle_final_agg_resp(uncommited, agg_resp, self.metrics())?; + .handle_final_agg_job_resp(uncommited, agg_job_resp, self.metrics())?; let out_shares_count = out_shares.len() as u64; self.put_out_shares(task_id, part_batch_sel, out_shares) .await?; @@ -507,15 +559,14 @@ where /// reports in the batch. async fn run_collect_job( &self, - collect_id: &Id, + task_id: &TaskId, + collect_id: &CollectionJobId, task_config: &DapTaskConfig, - collect_req: &CollectReq, + collect_req: &CollectionReq, ) -> Result { debug!("collecting id {collect_id}"); let batch_selector = BatchSelector::try_from(collect_req.query.clone())?; - let leader_agg_share = self - .get_agg_share(&collect_req.task_id, &batch_selector) - .await?; + let leader_agg_share = self.get_agg_share(task_id, &batch_selector).await?; // Check the batch size. If not not ready, then return early. // @@ -529,7 +580,7 @@ where // Prepare the Leader's aggregate share. let leader_enc_agg_share = task_config.vdaf.produce_leader_encrypted_agg_share( &task_config.collector_hpke_config, - &collect_req.task_id, + task_id, &batch_selector, &leader_agg_share, task_config.version, @@ -537,35 +588,63 @@ where // Prepare AggregateShareReq. let agg_share_req = AggregateShareReq { - task_id: collect_req.task_id.clone(), + draft02_task_id: task_id.for_request_payload(&task_config.version), batch_sel: batch_selector.clone(), agg_param: collect_req.agg_param.clone(), report_count: leader_agg_share.report_count, checksum: leader_agg_share.checksum, }; + let url_path = if task_config.version == DapVersion::Draft02 { + "aggregate_share".to_string() + } else { + format!("tasks/{}/aggregate_shares", task_id.to_base64url()) + }; + // Send AggregateShareReq and receive AggregateShareResp. let resp = leader_post!( self, - &collect_req.task_id, + task_id, task_config, - "aggregate_share", + &url_path, MEDIA_TYPE_AGG_SHARE_REQ, - agg_share_req.get_encoded_with_param(&task_config.version) + DapResource::Undefined, + agg_share_req.get_encoded_with_param(&task_config.version), + false ); - let agg_share_resp = AggregateShareResp::get_decoded(&resp.payload)?; + let agg_share_resp = AggregateShare::get_decoded(&resp.payload)?; + // For draft04 and later, the Collection message includes the smallest quantized time + // interval containing all reports in the batch. + let interval = match task_config.version { + DapVersion::Draft02 => None, + DapVersion::Draft04 => { + let low = task_config.quantized_time_lower_bound(leader_agg_share.min_time); + let high = task_config.quantized_time_upper_bound(leader_agg_share.max_time); + Some(Interval { + start: low, + duration: if high > low { + high - low + } else { + // This should never happen! + task_config.time_precision + }, + }) + } + _ => unreachable!("unhandled version {}", task_config.version), + }; // Complete the collect job. - let collect_resp = CollectResp { + let collection = Collection { part_batch_sel: batch_selector.into(), report_count: leader_agg_share.report_count, + interval, encrypted_agg_shares: vec![leader_enc_agg_share, agg_share_resp.encrypted_agg_share], }; - self.finish_collect_job(&collect_req.task_id, collect_id, &collect_resp) + self.finish_collect_job(task_id, collect_id, &collection) .await?; // Mark reports as collected. - self.mark_collected(&agg_share_req.task_id, &agg_share_req.batch_sel) + self.mark_collected(task_id, &agg_share_req.batch_sel) .await?; self.metrics() @@ -612,19 +691,18 @@ where } } } - // Process pending collect jobs. We wait until all aggregation jobs are finished before // proceeding to this step. This is to prevent a race condition involving an aggregate // share computed during a collect job and any output shares computed during an aggregation // job. - for (collect_id, collect_req) in self.get_pending_collect_jobs().await? { + for (task_id, collect_id, collect_req) in self.get_pending_collect_jobs().await? { let task_config = self - .get_task_config_for(Cow::Owned(collect_req.task_id.clone())) + .get_task_config_for(Cow::Owned(task_id.clone())) .await? .ok_or(DapAbort::UnrecognizedTask)?; telem.reports_collected += self - .run_collect_job(&collect_id, task_config.as_ref(), &collect_req) + .run_collect_job(&task_id, &collect_id, task_config.as_ref(), &collect_req) .await?; } @@ -641,8 +719,8 @@ where /// Store the Helper's aggregation-flow state. async fn put_helper_state( &self, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, helper_state: &DapHelperState, ) -> Result<(), DapError>; @@ -650,12 +728,12 @@ where /// associated with the given task and aggregation job. async fn get_helper_state( &self, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, ) -> Result, DapError>; - /// Handle an HTTP POST to `/aggregate`. The input is either an AggregateInitializeReq or - /// AggregateContinueReq and the response is an AggregateResp. + /// Handle an HTTP POST to `/aggregate`. The input is either an AggregationJobInitReq or + /// AggregationJobContinueReq and the response is an AggregationJobResp. /// /// This is called during the Initialization and Continuation phases. async fn http_post_aggregate( @@ -668,14 +746,16 @@ where } if !self.authorized(req).await? { - debug!("aborted unathorized aggregate request"); + debug!("aborted unauthorized aggregate request"); return Err(DapAbort::UnauthorizedRequest); } + let task_id = req.task_id()?; + match req.media_type { - Some(MEDIA_TYPE_AGG_INIT_REQ) => { - let agg_init_req = - AggregateInitializeReq::get_decoded_with_param(&req.version, &req.payload)?; + Some(MEDIA_TYPE_AGG_INIT_REQ) | Some(DRAFT02_MEDIA_TYPE_AGG_INIT_REQ) => { + let agg_job_init_req = + AggregationJobInitReq::get_decoded_with_param(&req.version, &req.payload)?; let mut first_metadata: Option<&ReportMetadata> = None; @@ -683,24 +763,23 @@ where // do (section 6 of draft-wang-ppm-dap-taskprov-02). let global_config = self.get_global_config(); if global_config.allow_taskprov { - let task_id = req.task_id()?; - let using_taskprov = agg_init_req + let using_taskprov = agg_job_init_req .report_shares .iter() .filter(|share| { share - .metadata + .report_metadata .is_taskprov(global_config.taskprov_version, task_id) }) .count(); - if using_taskprov == agg_init_req.report_shares.len() { + if using_taskprov == agg_job_init_req.report_shares.len() { // All the extensions use taskprov and look ok, so compute first_metadata. // Note this will always be Some(). - first_metadata = agg_init_req + first_metadata = agg_job_init_req .report_shares .first() - .map(|report_share| &report_share.metadata); + .map(|report_share| &report_share.report_metadata); } else if using_taskprov != 0 { // It's not all taskprov or no taskprov, so it's an error. return Err(DapAbort::UnrecognizedMessage); @@ -710,14 +789,34 @@ where let wrapped_task_config = self .get_task_config_considering_taskprov( req.version, - Cow::Borrowed(req.task_id()?), + Cow::Borrowed(task_id), first_metadata, ) .await? .ok_or(DapAbort::UnrecognizedTask)?; let task_config = wrapped_task_config.as_ref(); - let helper_state = - self.get_helper_state(&agg_init_req.task_id, &agg_init_req.agg_job_id); + + // draft02 compatibility: In draft02, the aggregation job ID is parsed from the + // HTTP request payload; in the latest draft, the aggregation job ID is parsed from + // the request path. + let agg_job_id = match ( + req.version, + &req.resource, + &agg_job_init_req.draft02_agg_job_id, + ) { + (DapVersion::Draft02, DapResource::Undefined, Some(ref agg_job_id)) => { + MetaAggregationJobId::Draft02(Cow::Borrowed(agg_job_id)) + } + (DapVersion::Draft04, DapResource::AggregationJob(ref agg_job_id), None) => { + MetaAggregationJobId::Draft04(Cow::Borrowed(agg_job_id)) + } + (DapVersion::Draft04, DapResource::Undefined, None) => { + return Err(DapAbort::BadRequest("undefined resource".into())); + } + _ => unreachable!("unhandled resource {:?}", req.resource), + }; + + let helper_state = self.get_helper_state(task_id, &agg_job_id); // Check whether the DAP version in the request matches the task config. if task_config.version != req.version { @@ -727,25 +826,32 @@ where // Ensure we know which batch the request pertains to. check_part_batch( task_config, - &agg_init_req.part_batch_sel, - &agg_init_req.agg_param, + &agg_job_init_req.part_batch_sel, + &agg_job_init_req.agg_param, )?; let early_rejects_future = self.check_early_reject( - &agg_init_req.task_id, - &agg_init_req.part_batch_sel, - agg_init_req + task_id, + &agg_job_init_req.part_batch_sel, + agg_job_init_req .report_shares .iter() - .map(|report_share| &report_share.metadata), + .map(|report_share| &report_share.report_metadata), ); let transition = task_config .vdaf - .handle_agg_init_req(self, task_config, &agg_init_req, self.metrics()) + .handle_agg_job_init_req( + self, + task_id, + task_config, + &agg_job_init_req, + self.metrics(), + ) .await?; - // Check that helper state with task_id and agg_job_id does not exist. + // Check that helper state with the given task ID and aggregation job ID does not + // exist. if helper_state.await?.is_some() { // TODO spec: Consider an explicit abort for this case. return Err(DapAbort::BadRequest( @@ -753,12 +859,12 @@ where )); } - let agg_resp = match transition { - DapHelperTransition::Continue(mut state, mut agg_resp) => { + let agg_job_resp = match transition { + DapHelperTransition::Continue(mut state, mut agg_job_resp) => { // Filter out early rejected reports. let early_rejects = early_rejects_future.await?; let mut state_index = 0; - for transition in agg_resp.transitions.iter_mut() { + for transition in agg_job_resp.transitions.iter_mut() { let early_failure = early_rejects.get(&transition.report_id); if !matches!(transition.var, TransitionVar::Failed(..)) && early_failure.is_some() @@ -796,13 +902,8 @@ where } } - self.put_helper_state( - &agg_init_req.task_id, - &agg_init_req.agg_job_id, - &state, - ) - .await?; - agg_resp + self.put_helper_state(task_id, &agg_job_id, &state).await?; + agg_job_resp } DapHelperTransition::Finish(..) => { return Err(DapError::fatal("unexpected transition (finished)").into()); @@ -812,14 +913,15 @@ where self.metrics().aggregation_job_gauge.inc(); Ok(DapResponse { - media_type: Some(MEDIA_TYPE_AGG_INIT_RESP), - payload: agg_resp.get_encoded(), + media_type: versioned_media_type_for(&req.version, MEDIA_TYPE_AGG_INIT_RESP), + payload: agg_job_resp.get_encoded(), }) } - Some(MEDIA_TYPE_AGG_CONT_REQ) => { - let agg_cont_req = AggregateContinueReq::get_decoded(&req.payload)?; + Some(MEDIA_TYPE_AGG_CONT_REQ) | Some(DRAFT02_MEDIA_TYPE_AGG_CONT_REQ) => { + let agg_job_cont_req = + AggregationJobContinueReq::get_decoded_with_param(&req.version, &req.payload)?; let wrapped_task_config = self - .get_task_config_for(Cow::Borrowed(req.task_id()?)) + .get_task_config_for(Cow::Borrowed(task_id)) .await? .ok_or(DapAbort::UnrecognizedTask)?; let task_config = wrapped_task_config.as_ref(); @@ -829,25 +931,46 @@ where return Err(DapAbort::InvalidProtocolVersion); } + // draft02 compatibility: In draft02, the aggregation job ID is parsed from the + // HTTP request payload; in the latest, the aggregation job ID is parsed from the + // request path. + let agg_job_id = match ( + req.version, + &req.resource, + &agg_job_cont_req.draft02_agg_job_id, + ) { + (DapVersion::Draft02, DapResource::Undefined, Some(ref agg_job_id)) => { + MetaAggregationJobId::Draft02(Cow::Borrowed(agg_job_id)) + } + (DapVersion::Draft04, DapResource::AggregationJob(ref agg_job_id), None) => { + MetaAggregationJobId::Draft04(Cow::Borrowed(agg_job_id)) + } + (DapVersion::Draft04, DapResource::Undefined, None) => { + return Err(DapAbort::BadRequest("undefined resource".into())); + } + _ => unreachable!("unhandled resource {:?}", req.resource), + }; + let state = self - .get_helper_state(&agg_cont_req.task_id, &agg_cont_req.agg_job_id) + .get_helper_state(task_id, &agg_job_id) .await? .ok_or(DapAbort::UnrecognizedAggregationJob)?; let part_batch_sel = state.part_batch_sel.clone(); - let transition = - task_config - .vdaf - .handle_agg_cont_req(state, &agg_cont_req, self.metrics())?; + let transition = task_config.vdaf.handle_agg_job_cont_req( + state, + &agg_job_cont_req, + self.metrics(), + )?; - let (agg_resp, out_shares_count) = match transition { + let (agg_job_resp, out_shares_count) = match transition { DapHelperTransition::Continue(..) => { return Err(DapError::fatal("unexpected transition (continued)").into()); } - DapHelperTransition::Finish(out_shares, agg_resp) => { + DapHelperTransition::Finish(out_shares, agg_job_resp) => { let out_shares_count = u64::try_from(out_shares.len()).unwrap(); - self.put_out_shares(&agg_cont_req.task_id, &part_batch_sel, out_shares) + self.put_out_shares(task_id, &part_batch_sel, out_shares) .await?; - (agg_resp, out_shares_count) + (agg_job_resp, out_shares_count) } }; @@ -860,7 +983,7 @@ where Ok(DapResponse { media_type: Some(MEDIA_TYPE_AGG_CONT_RESP), - payload: agg_resp.get_encoded(), + payload: agg_job_resp.get_encoded(), }) } //TODO spec: Specify this behavior. @@ -887,6 +1010,8 @@ where return Err(DapAbort::UnauthorizedRequest); } + let task_id = req.task_id()?; + let agg_share_req = AggregateShareReq::get_decoded_with_param(&req.version, &req.payload)?; let wrapped_task_config = self .get_task_config_for(Cow::Borrowed(req.task_id()?)) @@ -904,7 +1029,7 @@ where check_batch( self, task_config, - &agg_share_req.task_id, + task_id, &agg_share_req.batch_sel, &agg_share_req.agg_param, now, @@ -912,7 +1037,7 @@ where .await?; let agg_share = self - .get_agg_share(&agg_share_req.task_id, &agg_share_req.batch_sel) + .get_agg_share(task_id, &agg_share_req.batch_sel) .await?; // Check that we have aggreagted the same set of reports as the leader. @@ -931,18 +1056,18 @@ where } // Mark each aggregated report as collected. - self.mark_collected(&agg_share_req.task_id, &agg_share_req.batch_sel) + self.mark_collected(task_id, &agg_share_req.batch_sel) .await?; let encrypted_agg_share = task_config.vdaf.produce_helper_encrypted_agg_share( &task_config.collector_hpke_config, - &agg_share_req.task_id, + task_id, &agg_share_req.batch_sel, &agg_share, task_config.version, )?; - let agg_share_resp = AggregateShareResp { + let agg_share_resp = AggregateShare { encrypted_agg_share, }; @@ -952,7 +1077,7 @@ where .inc_by(agg_share_req.report_count); Ok(DapResponse { - media_type: Some(MEDIA_TYPE_AGG_SHARE_RESP), + media_type: versioned_media_type_for(&task_config.version, MEDIA_TYPE_AGG_SHARE_RESP), payload: agg_share_resp.get_encoded(), }) } @@ -979,7 +1104,7 @@ fn check_part_batch( async fn check_batch<'srv, 'req, S>( agg: &impl DapAggregator<'srv, 'req, S>, task_config: &DapTaskConfig, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, agg_param: &[u8], now: Time, diff --git a/daphne/src/roles_test.rs b/daphne/src/roles_test.rs index c92935eda..10c568565 100644 --- a/daphne/src/roles_test.rs +++ b/daphne/src/roles_test.rs @@ -6,15 +6,16 @@ use crate::{ async_test_versions, auth::BearerToken, constants::{ - DRAFT02_MEDIA_TYPE_HPKE_CONFIG, MEDIA_TYPE_AGG_CONT_REQ, MEDIA_TYPE_AGG_INIT_REQ, - MEDIA_TYPE_AGG_SHARE_REQ, MEDIA_TYPE_COLLECT_REQ, MEDIA_TYPE_REPORT, + versioned_media_type_for, DRAFT02_MEDIA_TYPE_HPKE_CONFIG, MEDIA_TYPE_AGG_CONT_REQ, + MEDIA_TYPE_AGG_INIT_REQ, MEDIA_TYPE_AGG_SHARE_REQ, MEDIA_TYPE_COLLECT_REQ, + MEDIA_TYPE_REPORT, }, hpke::{HpkeDecrypter, HpkeReceiverConfig}, messages::{ - taskprov, AggregateContinueReq, AggregateInitializeReq, AggregateResp, AggregateShareReq, - BatchSelector, CollectReq, CollectResp, Extension, HpkeKemId, Id, Interval, - PartialBatchSelector, Query, Report, ReportId, ReportMetadata, ReportShare, Time, - Transition, TransitionFailure, TransitionVar, + taskprov, AggregateShareReq, AggregationJobContinueReq, AggregationJobInitReq, + AggregationJobResp, BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq, + Extension, HpkeKemId, Interval, PartialBatchSelector, Query, Report, ReportId, + ReportMetadata, ReportShare, TaskId, Time, Transition, TransitionFailure, TransitionVar, }, metrics::DaphneMetrics, roles::{early_metadata_check, DapAggregator, DapAuthorizedSender, DapHelper, DapLeader}, @@ -23,12 +24,13 @@ use crate::{ testing::{AggStore, DapBatchBucketOwned, MockAggregator, MockAggregatorReportSelector}, vdaf::VdafVerifyKey, DapAbort, DapAggregateShare, DapCollectJob, DapGlobalConfig, DapMeasurement, DapQueryConfig, - DapRequest, DapTaskConfig, DapVersion, Prio3Config, VdafConfig, + DapRequest, DapResource, DapTaskConfig, DapVersion, MetaAggregationJobId, Prio3Config, + VdafConfig, }; use assert_matches::assert_matches; use matchit::Router; use paste::paste; -use prio::codec::{Decode, Encode, ParameterizedEncode}; +use prio::codec::{Decode, ParameterizedEncode}; use rand::{thread_rng, Rng}; use std::{ borrow::Cow, @@ -55,9 +57,9 @@ struct Test { leader: Arc, helper: Arc, collector_token: BearerToken, - time_interval_task_id: Id, - fixed_size_task_id: Id, - expired_task_id: Id, + time_interval_task_id: TaskId, + fixed_size_task_id: TaskId, + expired_task_id: TaskId, version: DapVersion, prometheus_registry: prometheus::Registry, } @@ -92,9 +94,9 @@ impl Test { HpkeReceiverConfig::gen(rng.gen(), HpkeKemId::X25519HkdfSha256).unwrap(); // Create the task list. - let time_interval_task_id = Id(rng.gen()); - let fixed_size_task_id = Id(rng.gen()); - let expired_task_id = Id(rng.gen()); + let time_interval_task_id = TaskId(rng.gen()); + let fixed_size_task_id = TaskId(rng.gen()); + let expired_task_id = TaskId(rng.gen()); let mut tasks = HashMap::new(); tasks.insert( time_interval_task_id.clone(), @@ -204,23 +206,29 @@ impl Test { } } - async fn gen_test_upload_req(&self, report: Report) -> DapRequest { - let task_config = self.leader.unchecked_get_task_config(&report.task_id).await; + async fn gen_test_upload_req( + &self, + report: Report, + task_id: &TaskId, + ) -> DapRequest { + let task_config = self.leader.unchecked_get_task_config(task_id).await; let version = task_config.version; DapRequest { version, media_type: Some(MEDIA_TYPE_REPORT), - task_id: Some(report.task_id.clone()), + task_id: Some(task_id.clone()), + resource: DapResource::Undefined, payload: report.get_encoded_with_param(&version), url: task_config.leader_url.join("upload").unwrap(), sender_auth: None, } } - async fn gen_test_agg_init_req( + async fn gen_test_agg_job_init_req( &self, - task_id: &Id, + task_id: &TaskId, + version: DapVersion, report_shares: Vec, ) -> DapRequest { let mut rng = thread_rng(); @@ -228,17 +236,19 @@ impl Test { let part_batch_sel = match task_config.query { DapQueryConfig::TimeInterval { .. } => PartialBatchSelector::TimeInterval, DapQueryConfig::FixedSize { .. } => PartialBatchSelector::FixedSizeByBatchId { - batch_id: Id(rng.gen()), + batch_id: BatchId(rng.gen()), }, }; + let agg_job_id = MetaAggregationJobId::gen_for_version(&version); self.leader_authorized_req_with_version( task_id, + Some(&agg_job_id), task_config.version, - MEDIA_TYPE_AGG_INIT_REQ, - AggregateInitializeReq { - task_id: task_id.clone(), - agg_job_id: Id(rng.gen()), + versioned_media_type_for(&task_config.version, MEDIA_TYPE_AGG_INIT_REQ).unwrap(), + AggregationJobInitReq { + draft02_task_id: task_id.for_request_payload(&version), + draft02_agg_job_id: agg_job_id.for_request_payload(), agg_param: Vec::default(), part_batch_sel, report_shares, @@ -248,21 +258,24 @@ impl Test { .await } - async fn gen_test_agg_cont_req( + async fn gen_test_agg_job_cont_req_with_round( &self, - agg_job_id: Id, + agg_job_id: &MetaAggregationJobId<'_>, transitions: Vec, + round: Option, ) -> DapRequest { let task_id = &self.time_interval_task_id; let task_config = self.leader.unchecked_get_task_config(task_id).await; self.leader_authorized_req( task_id, + Some(agg_job_id), task_config.version, - MEDIA_TYPE_AGG_CONT_REQ, - AggregateContinueReq { - task_id: task_id.clone(), - agg_job_id, + versioned_media_type_for(&task_config.version, MEDIA_TYPE_AGG_CONT_REQ).unwrap(), + AggregationJobContinueReq { + draft02_task_id: task_id.for_request_payload(&task_config.version), + draft02_agg_job_id: agg_job_id.for_request_payload(), + round, transitions, }, task_config.helper_url.join("aggregate").unwrap(), @@ -270,6 +283,21 @@ impl Test { .await } + async fn gen_test_agg_job_cont_req( + &self, + agg_job_id: &MetaAggregationJobId<'_>, + transitions: Vec, + version: DapVersion, + ) -> DapRequest { + let round = if version == DapVersion::Draft02 { + None + } else { + Some(1) + }; + self.gen_test_agg_job_cont_req_with_round(agg_job_id, transitions, round) + .await + } + async fn gen_test_agg_share_req( &self, report_count: u64, @@ -278,23 +306,30 @@ impl Test { let task_id = &self.time_interval_task_id; let task_config = self.leader.unchecked_get_task_config(task_id).await; + let url_path = if task_config.version == DapVersion::Draft02 { + "aggregate_shares".to_string() + } else { + format!("tasks/{}/aggregate_shares", task_id.to_base64url()) + }; + self.leader_authorized_req_with_version( task_id, + None, task_config.version, MEDIA_TYPE_AGG_SHARE_REQ, AggregateShareReq { - task_id: task_id.clone(), + draft02_task_id: task_id.for_request_payload(&task_config.version), batch_sel: BatchSelector::default(), agg_param: Vec::default(), report_count, checksum, }, - task_config.helper_url.join("aggregate_share").unwrap(), + task_config.helper_url.join(&url_path).unwrap(), ) .await } - async fn gen_test_report(&self, task_id: &Id) -> Report { + async fn gen_test_report(&self, task_id: &TaskId) -> Report { let version = self.leader.unchecked_get_task_config(task_id).await.version; // Construct HPKE config list. @@ -326,7 +361,7 @@ impl Test { .unwrap() } - async fn run_agg_job(&self, task_id: &Id) -> Result<(), DapAbort> { + async fn run_agg_job(&self, task_id: &TaskId) -> Result<(), DapAbort> { let wrapped = self .leader .get_task_config_for(Cow::Owned(task_id.clone())) @@ -345,7 +380,7 @@ impl Test { Ok(()) } - async fn run_col_job(&self, task_id: &Id, query: &Query) -> Result<(), DapAbort> { + async fn run_col_job(&self, task_id: &TaskId, query: &Query) -> Result<(), DapAbort> { let wrapped = self .leader .get_task_config_for(Cow::Owned(task_id.clone())) @@ -357,10 +392,10 @@ impl Test { let req = self .collector_authorized_req( task_config.version, - MEDIA_TYPE_COLLECT_REQ, + versioned_media_type_for(&task_config.version, MEDIA_TYPE_COLLECT_REQ).unwrap(), task_id, - CollectReq { - task_id: task_id.clone(), + CollectionReq { + draft02_task_id: task_id.for_request_payload(&task_config.version), query: query.clone(), agg_param: Vec::default(), }, @@ -371,25 +406,26 @@ impl Test { // Leader: Handle request from Collector. self.leader.http_post_collect(&req).await?; let resp = self.leader.get_pending_collect_jobs().await?; - let (collect_id, collect_req) = &resp[0]; + let (task_id, collect_id, collect_req) = &resp[0]; // Leader->Helper: Complete collection job. let _reports_collected = self .leader - .run_collect_job(collect_id, task_config, collect_req) + .run_collect_job(task_id, collect_id, task_config, collect_req) .await?; Ok(()) } - async fn leader_authorized_req( + async fn leader_authorized_req>( &self, - task_id: &Id, + task_id: &TaskId, + agg_job_id: Option<&MetaAggregationJobId<'_>>, version: DapVersion, media_type: &'static str, msg: M, url: Url, ) -> DapRequest { - let payload = msg.get_encoded(); + let payload = msg.get_encoded_with_param(&version); let sender_auth = Some( self.leader .authorize(task_id, media_type, &payload) @@ -400,6 +436,7 @@ impl Test { version, media_type: Some(media_type), task_id: Some(task_id.clone()), + resource: agg_job_id.map_or(DapResource::Undefined, |id| id.for_request_path()), payload, url, sender_auth, @@ -408,7 +445,8 @@ impl Test { async fn leader_authorized_req_with_version>( &self, - task_id: &Id, + task_id: &TaskId, + agg_job_id: Option<&MetaAggregationJobId<'_>>, version: DapVersion, media_type: &'static str, msg: M, @@ -425,6 +463,7 @@ impl Test { version, media_type: Some(media_type), task_id: Some(task_id.clone()), + resource: agg_job_id.map_or(DapResource::Undefined, |id| id.for_request_path()), payload, url, sender_auth, @@ -435,14 +474,21 @@ impl Test { &self, version: DapVersion, media_type: &'static str, - task_id: &Id, + task_id: &TaskId, msg: M, url: Url, ) -> DapRequest { + let mut rng = thread_rng(); + let collect_job_id = CollectionJobId(rng.gen()); DapRequest { version, media_type: Some(media_type), task_id: Some(task_id.clone()), + resource: if version == DapVersion::Draft02 { + DapResource::Undefined + } else { + DapResource::CollectionJob(collect_job_id) + }, payload: msg.get_encoded_with_param(&version), url, sender_auth: Some(self.collector_token.clone()), @@ -450,25 +496,27 @@ impl Test { } } -// Test that the Helper properly handles the batch parameter in the AggregateInitializeReq. +// Test that the Helper properly handles the batch parameter in the AggregationJobInitReq. async fn http_post_aggregate_invalid_batch_sel(version: DapVersion) { let mut rng = thread_rng(); let t = Test::new(version); let task_id = &t.time_interval_task_id; let task_config = t.leader.unchecked_get_task_config(task_id).await; + let agg_job_id = MetaAggregationJobId::gen_for_version(&version); // Helper expects "time_interval" query, but Leader indicates "fixed_size". let req = t .leader_authorized_req_with_version( task_id, + Some(&agg_job_id), task_config.version, - MEDIA_TYPE_AGG_INIT_REQ, - AggregateInitializeReq { - task_id: task_id.clone(), - agg_job_id: Id(rng.gen()), + versioned_media_type_for(&task_config.version, MEDIA_TYPE_AGG_INIT_REQ).unwrap(), + AggregationJobInitReq { + draft02_task_id: task_id.for_request_payload(&version), + draft02_agg_job_id: agg_job_id.for_request_payload(), agg_param: Vec::default(), part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { - batch_id: Id(rng.gen()), + batch_id: BatchId(rng.gen()), }, report_shares: Vec::default(), }, @@ -486,7 +534,7 @@ async_test_versions! { http_post_aggregate_invalid_batch_sel } async fn http_post_aggregate_init_unauthorized_request(version: DapVersion) { let t = Test::new(version); let mut req = t - .gen_test_agg_init_req(&t.time_interval_task_id, Vec::default()) + .gen_test_agg_job_init_req(&t.time_interval_task_id, version, Vec::default()) .await; req.sender_auth = None; @@ -512,34 +560,117 @@ async fn http_post_aggregate_init_expired_task(version: DapVersion) { let report = t.gen_test_report(&t.expired_task_id).await; let report_share = ReportShare { - metadata: report.metadata, + report_metadata: report.report_metadata, public_share: report.public_share, encrypted_input_share: report.encrypted_input_shares[1].clone(), }; let req = t - .gen_test_agg_init_req(&t.expired_task_id, vec![report_share]) + .gen_test_agg_job_init_req(&t.expired_task_id, version, vec![report_share]) .await; let resp = t.helper.http_post_aggregate(&req).await.unwrap(); - let agg_resp = AggregateResp::get_decoded(&resp.payload).unwrap(); - assert_eq!(agg_resp.transitions.len(), 1); + let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload).unwrap(); + assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( - agg_resp.transitions[0].var, + agg_job_resp.transitions[0].var, TransitionVar::Failed(TransitionFailure::TaskExpired) ); } async_test_versions! { http_post_aggregate_init_expired_task } +// Test that the Helper rejects reports with a bad round number. +async fn http_post_aggregate_bad_round(version: DapVersion) { + let t = Test::new(version); + if version == DapVersion::Draft02 { + // Nothing to test. + return; + } + + let report = t.gen_test_report(&t.time_interval_task_id).await; + let report_share = ReportShare { + report_metadata: report.report_metadata, + public_share: report.public_share, + encrypted_input_share: report.encrypted_input_shares[1].clone(), + }; + let req = t + .gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report_share]) + .await; + let agg_job_id = match &req.resource { + DapResource::AggregationJob(agg_job_id) => agg_job_id.clone(), + _ => panic!("agg_job_id resource missing!"), + }; + let resp = t.helper.http_post_aggregate(&req).await.unwrap(); + let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload).unwrap(); + assert_eq!(agg_job_resp.transitions.len(), 1); + assert_matches!(agg_job_resp.transitions[0].var, TransitionVar::Continued(_)); + // Test wrong round + let req = t + .gen_test_agg_job_cont_req_with_round( + &MetaAggregationJobId::Draft04(Cow::Borrowed(&agg_job_id)), + Vec::default(), + Some(2), + ) + .await; + assert_matches!( + t.helper.http_post_aggregate(&req).await, + Err(DapAbort::RoundMismatch) + ); +} + +async_test_versions! { http_post_aggregate_bad_round } + +// Test that the Helper rejects reports with a bad round id +async fn http_post_aggregate_zero_round(version: DapVersion) { + let t = Test::new(version); + if version == DapVersion::Draft02 { + // Nothing to test. + return; + } + + let report = t.gen_test_report(&t.time_interval_task_id).await; + let report_share = ReportShare { + report_metadata: report.report_metadata, + public_share: report.public_share, + encrypted_input_share: report.encrypted_input_shares[1].clone(), + }; + let req = t + .gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report_share]) + .await; + let agg_job_id = match &req.resource { + DapResource::AggregationJob(agg_job_id) => agg_job_id.clone(), + _ => panic!("agg_job_id resource missing!"), + }; + let resp = t.helper.http_post_aggregate(&req).await.unwrap(); + let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload).unwrap(); + assert_eq!(agg_job_resp.transitions.len(), 1); + assert_matches!(agg_job_resp.transitions[0].var, TransitionVar::Continued(_)); + // Test wrong round + let req = t + .gen_test_agg_job_cont_req_with_round( + &MetaAggregationJobId::Draft04(Cow::Borrowed(&agg_job_id)), + Vec::default(), + Some(0), + ) + .await; + assert_matches!( + t.helper.http_post_aggregate(&req).await, + Err(DapAbort::UnrecognizedMessage) + ); +} + +async_test_versions! { http_post_aggregate_zero_round } + async fn http_get_hpke_config_unrecognized_task(version: DapVersion) { let t = Test::new(version); let mut rng = thread_rng(); - let task_id = Id(rng.gen()); + let task_id = TaskId(rng.gen()); let req = DapRequest { version: DapVersion::Draft02, media_type: Some(DRAFT02_MEDIA_TYPE_HPKE_CONFIG), payload: Vec::new(), task_id: Some(task_id.clone()), + resource: DapResource::Undefined, url: Url::parse(&format!( "http://aggregator.biz/v02/hpke_config?task_id={}", task_id.to_base64url() @@ -562,6 +693,7 @@ async fn http_get_hpke_config_missing_task_id(version: DapVersion) { version: DapVersion::Draft02, media_type: Some(DRAFT02_MEDIA_TYPE_HPKE_CONFIG), task_id: Some(t.time_interval_task_id.clone()), + resource: DapResource::Undefined, payload: Vec::new(), url: Url::parse("http://aggregator.biz/v02/hpke_config").unwrap(), sender_auth: None, @@ -580,8 +712,10 @@ async_test_versions! { http_get_hpke_config_missing_task_id } async fn http_post_aggregate_cont_unauthorized_request(version: DapVersion) { let t = Test::new(version); - let mut rng = thread_rng(); - let mut req = t.gen_test_agg_cont_req(Id(rng.gen()), Vec::default()).await; + let agg_job_id = MetaAggregationJobId::gen_for_version(&version); + let mut req = t + .gen_test_agg_job_cont_req(&agg_job_id, Vec::default(), version) + .await; req.sender_auth = None; // Expect failure due to missing bearer token. @@ -634,12 +768,13 @@ async fn http_post_aggregate_share_invalid_batch_sel(version: DapVersion) { let req = t .leader_authorized_req_with_version( &t.time_interval_task_id, + None, task_config.version, - MEDIA_TYPE_AGG_SHARE_REQ, + versioned_media_type_for(&task_config.version, MEDIA_TYPE_AGG_SHARE_REQ).unwrap(), AggregateShareReq { - task_id: t.time_interval_task_id.clone(), + draft02_task_id: t.time_interval_task_id.for_request_payload(&version), batch_sel: BatchSelector::FixedSizeByBatchId { - batch_id: Id(rng.gen()), + batch_id: BatchId(rng.gen()), }, agg_param: Vec::default(), report_count: 0, @@ -661,12 +796,13 @@ async fn http_post_aggregate_share_invalid_batch_sel(version: DapVersion) { let req = t .leader_authorized_req_with_version( &t.fixed_size_task_id, + None, task_config.version, - MEDIA_TYPE_AGG_SHARE_REQ, + versioned_media_type_for(&task_config.version, MEDIA_TYPE_AGG_SHARE_REQ).unwrap(), AggregateShareReq { - task_id: t.fixed_size_task_id.clone(), + draft02_task_id: t.fixed_size_task_id.for_request_payload(&version), batch_sel: BatchSelector::FixedSizeByBatchId { - batch_id: Id(rng.gen()), // Unrecognized batch ID + batch_id: BatchId(rng.gen()), // Unrecognized batch ID }, agg_param: Vec::default(), report_count: 0, @@ -684,20 +820,36 @@ async fn http_post_aggregate_share_invalid_batch_sel(version: DapVersion) { async_test_versions! { http_post_aggregate_share_invalid_batch_sel } async fn http_post_collect_unauthorized_request(version: DapVersion) { + let mut rng = thread_rng(); let t = Test::new(version); let task_id = &t.time_interval_task_id; let task_config = t.leader.unchecked_get_task_config(task_id).await; + let collect_job_id = CollectionJobId(rng.gen()); + let url_path = if task_config.version == DapVersion::Draft02 { + "collect".to_string() + } else { + format!( + "tasks/{}/collection_jobs/{}", + task_id.to_base64url(), + collect_job_id.to_base64url() + ) + }; let mut req = DapRequest { version: task_config.version, - media_type: Some(MEDIA_TYPE_COLLECT_REQ), + media_type: versioned_media_type_for(&task_config.version, MEDIA_TYPE_COLLECT_REQ), task_id: Some(task_id.clone()), - payload: CollectReq { - task_id: task_id.clone(), + resource: if version == DapVersion::Draft02 { + DapResource::Undefined + } else { + DapResource::CollectionJob(collect_job_id) + }, + payload: CollectionReq { + draft02_task_id: task_id.for_request_payload(&version), query: Query::default(), agg_param: Vec::default(), } .get_encoded_with_param(&task_config.version), - url: task_config.leader_url.join("collect").unwrap(), + url: task_config.leader_url.join(&url_path).unwrap(), sender_auth: None, // Unauthorized request. }; @@ -722,24 +874,26 @@ async fn http_post_aggregate_failure_hpke_decrypt_error(version: DapVersion) { let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let (metadata, public_share, mut encrypted_input_share) = ( - report.metadata, + let (report_metadata, public_share, mut encrypted_input_share) = ( + report.report_metadata, report.public_share, report.encrypted_input_shares[1].clone(), ); encrypted_input_share.payload[0] ^= 0xff; // Cause decryption to fail let report_shares = vec![ReportShare { - metadata, + report_metadata, public_share, encrypted_input_share, }]; - let req = t.gen_test_agg_init_req(task_id, report_shares).await; + let req = t + .gen_test_agg_job_init_req(task_id, version, report_shares) + .await; - // Get AggregateResp and then extract the transition data from inside. - let agg_resp = - AggregateResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) + // Get AggregationJobResp and then extract the transition data from inside. + let agg_job_resp = + AggregationJobResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) .unwrap(); - let transition = &agg_resp.transitions[0]; + let transition = &agg_job_resp.transitions[0]; // Expect failure due to invalid ciphertext. assert_matches!( @@ -756,18 +910,20 @@ async fn http_post_aggregate_transition_continue(version: DapVersion) { let report = t.gen_test_report(task_id).await; let report_shares = vec![ReportShare { - metadata: report.metadata.clone(), + report_metadata: report.report_metadata.clone(), public_share: report.public_share, // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). encrypted_input_share: report.encrypted_input_shares[1].clone(), }]; - let req = t.gen_test_agg_init_req(task_id, report_shares).await; + let req = t + .gen_test_agg_job_init_req(task_id, version, report_shares) + .await; - // Get AggregateResp and then extract the transition data from inside. - let agg_resp = - AggregateResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) + // Get AggregationJobResp and then extract the transition data from inside. + let agg_job_resp = + AggregationJobResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) .unwrap(); - let transition = &agg_resp.transitions[0]; + let transition = &agg_job_resp.transitions[0]; // Expect success due to valid ciphertext. assert_matches!(transition.var, TransitionVar::Continued(_)); @@ -781,12 +937,14 @@ async fn http_post_aggregate_failure_report_replayed(version: DapVersion) { let report = t.gen_test_report(task_id).await; let report_shares = vec![ReportShare { - metadata: report.metadata.clone(), + report_metadata: report.report_metadata.clone(), public_share: report.public_share, // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). encrypted_input_share: report.encrypted_input_shares[1].clone(), }]; - let req = t.gen_test_agg_init_req(task_id, report_shares).await; + let req = t + .gen_test_agg_job_init_req(task_id, version, report_shares) + .await; // Add dummy data to report store backend. This is done in a new scope so that the lock on the // report store is released before running the test. @@ -797,14 +955,16 @@ async fn http_post_aggregate_failure_report_replayed(version: DapVersion) { .lock() .expect("report_store: failed to lock"); let report_store = guard.entry(task_id.clone()).or_default(); - report_store.processed.insert(report.metadata.id.clone()); + report_store + .processed + .insert(report.report_metadata.id.clone()); } - // Get AggregateResp and then extract the transition data from inside. - let agg_resp = - AggregateResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) + // Get AggregationJobResp and then extract the transition data from inside. + let agg_job_resp = + AggregationJobResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) .unwrap(); - let transition = &agg_resp.transitions[0]; + let transition = &agg_job_resp.transitions[0]; // Expect failure due to report store marked as collected. assert_matches!( @@ -827,12 +987,14 @@ async fn http_post_aggregate_failure_batch_collected(version: DapVersion) { let report = t.gen_test_report(task_id).await; let report_shares = vec![ReportShare { - metadata: report.metadata.clone(), + report_metadata: report.report_metadata.clone(), public_share: report.public_share, // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). encrypted_input_share: report.encrypted_input_shares[1].clone(), }]; - let req = t.gen_test_agg_init_req(task_id, report_shares).await; + let req = t + .gen_test_agg_job_init_req(task_id, version, report_shares) + .await; // Add mock data to the aggreagte store backend. This is done in its own scope so that the lock // is released before running the test. Otherwise the test will deadlock. @@ -846,7 +1008,7 @@ async fn http_post_aggregate_failure_batch_collected(version: DapVersion) { agg_store.insert( DapBatchBucketOwned::TimeInterval { - batch_window: task_config.truncate_time(t.now), + batch_window: task_config.quantized_time_lower_bound(t.now), }, AggStore { agg_share: DapAggregateShare::default(), @@ -855,11 +1017,11 @@ async fn http_post_aggregate_failure_batch_collected(version: DapVersion) { ); } - // Get AggregateResp and then extract the transition data from inside. - let agg_resp = - AggregateResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) + // Get AggregationJobResp and then extract the transition data from inside. + let agg_job_resp = + AggregationJobResp::get_decoded(&t.helper.http_post_aggregate(&req).await.unwrap().payload) .unwrap(); - let transition = &agg_resp.transitions[0]; + let transition = &agg_job_resp.transitions[0]; // Expect failure due to report store marked as collected. assert_matches!( @@ -881,12 +1043,14 @@ async fn http_post_aggregate_abort_helper_state_overwritten(version: DapVersion) let report = t.gen_test_report(task_id).await; let report_shares = vec![ReportShare { - metadata: report.metadata.clone(), + report_metadata: report.report_metadata.clone(), public_share: report.public_share, // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). encrypted_input_share: report.encrypted_input_shares[1].clone(), }]; - let req = t.gen_test_agg_init_req(task_id, report_shares).await; + let req = t + .gen_test_agg_job_init_req(task_id, version, report_shares) + .await; // Send aggregate request. let _ = t.helper.http_post_aggregate(&req).await; @@ -904,8 +1068,10 @@ async_test_versions! { http_post_aggregate_abort_helper_state_overwritten } async fn http_post_aggregate_fail_send_cont_req(version: DapVersion) { let t = Test::new(version); - let mut rng = thread_rng(); - let req = t.gen_test_agg_cont_req(Id(rng.gen()), Vec::default()).await; + let agg_job_id = MetaAggregationJobId::gen_for_version(&version); + let req = t + .gen_test_agg_job_cont_req(&agg_job_id, Vec::default(), version) + .await; // Send aggregate continue request to helper. let err = t.helper.http_post_aggregate(&req).await.unwrap_err(); @@ -922,12 +1088,12 @@ async fn http_post_upload_fail_send_invalid_report(version: DapVersion) { let task_config = t.leader.unchecked_get_task_config(task_id).await; // Construct a report payload with an invalid task ID. - let mut report_invalid_task_id = t.gen_test_report(task_id).await; - report_invalid_task_id.task_id = Id([0; 32]); + let report_invalid_task_id = t.gen_test_report(task_id).await; let req = DapRequest { version: task_config.version, media_type: Some(MEDIA_TYPE_REPORT), - task_id: Some(report_invalid_task_id.task_id.clone()), + task_id: Some(TaskId([0; 32])), + resource: DapResource::Undefined, payload: report_invalid_task_id.get_encoded_with_param(&task_config.version), url: task_config.leader_url.join("upload").unwrap(), sender_auth: None, @@ -943,7 +1109,7 @@ async fn http_post_upload_fail_send_invalid_report(version: DapVersion) { let mut report_one_input_share = t.gen_test_report(task_id).await; report_one_input_share.encrypted_input_shares = vec![report_one_input_share.encrypted_input_shares[0].clone()]; - let req = t.gen_test_upload_req(report_one_input_share).await; + let req = t.gen_test_upload_req(report_one_input_share, task_id).await; // Expect failure due to incorrect number of input shares assert_matches!( @@ -965,6 +1131,7 @@ async fn http_post_upload_task_expired(version: DapVersion) { version: task_config.version, media_type: Some(MEDIA_TYPE_REPORT), task_id: Some(task_id.clone()), + resource: DapResource::Undefined, payload: report.get_encoded_with_param(&version), url: task_config.leader_url.join("upload").unwrap(), sender_auth: None, @@ -983,7 +1150,7 @@ async fn get_reports_empty_response(version: DapVersion) { let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let req = t.gen_test_upload_req(report.clone()).await; + let req = t.gen_test_upload_req(report.clone(), task_id).await; // Upload report. t.leader @@ -1021,8 +1188,8 @@ async fn poll_collect_job_test_results(version: DapVersion) { version, MEDIA_TYPE_COLLECT_REQ, task_id, - CollectReq { - task_id: task_id.clone(), + CollectionReq { + draft02_task_id: task_id.for_request_payload(&version), query: task_config.query_for_current_batch_window(t.now), agg_param: Vec::default(), }, @@ -1036,7 +1203,7 @@ async fn poll_collect_job_test_results(version: DapVersion) { // Expect DapCollectJob::Unknown due to invalid collect ID. assert_eq!( t.leader - .poll_collect_job(task_id, &Id::default()) + .poll_collect_job(task_id, &CollectionJobId::default()) .await .unwrap(), DapCollectJob::Unknown @@ -1044,10 +1211,18 @@ async fn poll_collect_job_test_results(version: DapVersion) { // Leader: Get pending collect job to obtain collect_id let resp = t.leader.get_pending_collect_jobs().await.unwrap(); - let (collect_id, _collect_req) = &resp[0]; - let collect_resp = CollectResp { + let (_task_id, collect_id, _collect_req) = &resp[0]; + let collect_resp = Collection { part_batch_sel: PartialBatchSelector::TimeInterval, report_count: 0, + interval: if version == DapVersion::Draft02 { + None + } else { + Some(Interval { + start: 0, + duration: 2000000000, + }) + }, encrypted_agg_shares: Vec::default(), }; @@ -1086,14 +1261,14 @@ async fn http_post_collect_fail_invalid_batch_interval(version: DapVersion) { // Collector: Create a CollectReq with a very large batch interval. let req = t .collector_authorized_req( - task_config.version, + version, MEDIA_TYPE_COLLECT_REQ, task_id, - CollectReq { - task_id: task_id.clone(), + CollectionReq { + draft02_task_id: task_id.for_request_payload(&version), query: Query::TimeInterval { batch_interval: Interval { - start: t.now - (t.now % task_config.time_precision), + start: task_config.quantized_time_lower_bound(t.now), duration: t.leader.global_config.max_batch_duration + task_config.time_precision, }, @@ -1113,15 +1288,14 @@ async fn http_post_collect_fail_invalid_batch_interval(version: DapVersion) { // Collector: Create a CollectReq with a batch interval in the past. let req = t .collector_authorized_req( - task_config.version, + version, MEDIA_TYPE_COLLECT_REQ, task_id, - CollectReq { - task_id: task_id.clone(), + CollectionReq { + draft02_task_id: task_id.for_request_payload(&version), query: Query::TimeInterval { batch_interval: Interval { - start: t.now - - (t.now % task_config.time_precision) + start: task_config.quantized_time_lower_bound(t.now) - t.leader.global_config.min_batch_interval_start - task_config.time_precision, duration: task_config.time_precision * 2, @@ -1142,14 +1316,14 @@ async fn http_post_collect_fail_invalid_batch_interval(version: DapVersion) { // Collector: Create a CollectReq with a batch interval in the future. let req = t .collector_authorized_req( - task_config.version, + version, MEDIA_TYPE_COLLECT_REQ, task_id, - CollectReq { - task_id: task_id.clone(), + CollectionReq { + draft02_task_id: task_id.for_request_payload(&version), query: Query::TimeInterval { batch_interval: Interval { - start: t.now - (t.now % task_config.time_precision) + start: task_config.quantized_time_lower_bound(t.now) + t.leader.global_config.max_batch_interval_end - task_config.time_precision, duration: task_config.time_precision * 2, @@ -1178,15 +1352,14 @@ async fn http_post_collect_succeed_max_batch_interval(version: DapVersion) { // Collector: Create a CollectReq with a very large batch interval. let req = t .collector_authorized_req( - task_config.version, + version, MEDIA_TYPE_COLLECT_REQ, task_id, - CollectReq { - task_id: task_id.clone(), + CollectionReq { + draft02_task_id: task_id.for_request_payload(&version), query: Query::TimeInterval { batch_interval: Interval { - start: t.now - - (t.now % task_config.time_precision) + start: task_config.quantized_time_lower_bound(t.now) - t.leader.global_config.max_batch_duration / 2, duration: t.leader.global_config.max_batch_duration, }, @@ -1211,7 +1384,7 @@ async fn http_post_collect_fail_overlapping_batch_interval(version: DapVersion) // Create a report. let report = t.gen_test_report(task_id).await; - let req = t.gen_test_upload_req(report.clone()).await; + let req = t.gen_test_upload_req(report.clone(), task_id).await; // Client: Send upload request to Leader. t.leader.http_post_upload(&req).await.unwrap(); @@ -1240,14 +1413,14 @@ async fn http_post_collect_success(version: DapVersion) { let task_config = t.leader.unchecked_get_task_config(task_id).await; // Collector: Create a CollectReq. - let collector_collect_req = CollectReq { - task_id: task_id.clone(), + let collector_collect_req = CollectionReq { + draft02_task_id: task_id.for_request_payload(&version), query: task_config.query_for_current_batch_window(t.now), agg_param: Vec::default(), }; let req = t .collector_authorized_req( - task_config.version, + version, MEDIA_TYPE_COLLECT_REQ, task_id, collector_collect_req.clone(), @@ -1258,7 +1431,7 @@ async fn http_post_collect_success(version: DapVersion) { // Leader: Handle the CollectReq received from Collector. let url = t.leader.http_post_collect(&req).await.unwrap(); let resp = t.leader.get_pending_collect_jobs().await.unwrap(); - let (leader_collect_id, leader_collect_req) = &resp[0]; + let (_leader_task_id, leader_collect_id, leader_collect_req) = &resp[0]; // Check that the CollectReq sent by Collector is the same that is received by Leader. assert_eq!(&collector_collect_req, leader_collect_req); @@ -1295,10 +1468,10 @@ async fn http_post_collect_invalid_query(version: DapVersion) { task_config.version, MEDIA_TYPE_COLLECT_REQ, &t.time_interval_task_id, - CollectReq { - task_id: t.time_interval_task_id.clone(), + CollectionReq { + draft02_task_id: t.time_interval_task_id.for_request_payload(&version), query: Query::FixedSizeByBatchId { - batch_id: Id(rng.gen()), + batch_id: BatchId(rng.gen()), }, agg_param: Vec::default(), }, @@ -1320,10 +1493,10 @@ async fn http_post_collect_invalid_query(version: DapVersion) { task_config.version, MEDIA_TYPE_COLLECT_REQ, &t.fixed_size_task_id, - CollectReq { - task_id: t.fixed_size_task_id.clone(), + CollectionReq { + draft02_task_id: t.fixed_size_task_id.for_request_payload(&version), query: Query::FixedSizeByBatchId { - batch_id: Id(rng.gen()), // Unrecognized batch ID + batch_id: BatchId(rng.gen()), // Unrecognized batch ID }, agg_param: Vec::default(), }, @@ -1346,7 +1519,7 @@ async fn http_post_fail_wrong_dap_version(version: DapVersion) { // Send a request with the wrong DAP version. let report = t.gen_test_report(task_id).await; - let mut req = t.gen_test_upload_req(report).await; + let mut req = t.gen_test_upload_req(report, task_id).await; req.version = DapVersion::Unknown; req.url = task_config.leader_url.join("upload").unwrap(); @@ -1361,7 +1534,7 @@ async fn http_post_upload(version: DapVersion) { let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let req = t.gen_test_upload_req(report).await; + let req = t.gen_test_upload_req(report, task_id).await; t.leader .http_post_upload(&req) @@ -1377,7 +1550,7 @@ async fn e2e_time_interval(version: DapVersion) { let task_config = t.leader.unchecked_get_task_config(task_id).await; let report = t.gen_test_report(task_id).await; - let req = t.gen_test_upload_req(report).await; + let req = t.gen_test_upload_req(report, task_id).await; // Client: Send upload request to Leader. t.leader.http_post_upload(&req).await.unwrap(); @@ -1406,7 +1579,7 @@ async fn e2e_fixed_size(version: DapVersion) { let task_config = t.leader.unchecked_get_task_config(task_id).await; let report = t.gen_test_report(task_id).await; - let req = t.gen_test_upload_req(report).await; + let req = t.gen_test_upload_req(report, task_id).await; // Client: Send upload request to Leader. t.leader.http_post_upload(&req).await.unwrap(); @@ -1493,11 +1666,11 @@ async fn e2e_taskprov(version: DapVersion) { ) .unwrap(); - let task_id = &report.task_id; let req = DapRequest { version, media_type: Some(MEDIA_TYPE_REPORT), - task_id: Some(task_id.clone()), + task_id: Some(taskprov_id.clone()), + resource: DapResource::Undefined, payload: report.get_encoded_with_param(&version), url: Url::parse("https://cool.biz/upload").unwrap(), sender_auth: None, @@ -1505,16 +1678,19 @@ async fn e2e_taskprov(version: DapVersion) { t.leader.http_post_upload(&req).await.unwrap(); // Leader: Run aggregation job. - t.run_agg_job(task_id).await.unwrap(); + t.run_agg_job(&taskprov_id).await.unwrap(); // The Leader is now configured with the task. - let task_config = t.leader.unchecked_get_task_config(task_id).await; + let task_config = t.leader.unchecked_get_task_config(&taskprov_id).await; // Collector: Create collection job and poll result. let query = Query::FixedSizeByBatchId { - batch_id: t.leader.current_batch_id(task_id, &task_config).unwrap(), + batch_id: t + .leader + .current_batch_id(&taskprov_id, &task_config) + .unwrap(), }; - t.run_col_job(task_id, &query).await.unwrap(); + t.run_col_job(&taskprov_id, &query).await.unwrap(); assert_metrics_include!(t.prometheus_registry, { r#"test_leader_report_counter{status="aggregated"}"#: 1, diff --git a/daphne/src/taskprov.rs b/daphne/src/taskprov.rs index ef36db114..3cd2a0e9c 100644 --- a/daphne/src/taskprov.rs +++ b/daphne/src/taskprov.rs @@ -4,7 +4,7 @@ use crate::{ messages::{ taskprov::{QueryConfigVar, TaskConfig, VdafType, VdafTypeVar}, - Extension, HpkeConfig, Id, ReportMetadata, + Extension, HpkeConfig, ReportMetadata, TaskId, }, vdaf::VdafVerifyKey, DapAbort, DapError, DapQueryConfig, DapTaskConfig, DapVersion, Prio3Config, VdafConfig, @@ -35,16 +35,16 @@ pub(crate) const TASK_PROV_SALT_DRAFT02: [u8; 32] = [ 0x74, 0x01, 0x7a, 0x52, 0xcb, 0x4c, 0xf6, 0x39, 0xfb, 0x83, 0xe0, 0x47, 0x72, 0x3a, 0x0f, 0xfe, ]; -fn compute_task_id_draft02(serialized: &[u8]) -> Id { +fn compute_task_id_draft02(serialized: &[u8]) -> TaskId { let d = digest::digest(&digest::SHA256, serialized); let dref = d.as_ref(); let mut b: [u8; 32] = [0; 32]; b[..32].copy_from_slice(&dref[..32]); - Id(b) + TaskId(b) } /// Compute the task id of a serialized task config. -pub fn compute_task_id(version: TaskprovVersion, serialized: &[u8]) -> Result { +pub fn compute_task_id(version: TaskprovVersion, serialized: &[u8]) -> Result { match version { TaskprovVersion::Draft02 => Ok(compute_task_id_draft02(serialized)), TaskprovVersion::Unknown => Err(DapError::fatal( @@ -74,7 +74,7 @@ pub(crate) fn extract_prk_from_verify_key_init( /// Expand a pseudorandom key into the VDAF verification key for a given task. pub(crate) fn expand_prk_into_verify_key( prk: &Prk, - task_id: &Id, + task_id: &TaskId, vdaf_type: VdafType, ) -> VdafVerifyKey { let info = [task_id.as_ref()]; @@ -102,7 +102,7 @@ pub(crate) fn expand_prk_into_verify_key( pub(crate) fn compute_vdaf_verify_key( version: TaskprovVersion, verify_key_init: &[u8; 32], - task_id: &Id, + task_id: &TaskId, vdaf_type: VdafType, ) -> VdafVerifyKey { expand_prk_into_verify_key( @@ -119,7 +119,7 @@ pub fn bad_request(detail: &str) -> DapError { /// Check for a taskprov extension in the report, and return it if found. pub fn get_taskprov_task_config( version: TaskprovVersion, - task_id: &Id, + task_id: &TaskId, metadata: &ReportMetadata, ) -> Result, DapError> { let taskprovs: Vec<&Extension> = metadata @@ -186,7 +186,7 @@ impl DapTaskConfig { pub fn try_from_taskprov( dap_version: DapVersion, taskprov_version: TaskprovVersion, - task_id: &Id, + task_id: &TaskId, task_config: TaskConfig, vdaf_verify_key_init: &[u8; 32], collector_hpke_config: &HpkeConfig, @@ -217,7 +217,7 @@ impl DapTaskConfig { impl ReportMetadata { /// Does this metatdata have a taskprov extension and does it match the specified id? - pub fn is_taskprov(&self, version: TaskprovVersion, task_id: &Id) -> bool { + pub fn is_taskprov(&self, version: TaskprovVersion, task_id: &TaskId) -> bool { // Don't check for taskprov usage if we don't know the version. if matches!(version, TaskprovVersion::Unknown) { return false; diff --git a/daphne/src/taskprov_test.rs b/daphne/src/taskprov_test.rs index 3396d9e71..fc41c8933 100644 --- a/daphne/src/taskprov_test.rs +++ b/daphne/src/taskprov_test.rs @@ -3,14 +3,14 @@ use crate::{ messages::taskprov::VdafType, - messages::Id, + messages::TaskId, taskprov::{compute_vdaf_verify_key, TaskprovVersion}, vdaf::VdafVerifyKey, }; #[test] fn check_vdaf_key_computation() { - let task_id = Id([ + let task_id = TaskId([ 0xb4, 0x76, 0x9b, 0xb0, 0x63, 0xa8, 0xb3, 0x31, 0x2a, 0xf7, 0x42, 0x97, 0xf3, 0x0f, 0xdb, 0xf8, 0xe0, 0xb7, 0x1c, 0x2e, 0xb2, 0x48, 0x1f, 0x59, 0x1d, 0x1d, 0x7d, 0xe6, 0x6a, 0x4c, 0xe3, 0x4f, diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index 4245d7659..861795cdc 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -8,14 +8,15 @@ use crate::{ constants, hpke::{HpkeDecrypter, HpkeReceiverConfig}, messages::{ - BatchSelector, CollectReq, CollectResp, HpkeCiphertext, HpkeConfig, Id, - PartialBatchSelector, Report, ReportId, ReportMetadata, Time, TransitionFailure, + AggregationJobId, BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq, + Draft02AggregationJobId, HpkeCiphertext, HpkeConfig, PartialBatchSelector, Report, + ReportId, ReportMetadata, TaskId, Time, TransitionFailure, }, metrics::DaphneMetrics, roles::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader}, taskprov, DapAbort, DapAggregateShare, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig, DapHelperState, DapOutputShare, DapQueryConfig, DapRequest, DapResponse, - DapTaskConfig, DapVersion, + DapTaskConfig, DapVersion, MetaAggregationJobId, }; use assert_matches::assert_matches; use async_trait::async_trait; @@ -31,9 +32,28 @@ use std::{ }; use url::Url; +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub(crate) enum MetaAggregationJobIdOwned { + Draft02(Draft02AggregationJobId), + Draft04(AggregationJobId), +} + +impl From<&MetaAggregationJobId<'_>> for MetaAggregationJobIdOwned { + fn from(agg_job_id: &MetaAggregationJobId<'_>) -> Self { + match agg_job_id { + MetaAggregationJobId::Draft02(agg_job_id) => { + Self::Draft02(agg_job_id.clone().into_owned()) + } + MetaAggregationJobId::Draft04(agg_job_id) => { + Self::Draft04(agg_job_id.clone().into_owned()) + } + } + } +} + #[derive(Eq, Hash, PartialEq)] pub(crate) enum DapBatchBucketOwned { - FixedSize { batch_id: Id }, + FixedSize { batch_id: BatchId }, TimeInterval { batch_window: Time }, } @@ -62,18 +82,18 @@ impl<'a> DapBatchBucket<'a> { } } -pub(crate) struct MockAggregatorReportSelector(pub(crate) Id); +pub(crate) struct MockAggregatorReportSelector(pub(crate) TaskId); pub(crate) struct MockAggregator { pub(crate) global_config: DapGlobalConfig, - pub(crate) tasks: Arc>>, + pub(crate) tasks: Arc>>, pub(crate) hpke_receiver_config_list: Vec, pub(crate) leader_token: BearerToken, pub(crate) collector_token: Option, // Not set by Helper - pub(crate) report_store: Arc>>, - pub(crate) leader_state_store: Arc>>, + pub(crate) report_store: Arc>>, + pub(crate) leader_state_store: Arc>>, pub(crate) helper_state_store: Arc>>, - pub(crate) agg_store: Arc>>>, + pub(crate) agg_store: Arc>>>, pub(crate) collector_hpke_config: HpkeConfig, pub(crate) taskprov_vdaf_verify_key_init: [u8; 32], pub(crate) metrics: DaphneMetrics, @@ -89,7 +109,7 @@ impl MockAggregator { /// 2) the report has been submitted by the client in the past. async fn check_report_early_fail( &self, - task_id: &Id, + task_id: &TaskId, bucket: &DapBatchBucketOwned, metadata: &ReportMetadata, ) -> Option { @@ -123,10 +143,14 @@ impl MockAggregator { /// Assign the report to a bucket. /// /// TODO(cjpatton) Figure out if we can avoid returning and owned thing here. - async fn assign_report_to_bucket(&self, report: &Report) -> Option { + async fn assign_report_to_bucket( + &self, + report: &Report, + task_id: &TaskId, + ) -> Option { let mut rng = thread_rng(); let task_config = self - .get_task_config_for(Cow::Borrowed(&report.task_id)) + .get_task_config_for(Cow::Borrowed(task_id)) .await .unwrap() .expect("tasks: unrecognized task"); @@ -138,7 +162,7 @@ impl MockAggregator { .leader_state_store .lock() .expect("leader_state_store: failed to lock"); - let leader_state_store = guard.entry(report.task_id.clone()).or_default(); + let leader_state_store = guard.entry(task_id.clone()).or_default(); // Assign the report to the first unsaturated batch. for (batch_id, report_count) in leader_state_store.batch_queue.iter_mut() { @@ -151,7 +175,7 @@ impl MockAggregator { } // No unsaturated batch exists, so create a new batch. - let batch_id = Id(rng.gen()); + let batch_id = BatchId(rng.gen()); leader_state_store .batch_queue .push_back((batch_id.clone(), 1)); @@ -161,14 +185,18 @@ impl MockAggregator { // For time-interval queries, the bucket is the batch window computed by truncating the // report timestamp. DapQueryConfig::TimeInterval => Some(DapBatchBucketOwned::TimeInterval { - batch_window: task_config.truncate_time(report.metadata.time), + batch_window: task_config.quantized_time_lower_bound(report.report_metadata.time), }), } } /// Return the ID of the batch currently being filled with reports. Panics unless the task is /// configured for fixed-size queries. - pub(crate) fn current_batch_id(&self, task_id: &Id, task_config: &DapTaskConfig) -> Option { + pub(crate) fn current_batch_id( + &self, + task_id: &TaskId, + task_config: &DapTaskConfig, + ) -> Option { // Calling current_batch() is only well-defined for fixed-size tasks. assert_matches!(task_config.query, DapQueryConfig::FixedSize { .. }); @@ -187,7 +215,7 @@ impl MockAggregator { .map(|(batch_id, _report_count)| batch_id) } - pub(crate) async fn unchecked_get_task_config(&self, task_id: &Id) -> DapTaskConfig { + pub(crate) async fn unchecked_get_task_config(&self, task_id: &TaskId) -> DapTaskConfig { self.get_task_config_for(Cow::Borrowed(task_id)) .await .expect("encountered unexpected error") @@ -201,14 +229,14 @@ impl<'a> BearerTokenProvider<'a> for MockAggregator { async fn get_leader_bearer_token_for( &'a self, - _task_id: &'a Id, + _task_id: &'a TaskId, ) -> Result, DapError> { Ok(Some(&self.leader_token)) } async fn get_collector_bearer_token_for( &'a self, - _task_id: &'a Id, + _task_id: &'a TaskId, ) -> Result, DapError> { if let Some(ref collector_token) = self.collector_token { Ok(Some(collector_token)) @@ -241,7 +269,7 @@ impl<'a> HpkeDecrypter<'a> for MockAggregator { async fn get_hpke_config_for( &'a self, _version: DapVersion, - task_id: Option<&Id>, + task_id: Option<&TaskId>, ) -> Result<&'a HpkeConfig, DapError> { if self.hpke_receiver_config_list.is_empty() { return Err(DapError::fatal("emtpy HPKE receiver config list")); @@ -260,13 +288,13 @@ impl<'a> HpkeDecrypter<'a> for MockAggregator { Ok(&self.hpke_receiver_config_list[0].config) } - async fn can_hpke_decrypt(&self, _task_id: &Id, config_id: u8) -> Result { + async fn can_hpke_decrypt(&self, _task_id: &TaskId, config_id: u8) -> Result { Ok(self.get_hpke_receiver_config_for(config_id).is_some()) } async fn hpke_decrypt( &self, - _task_id: &Id, + _task_id: &TaskId, info: &[u8], aad: &[u8], ciphertext: &HpkeCiphertext, @@ -284,7 +312,7 @@ impl<'a> HpkeDecrypter<'a> for MockAggregator { impl DapAuthorizedSender for MockAggregator { async fn authorize( &self, - task_id: &Id, + task_id: &TaskId, media_type: &'static str, _payload: &[u8], ) -> Result { @@ -320,7 +348,7 @@ where async fn get_task_config_considering_taskprov( &'srv self, version: DapVersion, - task_id: Cow<'req, Id>, + task_id: Cow<'req, TaskId>, metadata: Option<&ReportMetadata>, ) -> Result, DapError> { let taskprov_version = self.global_config.taskprov_version; @@ -374,7 +402,7 @@ where async fn is_batch_overlapping( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { let task_config = self @@ -400,7 +428,7 @@ where Ok(false) } - async fn batch_exists(&self, task_id: &Id, batch_id: &Id) -> Result { + async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result { let guard = self.agg_store.lock().expect("agg_store: failed to lock"); if let Some(agg_store) = guard.get(task_id) { Ok(agg_store @@ -415,7 +443,7 @@ where async fn put_out_shares( &self, - task_id: &Id, + task_id: &TaskId, part_batch_sel: &PartialBatchSelector, out_shares: Vec, ) -> Result<(), DapError> { @@ -439,7 +467,7 @@ where async fn get_agg_share( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { let task_config = self @@ -467,7 +495,7 @@ where async fn check_early_reject<'b>( &self, - task_id: &Id, + task_id: &TaskId, part_batch_sel: &'b PartialBatchSelector, report_meta: impl Iterator, ) -> Result, DapError> { @@ -503,7 +531,7 @@ where async fn mark_collected( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result<(), DapError> { let task_config = self.unchecked_get_task_config(task_id).await; @@ -519,7 +547,7 @@ where Ok(()) } - async fn current_batch(&self, task_id: &Id) -> std::result::Result { + async fn current_batch(&self, task_id: &TaskId) -> std::result::Result { let task_config = self.unchecked_get_task_config(task_id).await; if let Some(id) = self.current_batch_id(task_id, &task_config) { Ok(id) @@ -542,13 +570,13 @@ where { async fn put_helper_state( &self, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, helper_state: &DapHelperState, ) -> Result<(), DapError> { let helper_state_info = HelperStateInfo { task_id: task_id.clone(), - agg_job_id: agg_job_id.clone(), + agg_job_id_owned: agg_job_id.into(), }; let mut helper_state_store_mutex_guard = self @@ -573,12 +601,12 @@ where async fn get_helper_state( &self, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, ) -> Result, DapError> { let helper_state_info = HelperStateInfo { task_id: task_id.clone(), - agg_job_id: agg_job_id.clone(), + agg_job_id_owned: agg_job_id.into(), }; let mut helper_state_store_mutex_guard = self @@ -607,15 +635,15 @@ where { type ReportSelector = MockAggregatorReportSelector; - async fn put_report(&self, report: &Report) -> Result<(), DapError> { + async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> { let bucket = self - .assign_report_to_bucket(report) + .assign_report_to_bucket(report, task_id) .await .expect("could not determine batch for report"); // Check whether Report has been collected or replayed. if let Some(transition_failure) = self - .check_report_early_fail(&report.task_id, bucket.borrow(), &report.metadata) + .check_report_early_fail(task_id, bucket.borrow(), &report.report_metadata) .await { return Err(DapError::Transition(transition_failure)); @@ -627,7 +655,7 @@ where .lock() .expect("report_store: failed to lock"); let queue = guard - .get_mut(&report.task_id) + .get_mut(task_id) .expect("report_store: unrecognized task") .pending .entry(bucket) @@ -639,7 +667,7 @@ where async fn get_reports( &self, report_sel: &MockAggregatorReportSelector, - ) -> Result>>, DapError> { + ) -> Result>>, DapError> { let task_id = &report_sel.0; let task_config = self.unchecked_get_task_config(task_id).await; let mut guard = self @@ -687,10 +715,15 @@ where } // Called after receiving a CollectReq from Collector. - async fn init_collect_job(&self, collect_req: &CollectReq) -> Result { + async fn init_collect_job( + &self, + task_id: &TaskId, + collect_job_id: &Option, + collect_req: &CollectionReq, + ) -> Result { let mut rng = thread_rng(); let task_config = self - .get_task_config_for(Cow::Borrowed(&collect_req.task_id)) + .get_task_config_for(Cow::Borrowed(task_id)) .await? .ok_or_else(|| DapError::fatal("task not found"))?; @@ -701,20 +734,20 @@ where let leader_state_store = leader_state_store_mutex_guard.deref_mut(); // Construct a new Collect URI for this CollectReq. - let collect_id = Id(rng.gen()); + let collect_id = collect_job_id + .as_ref() + .map_or_else(|| CollectionJobId(rng.gen()), |cid| cid.clone()); let collect_uri = task_config .leader_url .join(&format!( "collect/task/{}/req/{}", - collect_req.task_id.to_base64url(), + task_id.to_base64url(), collect_id.to_base64url(), )) .map_err(|e| DapError::Fatal(e.to_string()))?; // Store Collect ID and CollectReq into LeaderState. - let leader_state = leader_state_store - .entry(collect_req.task_id.clone()) - .or_default(); + let leader_state = leader_state_store.entry(task_id.clone()).or_default(); leader_state.collect_ids.push_back(collect_id.clone()); let collect_job_state = CollectJobState::Pending(collect_req.clone()); leader_state @@ -727,8 +760,8 @@ where // Called to retrieve completed CollectResp at the request of Collector. async fn poll_collect_job( &self, - task_id: &Id, - collect_id: &Id, + task_id: &TaskId, + collect_id: &CollectionJobId, ) -> Result { let mut leader_state_store_mutex_guard = self .leader_state_store @@ -750,7 +783,9 @@ where } // Called to retrieve pending CollectReq. - async fn get_pending_collect_jobs(&self) -> Result, DapError> { + async fn get_pending_collect_jobs( + &self, + ) -> Result, DapError> { let mut leader_state_store_mutex_guard = self .leader_state_store .lock() @@ -758,13 +793,13 @@ where let leader_state_store = leader_state_store_mutex_guard.deref_mut(); let mut res = Vec::new(); - for (_task_id, leader_state) in leader_state_store.iter() { + for (task_id, leader_state) in leader_state_store.iter() { // Iterate over collect IDs and copy them and their associated requests to the response. for collect_id in leader_state.collect_ids.iter() { if let CollectJobState::Pending(collect_req) = leader_state.collect_jobs.get(collect_id).unwrap() { - res.push((collect_id.clone(), collect_req.clone())); + res.push((task_id.clone(), collect_id.clone(), collect_req.clone())); } } } @@ -773,9 +808,9 @@ where async fn finish_collect_job( &self, - task_id: &Id, - collect_id: &Id, - collect_resp: &CollectResp, + task_id: &TaskId, + collect_id: &CollectionJobId, + collect_resp: &Collection, ) -> Result<(), DapError> { let mut leader_state_store_mutex_guard = self .leader_state_store @@ -826,7 +861,9 @@ where .media_type .expect("tried to send request without media type") { - constants::MEDIA_TYPE_AGG_INIT_REQ | constants::MEDIA_TYPE_AGG_CONT_REQ => Ok(self + constants::MEDIA_TYPE_AGG_INIT_REQ + | constants::DRAFT02_MEDIA_TYPE_AGG_INIT_REQ + | constants::MEDIA_TYPE_AGG_CONT_REQ => Ok(self .peer .as_ref() .expect("peer not configured") @@ -843,13 +880,31 @@ where s => unreachable!("unhandled media type: {}", s), } } + + async fn send_http_put(&self, req: DapRequest) -> Result { + match req + .media_type + .expect("tried to send request without media type") + { + constants::MEDIA_TYPE_AGG_INIT_REQ | constants::DRAFT02_MEDIA_TYPE_AGG_INIT_REQ => { + Ok(self + .peer + .as_ref() + .expect("peer not configured") + .http_post_aggregate(&req) + .await + .expect("peer aborted unexpectedly")) + } + s => unreachable!("unhandled media type: {}", s), + } + } } /// Information associated to a certain helper state for a given task ID and aggregate job ID. #[derive(Clone, Eq, Hash, PartialEq, Deserialize, Serialize)] pub(crate) struct HelperStateInfo { - task_id: Id, - agg_job_id: Id, + task_id: TaskId, + agg_job_id_owned: MetaAggregationJobIdOwned, } /// Stores the reports received from Clients. @@ -861,8 +916,8 @@ pub(crate) struct ReportStore { /// Stores the state of the collect job. pub(crate) enum CollectJobState { - Pending(CollectReq), - Processed(CollectResp), + Pending(CollectionReq), + Processed(Collection), } /// LeaderState keeps track of the following: @@ -870,9 +925,9 @@ pub(crate) enum CollectJobState { /// * The state of the collect job associated to the Collect ID. #[derive(Default)] pub(crate) struct LeaderState { - collect_ids: VecDeque, - collect_jobs: HashMap, - batch_queue: VecDeque<(Id, u64)>, // Batch ID, batch size + collect_ids: VecDeque, + collect_jobs: HashMap, + batch_queue: VecDeque<(BatchId, u64)>, // Batch ID, batch size } /// AggStore keeps track of the following: @@ -897,7 +952,7 @@ pub(crate) struct AggStore { // // and // -// something_draft03 +// something_draft04 // // that called something(version) with the appropriate version. // @@ -921,7 +976,7 @@ macro_rules! test_versions { ($($fname:ident),*) => { $( test_version! { $fname, Draft02 } - test_version! { $fname, Draft03 } + test_version! { $fname, Draft04 } )* }; } @@ -943,7 +998,7 @@ macro_rules! async_test_versions { ($($fname:ident),*) => { $( async_test_version! { $fname, Draft02 } - async_test_version! { $fname, Draft03 } + async_test_version! { $fname, Draft04 } )* }; } diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index 2237c0cad..b594f549a 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -7,10 +7,10 @@ use crate::{ hpke::HpkeDecrypter, messages::{ - encode_u32_bytes, AggregateContinueReq, AggregateInitializeReq, AggregateResp, - BatchSelector, Extension, HpkeCiphertext, HpkeConfig, Id, PartialBatchSelector, - PlaintextInputShare, Report, ReportId, ReportMetadata, ReportShare, Time, Transition, - TransitionFailure, TransitionVar, + encode_u32_bytes, AggregationJobContinueReq, AggregationJobInitReq, AggregationJobResp, + BatchSelector, Extension, HpkeCiphertext, HpkeConfig, PartialBatchSelector, + PlaintextInputShare, Report, ReportId, ReportMetadata, ReportShare, TaskId, Time, + Transition, TransitionFailure, TransitionVar, }, metrics::DaphneMetrics, vdaf::{ @@ -25,7 +25,7 @@ use crate::{ }, DapAbort, DapAggregateResult, DapAggregateShare, DapError, DapHelperState, DapHelperTransition, DapLeaderState, DapLeaderTransition, DapLeaderUncommitted, DapMeasurement, DapOutputShare, - DapTaskConfig, DapVersion, VdafConfig, + DapTaskConfig, DapVersion, MetaAggregationJobId, VdafConfig, }; use prio::{ codec::{CodecError, Decode, Encode, ParameterizedEncode}, @@ -40,9 +40,9 @@ use serde::{Deserialize, Serialize}; use std::{collections::HashSet, convert::TryInto}; const CTX_INPUT_SHARE_DRAFT02: &[u8] = b"dap-02 input share"; -const CTX_INPUT_SHARE_DRAFT03: &[u8] = b"dap-03 input share"; +const CTX_INPUT_SHARE_DRAFT04: &[u8] = b"dap-04 input share"; const CTX_AGG_SHARE_DRAFT02: &[u8] = b"dap-02 aggregate share"; -const CTX_AGG_SHARE_DRAFT03: &[u8] = b"dap-03 aggregate share"; +const CTX_AGG_SHARE_DRAFT04: &[u8] = b"dap-04 aggregate share"; const CTX_ROLE_COLLECTOR: u8 = 0; const CTX_ROLE_CLIENT: u8 = 1; const CTX_ROLE_LEADER: u8 = 2; @@ -98,9 +98,9 @@ pub(crate) enum VdafAggregateShare { impl Encode for VdafAggregateShare { fn encode(&self, bytes: &mut Vec) { match self { - VdafAggregateShare::Field64(agg_share) => bytes.append(&mut agg_share.into()), - VdafAggregateShare::Field128(agg_share) => bytes.append(&mut agg_share.into()), - VdafAggregateShare::FieldPrio2(agg_share) => bytes.append(&mut agg_share.into()), + VdafAggregateShare::Field64(agg_share) => agg_share.encode(bytes), + VdafAggregateShare::Field128(agg_share) => agg_share.encode(bytes), + VdafAggregateShare::FieldPrio2(agg_share) => agg_share.encode(bytes), } } } @@ -168,18 +168,21 @@ impl VdafConfig { &self, hpke_config_list: &[HpkeConfig], time: Time, - task_id: &Id, + task_id: &TaskId, measurement: DapMeasurement, extensions: Vec, version: DapVersion, ) -> Result { - let (public_share, input_shares) = self.produce_input_shares(measurement)?; + let mut rng = thread_rng(); + let report_id = ReportId(rng.gen()); + let (public_share, input_shares) = self.produce_input_shares(measurement, &report_id.0)?; self.produce_report_with_extensions_for_shares( public_share, input_shares, hpke_config_list, time, task_id, + &report_id, extensions, version, ) @@ -193,17 +196,17 @@ impl VdafConfig { mut input_shares: Vec>, hpke_config_list: &[HpkeConfig], time: Time, - task_id: &Id, + task_id: &TaskId, + report_id: &ReportId, extensions: Vec, version: DapVersion, ) -> Result { - let mut rng = thread_rng(); let report_extensions = match version { DapVersion::Draft02 => extensions.clone(), _ => vec![], }; let metadata = ReportMetadata { - id: ReportId(rng.gen()), + id: report_id.clone(), time, extensions: report_extensions, }; @@ -226,7 +229,7 @@ impl VdafConfig { let input_share_text = match version { DapVersion::Draft02 => CTX_INPUT_SHARE_DRAFT02, - DapVersion::Draft03 => CTX_INPUT_SHARE_DRAFT03, + DapVersion::Draft04 => CTX_INPUT_SHARE_DRAFT04, _ => return Err(unimplemented_version()), }; let n: usize = input_share_text.len(); @@ -264,8 +267,8 @@ impl VdafConfig { } Ok(Report { - task_id: task_id.clone(), - metadata, + draft02_task_id: task_id.for_request_payload(&version), + report_metadata: metadata, public_share, encrypted_input_shares, }) @@ -275,13 +278,12 @@ impl VdafConfig { pub(crate) fn produce_input_shares( &self, measurement: DapMeasurement, + nonce: &[u8; 16], ) -> Result<(Vec, Vec>), DapError> { - let public_share = Vec::new(); - let input_shares = match self { - Self::Prio3(prio3_config) => prio3_shard(prio3_config, measurement)?, - Self::Prio2 { dimension } => prio2_shard(*dimension, measurement)?, - }; - Ok((public_share, input_shares)) + match self { + Self::Prio3(prio3_config) => Ok(prio3_shard(prio3_config, measurement, nonce)?), + Self::Prio2 { dimension } => Ok(prio2_shard(*dimension, measurement, nonce)?), + } } /// Generate a report for a measurement. This method is run by the Client. @@ -305,7 +307,7 @@ impl VdafConfig { &self, hpke_config_list: &[HpkeConfig], time: Time, - task_id: &Id, + task_id: &TaskId, measurement: DapMeasurement, version: DapVersion, ) -> Result { @@ -340,7 +342,7 @@ impl VdafConfig { &self, decrypter: &impl HpkeDecrypter<'_>, is_leader: bool, - task_id: &Id, + task_id: &TaskId, task_config: &DapTaskConfig, metadata: &ReportMetadata, public_share: &[u8], @@ -350,13 +352,9 @@ impl VdafConfig { return Err(DapError::Transition(TransitionFailure::TaskExpired)); } - if !public_share.is_empty() { - return Err(DapError::Transition(TransitionFailure::VdafPrepError)); - } - let input_share_text = match task_config.version { DapVersion::Draft02 => CTX_INPUT_SHARE_DRAFT02, - DapVersion::Draft03 => CTX_INPUT_SHARE_DRAFT03, + DapVersion::Draft04 => CTX_INPUT_SHARE_DRAFT04, _ => return Err(unimplemented_version()), }; let n: usize = input_share_text.len(); @@ -397,7 +395,8 @@ impl VdafConfig { prio3_config, verify_key, agg_id, - metadata.id.as_ref(), + &metadata.id.0, + public_share, &input_share.payload, )?) } @@ -406,7 +405,8 @@ impl VdafConfig { *dimension, verify_key, agg_id, - metadata.id.as_ref(), + &metadata.id.0, + public_share, &input_share.payload, )?) } @@ -433,33 +433,27 @@ impl VdafConfig { /// /// * `version` is the DapVersion to use. #[allow(clippy::too_many_arguments)] - pub(crate) async fn produce_agg_init_req( + pub(crate) async fn produce_agg_job_init_req( &self, decrypter: &impl HpkeDecrypter<'_>, - task_id: &Id, + task_id: &TaskId, task_config: &DapTaskConfig, - agg_job_id: &Id, + agg_job_id: &MetaAggregationJobId<'_>, part_batch_sel: &PartialBatchSelector, reports: Vec, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapAbort> { let mut processed = HashSet::with_capacity(reports.len()); let mut states = Vec::with_capacity(reports.len()); let mut seq = Vec::with_capacity(reports.len()); for report in reports.into_iter() { - if processed.contains(&report.metadata.id) { + if processed.contains(&report.report_metadata.id) { return Err(DapError::fatal( "tried to process report sequence with non-unique report IDs", ) .into()); } - processed.insert(report.metadata.id.clone()); - - if &report.task_id != task_id || report.encrypted_input_shares.len() != 2 { - return Err( - DapError::fatal("tried to process report with incorrect task ID").into(), - ); - } + processed.insert(report.report_metadata.id.clone()); let (leader_share, helper_share) = { let mut it = report.encrypted_input_shares.into_iter(); @@ -472,7 +466,7 @@ impl VdafConfig { true, // is_leader task_id, task_config, - &report.metadata, + &report.report_metadata, &report.public_share, &leader_share, ) @@ -482,11 +476,11 @@ impl VdafConfig { states.push(( step, message, - report.metadata.time, - report.metadata.id.clone(), + report.report_metadata.time, + report.report_metadata.id.clone(), )); seq.push(ReportShare { - metadata: report.metadata, + report_metadata: report.report_metadata, public_share: report.public_share, encrypted_input_share: helper_share, }); @@ -508,9 +502,9 @@ impl VdafConfig { Ok(DapLeaderTransition::Continue( DapLeaderState { seq: states }, - AggregateInitializeReq { - task_id: task_id.clone(), - agg_job_id: agg_job_id.clone(), + AggregationJobInitReq { + draft02_task_id: task_id.for_request_payload(&task_config.version), + draft02_agg_job_id: agg_job_id.for_request_payload(), agg_param: Vec::default(), part_batch_sel: part_batch_sel.clone(), report_shares: seq, @@ -536,33 +530,34 @@ impl VdafConfig { /// /// * `task_id` indicates the DAP task for which the reports are being processed. /// - /// * `agg_init_req` is the request sent by the Leader. + /// * `agg_job_init_req` is the request sent by the Leader. /// /// * `version` is the DapVersion to use. - pub(crate) async fn handle_agg_init_req( + pub(crate) async fn handle_agg_job_init_req( &self, decrypter: &impl HpkeDecrypter<'_>, + task_id: &TaskId, task_config: &DapTaskConfig, - agg_init_req: &AggregateInitializeReq, + agg_job_init_req: &AggregationJobInitReq, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { - let num_reports = agg_init_req.report_shares.len(); + ) -> Result, DapAbort> { + let num_reports = agg_job_init_req.report_shares.len(); let mut processed = HashSet::with_capacity(num_reports); let mut states = Vec::with_capacity(num_reports); let mut transitions = Vec::with_capacity(num_reports); - for report_share in agg_init_req.report_shares.iter() { - if processed.contains(&report_share.metadata.id) { + for report_share in agg_job_init_req.report_shares.iter() { + if processed.contains(&report_share.report_metadata.id) { return Err(DapAbort::UnrecognizedMessage); } - processed.insert(report_share.metadata.id.clone()); + processed.insert(report_share.report_metadata.id.clone()); let var = match self .consume_report_share( decrypter, false, // is_leader - &agg_init_req.task_id, + task_id, task_config, - &report_share.metadata, + &report_share.report_metadata, &report_share.public_share, &report_share.encrypted_input_share, ) @@ -575,8 +570,8 @@ impl VdafConfig { }; states.push(( step, - report_share.metadata.time, - report_share.metadata.id.clone(), + report_share.report_metadata.time, + report_share.report_metadata.id.clone(), )); TransitionVar::Continued(message_data) } @@ -593,17 +588,17 @@ impl VdafConfig { }; transitions.push(Transition { - report_id: report_share.metadata.id.clone(), + report_id: report_share.report_metadata.id.clone(), var, }); } Ok(DapHelperTransition::Continue( DapHelperState { - part_batch_sel: agg_init_req.part_batch_sel.clone(), + part_batch_sel: agg_job_init_req.part_batch_sel.clone(), seq: states, }, - AggregateResp { transitions }, + AggregationJobResp { transitions }, )) } @@ -618,25 +613,28 @@ impl VdafConfig { /// /// * `state` is the Leader's current state. /// - /// * `agg_resp` is the previous aggregate response sent by the Helper. - pub(crate) fn handle_agg_resp( + /// * `agg_job_resp` is the previous aggregate response sent by the Helper. + pub(crate) fn handle_agg_job_resp( &self, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, state: DapLeaderState, - agg_resp: AggregateResp, + agg_job_resp: AggregationJobResp, + version: DapVersion, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { - if agg_resp.transitions.len() != state.seq.len() { + ) -> Result, DapAbort> { + if agg_job_resp.transitions.len() != state.seq.len() { return Err(DapAbort::UnrecognizedMessage); } let mut seq = Vec::with_capacity(state.seq.len()); let mut states = Vec::with_capacity(state.seq.len()); - for (helper, (leader_step, leader_message, leader_time, leader_report_id)) in - agg_resp.transitions.into_iter().zip(state.seq.into_iter()) + for (helper, (leader_step, leader_message, leader_time, leader_report_id)) in agg_job_resp + .transitions + .into_iter() + .zip(state.seq.into_iter()) { - // TODO spec: Consider removing the report ID from the AggregateResp. + // TODO spec: Consider removing the report ID from the AggregationJobResp. if helper.report_id != leader_report_id { return Err(DapAbort::UnrecognizedMessage); } @@ -711,9 +709,14 @@ impl VdafConfig { Ok(DapLeaderTransition::Uncommitted( DapLeaderUncommitted { seq: states }, - AggregateContinueReq { - task_id: task_id.clone(), - agg_job_id: agg_job_id.clone(), + AggregationJobContinueReq { + draft02_task_id: task_id.for_request_payload(&version), + draft02_agg_job_id: agg_job_id.for_request_payload(), + round: if version == DapVersion::Draft02 { + None + } else { + Some(1) + }, transitions: seq, }, )) @@ -729,18 +732,27 @@ impl VdafConfig { /// * `state` is the helper's current state. /// /// * `agg_cont_req` is the aggregate request sent by the Leader. - pub(crate) fn handle_agg_cont_req( + pub(crate) fn handle_agg_job_cont_req( &self, state: DapHelperState, - agg_cont_req: &AggregateContinueReq, + agg_cont_req: &AggregationJobContinueReq, metrics: &DaphneMetrics, - ) -> Result, DapAbort> { + ) -> Result, DapAbort> { + if let Some(round) = agg_cont_req.round { + if round == 0 { + return Err(DapAbort::UnrecognizedMessage); + } + // TODO(bhalleycf) For now, there is only ever one round, and we don't try to do + // aggregation-round-skew-recovery. + if round != 1 { + return Err(DapAbort::RoundMismatch); + } + } let mut processed = HashSet::with_capacity(state.seq.len()); let mut recognized = HashSet::with_capacity(state.seq.len()); for (_, _, report_id) in state.seq.iter() { recognized.insert(report_id.clone()); } - let num_reports = state.seq.len(); let mut transitions = Vec::with_capacity(num_reports); let mut out_shares = Vec::with_capacity(num_reports); @@ -818,7 +830,7 @@ impl VdafConfig { Ok(DapHelperTransition::Finish( out_shares, - AggregateResp { transitions }, + AggregationJobResp { transitions }, )) } @@ -834,24 +846,24 @@ impl VdafConfig { /// * `uncommited` is the Leader's current state, i.e., the set of output shares output from /// the previous round that have not yet been commmitted to. /// - /// * `agg_resp` is the previous aggregate response sent by the Helper. - pub(crate) fn handle_final_agg_resp( + /// * `agg_job_resp` is the previous aggregate response sent by the Helper. + pub(crate) fn handle_final_agg_job_resp( &self, uncommitted: DapLeaderUncommitted, - agg_resp: AggregateResp, + agg_job_resp: AggregationJobResp, metrics: &DaphneMetrics, ) -> Result, DapAbort> { - if agg_resp.transitions.len() != uncommitted.seq.len() { + if agg_job_resp.transitions.len() != uncommitted.seq.len() { return Err(DapAbort::UnrecognizedMessage); } let mut out_shares = Vec::with_capacity(uncommitted.seq.len()); - for (helper, (out_share, leader_report_id)) in agg_resp + for (helper, (out_share, leader_report_id)) in agg_job_resp .transitions .into_iter() .zip(uncommitted.seq.into_iter()) { - // TODO spec: Consider removing the report ID from the AggregateResp. + // TODO spec: Consider removing the report ID from the AggregationJobResp. if helper.report_id != leader_report_id { return Err(DapAbort::UnrecognizedMessage); } @@ -893,7 +905,7 @@ impl VdafConfig { pub(crate) fn produce_leader_encrypted_agg_share( &self, hpke_config: &HpkeConfig, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, agg_share: &DapAggregateShare, version: DapVersion, @@ -908,7 +920,7 @@ impl VdafConfig { pub(crate) fn produce_helper_encrypted_agg_share( &self, hpke_config: &HpkeConfig, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, agg_share: &DapAggregateShare, version: DapVersion, @@ -936,7 +948,7 @@ impl VdafConfig { pub async fn consume_encrypted_agg_shares( &self, decrypter: &impl HpkeDecrypter<'_>, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, report_count: u64, encrypted_agg_shares: Vec, @@ -944,7 +956,7 @@ impl VdafConfig { ) -> Result { let agg_share_text = match version { DapVersion::Draft02 => CTX_AGG_SHARE_DRAFT02, - DapVersion::Draft03 => CTX_AGG_SHARE_DRAFT03, + DapVersion::Draft04 => CTX_AGG_SHARE_DRAFT04, _ => return Err(unimplemented_version()), }; let n: usize = agg_share_text.len(); @@ -993,7 +1005,7 @@ impl VdafConfig { fn produce_encrypted_agg_share( is_leader: bool, hpke_config: &HpkeConfig, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, agg_share: &DapAggregateShare, version: DapVersion, @@ -1006,7 +1018,7 @@ fn produce_encrypted_agg_share( let agg_share_text = match version { DapVersion::Draft02 => CTX_AGG_SHARE_DRAFT02, - DapVersion::Draft03 => CTX_AGG_SHARE_DRAFT03, + DapVersion::Draft04 => CTX_AGG_SHARE_DRAFT04, _ => return Err(unimplemented_version_abort()), }; let n: usize = agg_share_text.len(); diff --git a/daphne/src/vdaf/mod_test.rs b/daphne/src/vdaf/mod_test.rs index dc648008f..e52247434 100644 --- a/daphne/src/vdaf/mod_test.rs +++ b/daphne/src/vdaf/mod_test.rs @@ -6,22 +6,26 @@ use crate::{ async_test_versions, hpke::HpkeReceiverConfig, messages::{ - AggregateContinueReq, AggregateInitializeReq, AggregateResp, BatchSelector, HpkeAeadId, - HpkeCiphertext, HpkeConfig, HpkeKdfId, HpkeKemId, Id, Interval, PartialBatchSelector, - Report, ReportId, ReportShare, Time, Transition, TransitionFailure, TransitionVar, + AggregationJobContinueReq, AggregationJobInitReq, AggregationJobResp, BatchSelector, + HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeKdfId, HpkeKemId, Interval, + PartialBatchSelector, Report, ReportId, ReportShare, TaskId, Time, Transition, + TransitionFailure, TransitionVar, }, metrics::DaphneMetrics, test_version, test_versions, DapAbort, DapAggregateResult, DapAggregateShare, DapError, DapHelperState, DapHelperTransition, DapLeaderState, DapLeaderTransition, DapLeaderUncommitted, - DapMeasurement, DapOutputShare, DapQueryConfig, DapTaskConfig, DapVersion, Prio3Config, - VdafAggregateShare, VdafConfig, VdafMessage, VdafState, + DapMeasurement, DapOutputShare, DapQueryConfig, DapTaskConfig, DapVersion, + MetaAggregationJobId, Prio3Config, VdafAggregateShare, VdafConfig, VdafMessage, VdafState, }; use assert_matches::assert_matches; use hpke_rs::HpkePublicKey; use paste::paste; -use prio::vdaf::{ - prio3::Prio3, Aggregatable, Aggregator as VdafAggregator, Collector as VdafCollector, - PrepareTransition, +use prio::{ + field::Field64, + vdaf::{ + prio3::Prio3, Aggregatable, AggregateShare, Aggregator as VdafAggregator, + Collector as VdafCollector, OutputShare, PrepareTransition, + }, }; use rand::prelude::*; use std::{fmt::Debug, time::SystemTime}; @@ -90,7 +94,7 @@ async fn roundtrip_report(version: DapVersion) { true, // is_leader &t.task_id, &t.task_config, - &report.metadata, + &report.report_metadata, &report.public_share, &report.encrypted_input_shares[0], ) @@ -103,7 +107,7 @@ async fn roundtrip_report(version: DapVersion) { false, // is_leader &t.task_id, &t.task_config, - &report.metadata, + &report.report_metadata, &report.public_share, &report.encrypted_input_shares[1], ) @@ -117,7 +121,7 @@ async fn roundtrip_report(version: DapVersion) { VdafMessage::Prio3ShareField64(leader_share), VdafMessage::Prio3ShareField64(helper_share), ) => { - let vdaf = Prio3::new_aes128_count(2).unwrap(); + let vdaf = Prio3::new_count(2).unwrap(); let message = vdaf .prepare_preprocess([leader_share, helper_share]) .unwrap(); @@ -175,7 +179,7 @@ fn roundtrip_report_unsupported_hpke_suite(version: DapVersion) { test_versions! { roundtrip_report_unsupported_hpke_suite } -async fn produce_agg_init_req(version: DapVersion) { +async fn produce_agg_job_init_req(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![ DapMeasurement::U64(1), @@ -183,29 +187,35 @@ async fn produce_agg_init_req(version: DapVersion) { DapMeasurement::U64(0), ]); - let (leader_state, agg_init_req) = t - .produce_agg_init_req(reports.clone()) + let (leader_state, agg_job_init_req) = t + .produce_agg_job_init_req(reports.clone()) .await .unwrap_continue(); assert_eq!(leader_state.seq.len(), 3); - assert_eq!(agg_init_req.task_id, t.task_id); - assert_eq!(agg_init_req.agg_param.len(), 0); - assert_eq!(agg_init_req.report_shares.len(), 3); - for (report_shares, report) in agg_init_req.report_shares.iter().zip(reports.iter()) { - assert_eq!(report_shares.metadata.id, report.metadata.id); + assert_eq!( + agg_job_init_req.draft02_task_id, + t.task_id.for_request_payload(&version) + ); + assert_eq!(agg_job_init_req.agg_param.len(), 0); + assert_eq!(agg_job_init_req.report_shares.len(), 3); + for (report_shares, report) in agg_job_init_req.report_shares.iter().zip(reports.iter()) { + assert_eq!(report_shares.report_metadata.id, report.report_metadata.id); } - let (helper_state, agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (helper_state, agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); assert_eq!(helper_state.seq.len(), 3); - assert_eq!(agg_resp.transitions.len(), 3); - for (sub, report) in agg_resp.transitions.iter().zip(reports.iter()) { - assert_eq!(sub.report_id, report.metadata.id); + assert_eq!(agg_job_resp.transitions.len(), 3); + for (sub, report) in agg_job_resp.transitions.iter().zip(reports.iter()) { + assert_eq!(sub.report_id, report.report_metadata.id); } } -async_test_versions! { produce_agg_init_req } +async_test_versions! { produce_agg_job_init_req } -async fn produce_agg_init_req_skip_hpke_decrypt_err(version: DapVersion) { +async fn produce_agg_job_init_req_skip_hpke_decrypt_err(version: DapVersion) { let t = Test::new(TEST_VDAF, version); let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); @@ -213,7 +223,7 @@ async fn produce_agg_init_req_skip_hpke_decrypt_err(version: DapVersion) { reports[0].encrypted_input_shares[0].payload[0] ^= 1; assert_matches!( - t.produce_agg_init_req(reports).await, + t.produce_agg_job_init_req(reports).await, DapLeaderTransition::Skip ); @@ -222,9 +232,9 @@ async fn produce_agg_init_req_skip_hpke_decrypt_err(version: DapVersion) { }); } -async_test_versions! { produce_agg_init_req_skip_hpke_decrypt_err } +async_test_versions! { produce_agg_job_init_req_skip_hpke_decrypt_err } -async fn produce_agg_init_req_skip_hpke_unknown_config_id(version: DapVersion) { +async fn produce_agg_job_init_req_skip_hpke_unknown_config_id(version: DapVersion) { let t = Test::new(TEST_VDAF, version); let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); @@ -232,7 +242,7 @@ async fn produce_agg_init_req_skip_hpke_unknown_config_id(version: DapVersion) { reports[0].encrypted_input_shares[0].config_id ^= 1; assert_matches!( - t.produce_agg_init_req(reports).await, + t.produce_agg_job_init_req(reports).await, DapLeaderTransition::Skip ); @@ -241,9 +251,9 @@ async fn produce_agg_init_req_skip_hpke_unknown_config_id(version: DapVersion) { }); } -async_test_versions! { produce_agg_init_req_skip_hpke_unknown_config_id } +async_test_versions! { produce_agg_job_init_req_skip_hpke_unknown_config_id } -async fn produce_agg_init_req_skip_vdaf_prep_error(version: DapVersion) { +async fn produce_agg_job_init_req_skip_vdaf_prep_error(version: DapVersion) { let t = Test::new(TEST_VDAF, version); let reports = vec![ t.produce_invalid_report_public_share_decode_failure(DapMeasurement::U64(1), version), @@ -251,7 +261,7 @@ async fn produce_agg_init_req_skip_vdaf_prep_error(version: DapVersion) { ]; assert_matches!( - t.produce_agg_init_req(reports).await, + t.produce_agg_job_init_req(reports).await, DapLeaderTransition::Skip ); @@ -260,9 +270,9 @@ async fn produce_agg_init_req_skip_vdaf_prep_error(version: DapVersion) { }); } -async_test_versions! { produce_agg_init_req_skip_vdaf_prep_error } +async_test_versions! { produce_agg_job_init_req_skip_vdaf_prep_error } -async fn handle_agg_init_req_hpke_decrypt_err(version: DapVersion) { +async fn handle_agg_job_init_req_hpke_decrypt_err(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); @@ -270,14 +280,14 @@ async fn handle_agg_init_req_hpke_decrypt_err(version: DapVersion) { reports[0].encrypted_input_shares[1].payload[0] ^= 1; let (_, agg_req) = t - .produce_agg_init_req(reports.clone()) + .produce_agg_job_init_req(reports.clone()) .await .unwrap_continue(); - let (_, agg_resp) = t.handle_agg_init_req(agg_req).await.unwrap_continue(); + let (_, agg_job_resp) = t.handle_agg_job_init_req(agg_req).await.unwrap_continue(); - assert_eq!(agg_resp.transitions.len(), 1); + assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( - agg_resp.transitions[0].var, + agg_job_resp.transitions[0].var, TransitionVar::Failed(TransitionFailure::HpkeDecryptError) ); @@ -286,9 +296,9 @@ async fn handle_agg_init_req_hpke_decrypt_err(version: DapVersion) { }); } -async_test_versions! { handle_agg_init_req_hpke_decrypt_err } +async_test_versions! { handle_agg_job_init_req_hpke_decrypt_err } -async fn handle_agg_init_req_hpke_unknown_config_id(version: DapVersion) { +async fn handle_agg_job_init_req_hpke_unknown_config_id(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let mut reports = t.produce_reports(vec![DapMeasurement::U64(1)]); @@ -296,14 +306,14 @@ async fn handle_agg_init_req_hpke_unknown_config_id(version: DapVersion) { reports[0].encrypted_input_shares[1].config_id ^= 1; let (_, agg_req) = t - .produce_agg_init_req(reports.clone()) + .produce_agg_job_init_req(reports.clone()) .await .unwrap_continue(); - let (_, agg_resp) = t.handle_agg_init_req(agg_req).await.unwrap_continue(); + let (_, agg_job_resp) = t.handle_agg_job_init_req(agg_req).await.unwrap_continue(); - assert_eq!(agg_resp.transitions.len(), 1); + assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( - agg_resp.transitions[0].var, + agg_job_resp.transitions[0].var, TransitionVar::Failed(TransitionFailure::HpkeUnknownConfigId) ); @@ -312,43 +322,43 @@ async fn handle_agg_init_req_hpke_unknown_config_id(version: DapVersion) { }); } -async_test_versions! { handle_agg_init_req_hpke_unknown_config_id } +async_test_versions! { handle_agg_job_init_req_hpke_unknown_config_id } -async fn handle_agg_init_req_vdaf_prep_error(version: DapVersion) { +async fn handle_agg_job_init_req_vdaf_prep_error(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let report0 = t.produce_invalid_report_public_share_decode_failure(DapMeasurement::U64(1), version); let report1 = t.produce_invalid_report_input_share_decode_failure(DapMeasurement::U64(1), version); - let agg_req = AggregateInitializeReq { - task_id: t.task_id.clone(), - agg_job_id: t.agg_job_id.clone(), + let agg_req = AggregationJobInitReq { + draft02_task_id: t.task_id.for_request_payload(&version), + draft02_agg_job_id: t.agg_job_id.for_request_payload(), agg_param: Vec::new(), part_batch_sel: PartialBatchSelector::TimeInterval, report_shares: vec![ ReportShare { - metadata: report0.metadata, + report_metadata: report0.report_metadata, public_share: report0.public_share, encrypted_input_share: report0.encrypted_input_shares[1].clone(), }, ReportShare { - metadata: report1.metadata, + report_metadata: report1.report_metadata, public_share: report1.public_share, encrypted_input_share: report1.encrypted_input_shares[1].clone(), }, ], }; - let (_, agg_resp) = t.handle_agg_init_req(agg_req).await.unwrap_continue(); + let (_, agg_job_resp) = t.handle_agg_job_init_req(agg_req).await.unwrap_continue(); - assert_eq!(agg_resp.transitions.len(), 2); + assert_eq!(agg_job_resp.transitions.len(), 2); assert_matches!( - agg_resp.transitions[0].var, + agg_job_resp.transitions[0].var, TransitionVar::Failed(TransitionFailure::VdafPrepError) ); assert_matches!( - agg_resp.transitions[1].var, + agg_job_resp.transitions[1].var, TransitionVar::Failed(TransitionFailure::VdafPrepError) ); @@ -357,84 +367,100 @@ async fn handle_agg_init_req_vdaf_prep_error(version: DapVersion) { }); } -async_test_versions! { handle_agg_init_req_vdaf_prep_error } +async_test_versions! { handle_agg_job_init_req_vdaf_prep_error } -async fn agg_resp_abort_transition_out_of_order(version: DapVersion) { +async fn agg_job_resp_abort_transition_out_of_order(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (_, mut agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (_, mut agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); // Helper sends transitions out of order. - let tmp = agg_resp.transitions[0].clone(); - agg_resp.transitions[0] = agg_resp.transitions[1].clone(); - agg_resp.transitions[1] = tmp; + let tmp = agg_job_resp.transitions[0].clone(); + agg_job_resp.transitions[0] = agg_job_resp.transitions[1].clone(); + agg_job_resp.transitions[1] = tmp; assert_matches!( - t.handle_agg_resp_expect_err(leader_state, agg_resp), + t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), DapAbort::UnrecognizedMessage ); } -async_test_versions! { agg_resp_abort_transition_out_of_order } +async_test_versions! { agg_job_resp_abort_transition_out_of_order } -async fn agg_resp_abort_report_id_repeated(version: DapVersion) { +async fn agg_job_resp_abort_report_id_repeated(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (_, mut agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (_, mut agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); // Helper sends a transition twice. - let repeated_transition = agg_resp.transitions[0].clone(); - agg_resp.transitions.push(repeated_transition); + let repeated_transition = agg_job_resp.transitions[0].clone(); + agg_job_resp.transitions.push(repeated_transition); assert_matches!( - t.handle_agg_resp_expect_err(leader_state, agg_resp), + t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), DapAbort::UnrecognizedMessage ); } -async_test_versions! { agg_resp_abort_report_id_repeated } +async_test_versions! { agg_job_resp_abort_report_id_repeated } -async fn agg_resp_abort_unrecognized_report_id(version: DapVersion) { +async fn agg_job_resp_abort_unrecognized_report_id(version: DapVersion) { let mut rng = thread_rng(); let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (_, mut agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (_, mut agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); // Helper sent a transition with an unrecognized report ID. - agg_resp.transitions.push(Transition { + agg_job_resp.transitions.push(Transition { report_id: ReportId(rng.gen()), var: TransitionVar::Continued(b"whatever".to_vec()), }); assert_matches!( - t.handle_agg_resp_expect_err(leader_state, agg_resp), + t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), DapAbort::UnrecognizedMessage ); } -async_test_versions! { agg_resp_abort_unrecognized_report_id } +async_test_versions! { agg_job_resp_abort_unrecognized_report_id } -async fn agg_resp_abort_invalid_transition(version: DapVersion) { +async fn agg_job_resp_abort_invalid_transition(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1)]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (_, mut agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (_, mut agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); // Helper sent a transition with an unrecognized report ID. - agg_resp.transitions[0].var = TransitionVar::Finished; + agg_job_resp.transitions[0].var = TransitionVar::Finished; assert_matches!( - t.handle_agg_resp_expect_err(leader_state, agg_resp), + t.handle_agg_job_resp_expect_err(leader_state, agg_job_resp), DapAbort::UnrecognizedMessage ); } -async_test_versions! { agg_resp_abort_invalid_transition } +async_test_versions! { agg_job_resp_abort_invalid_transition } -async fn agg_cont_req(version: DapVersion) { +async fn agg_job_cont_req(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![ DapMeasurement::U64(1), @@ -443,20 +469,24 @@ async fn agg_cont_req(version: DapVersion) { DapMeasurement::U64(0), DapMeasurement::U64(1), ]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (helper_state, agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (helper_state, agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); - let (leader_uncommitted, agg_cont_req) = t - .handle_agg_resp(leader_state, agg_resp) + let (leader_uncommitted, agg_job_cont_req) = t + .handle_agg_job_resp(leader_state, agg_job_resp) .unwrap_uncommitted(); - let (helper_out_shares, agg_resp) = t - .handle_agg_cont_req(helper_state, &agg_cont_req) + let (helper_out_shares, agg_job_resp) = t + .handle_agg_job_cont_req(helper_state, &agg_job_cont_req) .unwrap_finish(); assert_eq!(helper_out_shares.len(), 5); - assert_eq!(agg_resp.transitions.len(), 5); + assert_eq!(agg_job_resp.transitions.len(), 5); - let leader_out_shares = t.handle_final_agg_resp(leader_uncommitted, agg_resp); + let leader_out_shares = t.handle_final_agg_job_resp(leader_uncommitted, agg_job_resp); assert_eq!(leader_out_shares.len(), 5); let num_measurements = leader_out_shares.len(); @@ -484,7 +514,7 @@ async fn agg_cont_req(version: DapVersion) { }) .unwrap(); - let vdaf = Prio3::new_aes128_count(2).unwrap(); + let vdaf = Prio3::new_count(2).unwrap(); assert_eq!( vdaf.unshard(&(), [leader_agg_share, helper_agg_share], num_measurements,) .unwrap(), @@ -492,9 +522,9 @@ async fn agg_cont_req(version: DapVersion) { ); } -async_test_versions! { agg_cont_req } +async_test_versions! { agg_job_cont_req } -async fn agg_cont_req_skip_vdaf_prep_error(version: DapVersion) { +async fn agg_job_cont_req_skip_vdaf_prep_error(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let mut reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); reports.insert( @@ -502,29 +532,30 @@ async fn agg_cont_req_skip_vdaf_prep_error(version: DapVersion) { t.produce_invalid_report_vdaf_prep_failure(DapMeasurement::U64(1), version), ); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (helper_state, agg_resp) = t - .handle_agg_init_req(agg_init_req.clone()) + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (helper_state, agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req.clone()) .await .unwrap_continue(); - let (_, agg_cont_req) = t - .handle_agg_resp(leader_state, agg_resp) + let (_, agg_job_cont_req) = t + .handle_agg_job_resp(leader_state, agg_job_resp) .unwrap_uncommitted(); - let (helper_output_shares, agg_resp) = t - .handle_agg_cont_req(helper_state, &agg_cont_req) + let (helper_output_shares, agg_job_resp) = t + .handle_agg_job_cont_req(helper_state, &agg_job_cont_req) .unwrap_finish(); assert_eq!(2, helper_output_shares.len()); - assert_eq!(2, agg_resp.transitions.len()); + assert_eq!(2, agg_job_resp.transitions.len()); assert_eq!( - agg_resp.transitions[0].report_id, - agg_init_req.report_shares[0].metadata.id + agg_job_resp.transitions[0].report_id, + agg_job_init_req.report_shares[0].report_metadata.id ); assert_eq!( - agg_resp.transitions[1].report_id, - agg_init_req.report_shares[2].metadata.id + agg_job_resp.transitions[1].report_id, + agg_job_init_req.report_shares[2].report_metadata.id ); assert_metrics_include!(t.prometheus_registry, { @@ -532,20 +563,24 @@ async fn agg_cont_req_skip_vdaf_prep_error(version: DapVersion) { }); } -async_test_versions! { agg_cont_req_skip_vdaf_prep_error } +async_test_versions! { agg_job_cont_req_skip_vdaf_prep_error } async fn agg_cont_abort_unrecognized_report_id(version: DapVersion) { let mut rng = thread_rng(); let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (helper_state, agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (helper_state, agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); - let (_, mut agg_cont_req) = t - .handle_agg_resp(leader_state, agg_resp) + let (_, mut agg_job_cont_req) = t + .handle_agg_job_resp(leader_state, agg_job_resp) .unwrap_uncommitted(); // Leader sends a Transition with an unrecognized report_id. - agg_cont_req.transitions.insert( + agg_job_cont_req.transitions.insert( 1, Transition { report_id: ReportId(rng.gen()), @@ -554,67 +589,83 @@ async fn agg_cont_abort_unrecognized_report_id(version: DapVersion) { ); assert_matches!( - t.handle_agg_cont_req_expect_err(helper_state, &agg_cont_req), + t.handle_agg_job_cont_req_expect_err(helper_state, &agg_job_cont_req), DapAbort::UnrecognizedMessage ); } async_test_versions! { agg_cont_abort_unrecognized_report_id } -async fn agg_cont_req_abort_transition_out_of_order(version: DapVersion) { +async fn agg_job_cont_req_abort_transition_out_of_order(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (helper_state, agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (helper_state, agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); - let (_, mut agg_cont_req) = t - .handle_agg_resp(leader_state, agg_resp) + let (_, mut agg_job_cont_req) = t + .handle_agg_job_resp(leader_state, agg_job_resp) .unwrap_uncommitted(); // Leader sends transitions out of order. - let tmp = agg_cont_req.transitions[0].clone(); - agg_cont_req.transitions[0] = agg_cont_req.transitions[1].clone(); - agg_cont_req.transitions[1] = tmp; + let tmp = agg_job_cont_req.transitions[0].clone(); + agg_job_cont_req.transitions[0] = agg_job_cont_req.transitions[1].clone(); + agg_job_cont_req.transitions[1] = tmp; assert_matches!( - t.handle_agg_cont_req_expect_err(helper_state, &agg_cont_req), + t.handle_agg_job_cont_req_expect_err(helper_state, &agg_job_cont_req), DapAbort::UnrecognizedMessage ); } -async_test_versions! { agg_cont_req_abort_transition_out_of_order } +async_test_versions! { agg_job_cont_req_abort_transition_out_of_order } -async fn agg_cont_req_abort_report_id_repeated(version: DapVersion) { +async fn agg_job_cont_req_abort_report_id_repeated(version: DapVersion) { let mut t = Test::new(TEST_VDAF, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); - let (leader_state, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (helper_state, agg_resp) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (helper_state, agg_job_resp) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); - let (_, mut agg_cont_req) = t - .handle_agg_resp(leader_state, agg_resp) + let (_, mut agg_job_cont_req) = t + .handle_agg_job_resp(leader_state, agg_job_resp) .unwrap_uncommitted(); // Leader sends a transition twice. - let repeated_transition = agg_cont_req.transitions[0].clone(); - agg_cont_req.transitions.push(repeated_transition); + let repeated_transition = agg_job_cont_req.transitions[0].clone(); + agg_job_cont_req.transitions.push(repeated_transition); assert_matches!( - t.handle_agg_cont_req_expect_err(helper_state, &agg_cont_req), + t.handle_agg_job_cont_req_expect_err(helper_state, &agg_job_cont_req), DapAbort::UnrecognizedMessage ); } -async_test_versions! { agg_cont_req_abort_report_id_repeated } +async_test_versions! { agg_job_cont_req_abort_report_id_repeated } async fn encrypted_agg_share(version: DapVersion) { let t = Test::new(TEST_VDAF, version); let leader_agg_share = DapAggregateShare { report_count: 50, + min_time: 1637359200, + max_time: 1637359200, checksum: [0; 32], - data: Some(VdafAggregateShare::Field64(vec![23.into()].into())), + data: Some(VdafAggregateShare::Field64(AggregateShare::from( + OutputShare::from(vec![Field64::from(23)]), + ))), }; let helper_agg_share = DapAggregateShare { report_count: 50, + min_time: 1637359200, + max_time: 1637359200, checksum: [0; 32], - data: Some(VdafAggregateShare::Field64(vec![9.into()].into())), + data: Some(VdafAggregateShare::Field64(AggregateShare::from( + OutputShare::from(vec![Field64::from(9)]), + ))), }; let batch_selector = BatchSelector::TimeInterval { @@ -649,8 +700,11 @@ async fn helper_state_serialization(version: DapVersion) { DapMeasurement::U64(0), DapMeasurement::U64(1), ]); - let (_, agg_init_req) = t.produce_agg_init_req(reports).await.unwrap_continue(); - let (want, _) = t.handle_agg_init_req(agg_init_req).await.unwrap_continue(); + let (_, agg_job_init_req) = t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (want, _) = t + .handle_agg_job_init_req(agg_job_init_req) + .await + .unwrap_continue(); let got = DapHelperState::get_decoded(TEST_VDAF, &want.get_encoded(TEST_VDAF).unwrap()).unwrap(); @@ -663,8 +717,8 @@ async_test_versions! { helper_state_serialization } pub(crate) struct Test { now: Time, - task_id: Id, - agg_job_id: Id, + task_id: TaskId, + agg_job_id: MetaAggregationJobId<'static>, task_config: DapTaskConfig, leader_hpke_receiver_config: HpkeReceiverConfig, helper_hpke_receiver_config: HpkeReceiverConfig, @@ -682,8 +736,8 @@ impl Test { .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs(); - let task_id = Id(rng.gen()); - let agg_job_id = Id(rng.gen()); + let task_id = TaskId(rng.gen()); + let agg_job_id = MetaAggregationJobId::gen_for_version(&version); let vdaf_verify_key = vdaf.gen_verify_key(); let leader_hpke_receiver_config = HpkeReceiverConfig::gen(rng.gen(), HpkeKemId::X25519HkdfSha256).unwrap(); @@ -752,10 +806,11 @@ impl Test { measurement: DapMeasurement, version: DapVersion, ) -> Report { + let report_id = ReportId(thread_rng().gen()); let (invalid_public_share, mut invalid_input_shares) = self .task_config .vdaf - .produce_input_shares(measurement) + .produce_input_shares(measurement, &report_id.0) .unwrap(); invalid_input_shares[1][0] ^= 1; // The first bit is incorrect! self.task_config @@ -766,6 +821,7 @@ impl Test { &self.client_hpke_config_list, self.now, &self.task_id, + &report_id, Vec::new(), // extensions version, ) @@ -778,10 +834,11 @@ impl Test { measurement: DapMeasurement, version: DapVersion, ) -> Report { + let report_id = ReportId(thread_rng().gen()); let (mut invalid_public_share, invalid_input_shares) = self .task_config .vdaf - .produce_input_shares(measurement) + .produce_input_shares(measurement, &report_id.0) .unwrap(); invalid_public_share.push(1); // Add spurious byte at the end self.task_config @@ -792,6 +849,7 @@ impl Test { &self.client_hpke_config_list, self.now, &self.task_id, + &report_id, Vec::new(), // extensions version, ) @@ -804,10 +862,11 @@ impl Test { measurement: DapMeasurement, version: DapVersion, ) -> Report { + let report_id = ReportId(thread_rng().gen()); let (invalid_public_share, mut invalid_input_shares) = self .task_config .vdaf - .produce_input_shares(measurement) + .produce_input_shares(measurement, &report_id.0) .unwrap(); invalid_input_shares[0].push(1); // Add a spurious byte to the Leader's share invalid_input_shares[1].push(1); // Add a spurious byte to the Helper's share @@ -819,19 +878,20 @@ impl Test { &self.client_hpke_config_list, self.now, &self.task_id, + &report_id, Vec::new(), // extensions version, ) .unwrap() } - async fn produce_agg_init_req( + async fn produce_agg_job_init_req( &self, reports: Vec, - ) -> DapLeaderTransition { + ) -> DapLeaderTransition { self.task_config .vdaf - .produce_agg_init_req( + .produce_agg_job_init_req( &self.leader_hpke_receiver_config, &self.task_id, &self.task_config, @@ -844,86 +904,89 @@ impl Test { .unwrap() } - async fn handle_agg_init_req( + async fn handle_agg_job_init_req( &mut self, - agg_init_req: AggregateInitializeReq, - ) -> DapHelperTransition { + agg_job_init_req: AggregationJobInitReq, + ) -> DapHelperTransition { self.task_config .vdaf - .handle_agg_init_req( + .handle_agg_job_init_req( &self.helper_hpke_receiver_config, + &self.task_id, &self.task_config, - &agg_init_req, + &agg_job_init_req, &self.helper_metrics, ) .await .unwrap() } - fn handle_agg_resp( + fn handle_agg_job_resp( &self, leader_state: DapLeaderState, - agg_resp: AggregateResp, - ) -> DapLeaderTransition { + agg_job_resp: AggregationJobResp, + ) -> DapLeaderTransition { self.task_config .vdaf - .handle_agg_resp( + .handle_agg_job_resp( &self.task_id, &self.agg_job_id, leader_state, - agg_resp, + agg_job_resp, + self.task_config.version, &self.leader_metrics, ) .unwrap() } - fn handle_agg_resp_expect_err( + fn handle_agg_job_resp_expect_err( &self, leader_state: DapLeaderState, - agg_resp: AggregateResp, + agg_job_resp: AggregationJobResp, ) -> DapAbort { self.task_config .vdaf - .handle_agg_resp( + .handle_agg_job_resp( &self.task_id, &self.agg_job_id, leader_state, - agg_resp, + agg_job_resp, + self.task_config.version, &self.leader_metrics, ) - .expect_err("handle_agg_resp() succeeded; expected failure") + .expect_err("handle_agg_job_resp() succeeded; expected failure") } - fn handle_agg_cont_req( + fn handle_agg_job_cont_req( &self, helper_state: DapHelperState, - agg_cont_req: &AggregateContinueReq, - ) -> DapHelperTransition { + agg_job_cont_req: &AggregationJobContinueReq, + ) -> DapHelperTransition { self.task_config .vdaf - .handle_agg_cont_req(helper_state, agg_cont_req, &self.helper_metrics) + .handle_agg_job_cont_req(helper_state, agg_job_cont_req, &self.helper_metrics) .unwrap() } - fn handle_agg_cont_req_expect_err( + fn handle_agg_job_cont_req_expect_err( &self, helper_state: DapHelperState, - agg_cont_req: &AggregateContinueReq, + agg_job_cont_req: &AggregationJobContinueReq, ) -> DapAbort { self.task_config .vdaf - .handle_agg_cont_req(helper_state, agg_cont_req, &self.helper_metrics) - .expect_err("handle_agg_cont_req() succeeded; expected failure") + .handle_agg_job_cont_req(helper_state, agg_job_cont_req, &self.helper_metrics) + .expect_err("handle_agg_job_cont_req() succeeded; expected failure") } - fn handle_final_agg_resp( + fn handle_final_agg_job_resp( &self, leader_uncommitted: DapLeaderUncommitted, - agg_resp: AggregateResp, + agg_job_resp: AggregationJobResp, ) -> Vec { self.task_config .vdaf - .handle_final_agg_resp(leader_uncommitted, agg_resp, &self.leader_metrics) + .handle_final_agg_job_resp(leader_uncommitted, agg_job_resp, &self.leader_metrics) .unwrap() } @@ -996,8 +1059,14 @@ impl Test { let reports = self.produce_reports(measurements); // Aggregators: Preparation - let (leader_state, agg_init) = self.produce_agg_init_req(reports).await.unwrap_continue(); - let (helper_state, agg_resp) = self.handle_agg_init_req(agg_init).await.unwrap_continue(); + let (leader_state, agg_init) = self + .produce_agg_job_init_req(reports) + .await + .unwrap_continue(); + let (helper_state, agg_job_resp) = self + .handle_agg_job_init_req(agg_init) + .await + .unwrap_continue(); let got = DapHelperState::get_decoded( &self.task_config.vdaf, &helper_state @@ -1008,12 +1077,12 @@ impl Test { assert_eq!(got, helper_state); let (uncommitted, agg_cont) = self - .handle_agg_resp(leader_state, agg_resp) + .handle_agg_job_resp(leader_state, agg_job_resp) .unwrap_uncommitted(); - let (helper_out_shares, agg_resp) = self - .handle_agg_cont_req(helper_state, &agg_cont) + let (helper_out_shares, agg_job_resp) = self + .handle_agg_job_cont_req(helper_state, &agg_cont) .unwrap_finish(); - let leader_out_shares = self.handle_final_agg_resp(uncommitted, agg_resp); + let leader_out_shares = self.handle_final_agg_job_resp(uncommitted, agg_job_resp); let report_count = u64::try_from(leader_out_shares.len()).unwrap(); // Leader: Aggregation diff --git a/daphne/src/vdaf/prio2.rs b/daphne/src/vdaf/prio2.rs index b716dd825..cc83e37dd 100644 --- a/daphne/src/vdaf/prio2.rs +++ b/daphne/src/vdaf/prio2.rs @@ -8,56 +8,61 @@ use crate::{ vdaf::VdafError, DapAggregateResult, DapMeasurement, VdafAggregateShare, VdafMessage, VdafState, }; use prio::{ - codec::{CodecError, Decode, Encode, ParameterizedDecode}, + codec::{Decode, Encode, ParameterizedDecode}, field::FieldPrio2, vdaf::{ prio2::{Prio2, Prio2PrepareShare, Prio2PrepareState}, AggregateShare, Aggregator, Client, Collector, PrepareTransition, Share, Vdaf, }, }; -use std::{convert::TryFrom, io::Cursor}; +use std::io::Cursor; /// Split the given measurement into a sequence of encoded input shares. pub(crate) fn prio2_shard( - dimension: u32, + dimension: usize, measurement: DapMeasurement, -) -> Result>, VdafError> { - let vdaf = Prio2::new(dimension as usize)?; - let (_public_share, input_shares) = match measurement { - DapMeasurement::U32Vec(ref data) => vdaf.shard(data)?, + nonce: &[u8; 16], +) -> Result<(Vec, Vec>), VdafError> { + let vdaf = Prio2::new(dimension)?; + let (public_share, input_shares) = match measurement { + DapMeasurement::U32Vec(ref data) => vdaf.shard(data, nonce)?, _ => panic!("prio2_shard: unexpected measurement type"), }; - Ok(input_shares - .iter() - .map(|input_share| input_share.get_encoded()) - .collect()) + Ok(( + public_share.get_encoded(), + input_shares + .iter() + .map(|input_share| input_share.get_encoded()) + .collect(), + )) } /// Consume an input share and return the corresponding VDAF step and message. pub(crate) fn prio2_prepare_init( - dimension: u32, + dimension: usize, verify_key: &[u8; 32], agg_id: usize, - nonce_data: &[u8], + nonce: &[u8; 16], + public_share_data: &[u8], input_share_data: &[u8], ) -> Result<(VdafState, VdafMessage), VdafError> { - let vdaf = Prio2::new(dimension as usize)?; + let vdaf = Prio2::new(dimension)?; + <()>::get_decoded_with_param(&vdaf, public_share_data)?; let input_share: Share = Share::get_decoded_with_param(&(&vdaf, agg_id), input_share_data)?; - let (state, share) = - vdaf.prepare_init(verify_key, agg_id, &(), nonce_data, &(), &input_share)?; + let (state, share) = vdaf.prepare_init(verify_key, agg_id, &(), nonce, &(), &input_share)?; Ok((VdafState::Prio2(state), VdafMessage::Prio2Share(share))) } /// Consume the verifier shares and return the output share and serialized outbound message. pub(crate) fn prio2_leader_prepare_finish( - dimension: u32, + dimension: usize, leader_state: VdafState, leader_share: VdafMessage, helper_share_data: &[u8], ) -> Result<(VdafAggregateShare, Vec), VdafError> { - let vdaf = Prio2::new(dimension as usize)?; + let vdaf = Prio2::new(dimension)?; let (out_share, outbound) = match (leader_state, leader_share) { (VdafState::Prio2(state), VdafMessage::Prio2Share(share)) => { let helper_share = @@ -78,11 +83,11 @@ pub(crate) fn prio2_leader_prepare_finish( /// Consume the peer's prepare message and return an output share. pub(crate) fn prio2_helper_prepare_finish( - dimension: u32, + dimension: usize, helper_state: VdafState, leader_message_data: &[u8], ) -> Result { - let vdaf = Prio2::new(dimension as usize)?; + let vdaf = Prio2::new(dimension)?; <()>::get_decoded(leader_message_data)?; let out_share = match helper_state { VdafState::Prio2(state) => match vdaf.prepare_step(state, ())? { @@ -99,11 +104,11 @@ pub(crate) fn prio2_helper_prepare_finish( /// Parse a prio2 prepare message from the front of `reader` whose type is compatible with `param`. pub(crate) fn prio2_decode_prepare_state( - dimension: u32, + dimension: usize, agg_id: usize, bytes: &mut Cursor<&[u8]>, ) -> Result { - let vdaf = Prio2::new(dimension as usize)?; + let vdaf = Prio2::new(dimension)?; Ok(VdafState::Prio2(Prio2PrepareState::decode_with_param( &(&vdaf, agg_id), bytes, @@ -120,15 +125,14 @@ pub(crate) fn prio2_encode_prepare_message(message: &VdafMessage) -> Vec { /// Interpret `encoded_agg_shares` as a sequence of encoded aggregate shares and unshard them. pub(crate) fn prio2_unshard>>( - dimension: u32, + dimension: usize, num_measurements: usize, encoded_agg_shares: M, ) -> Result { - let vdaf = Prio2::new(dimension as usize)?; + let vdaf = Prio2::new(dimension)?; let mut agg_shares = Vec::with_capacity(vdaf.num_aggregators()); for encoded in encoded_agg_shares.into_iter() { - let agg_share = AggregateShare::try_from(encoded.as_ref()) - .map_err(|e| CodecError::Other(Box::new(e)))?; + let agg_share = AggregateShare::get_decoded_with_param(&(&vdaf, &()), encoded.as_ref())?; agg_shares.push(agg_share) } let agg_res = vdaf.unshard(&(), agg_shares, num_measurements)?; diff --git a/daphne/src/vdaf/prio3.rs b/daphne/src/vdaf/prio3.rs index 3d8f83fce..7279cfabc 100644 --- a/daphne/src/vdaf/prio3.rs +++ b/daphne/src/vdaf/prio3.rs @@ -8,15 +8,16 @@ use crate::{ VdafMessage, VdafState, }; use prio::{ - codec::{CodecError, Encode, ParameterizedDecode}, + codec::{Encode, ParameterizedDecode}, vdaf::{ prio3::{ Prio3, Prio3InputShare, Prio3PrepareMessage, Prio3PrepareShare, Prio3PrepareState, + Prio3PublicShare, }, AggregateShare, Aggregator, Client, Collector, PrepareTransition, Vdaf, }, }; -use std::{convert::TryFrom, io::Cursor}; +use std::io::Cursor; const ERR_EXPECT_FINISH: &str = "unexpected transition (continued)"; const ERR_FIELD_TYPE: &str = "unexpected field type for step or message"; @@ -24,16 +25,19 @@ const ERR_FIELD_TYPE: &str = "unexpected field type for step or message"; macro_rules! shard { ( $vdaf:ident, - $measurement:expr + $measurement:expr, + $nonce:expr ) => {{ // Split measurement into input shares. - let (_public_share, input_shares) = $vdaf.shard($measurement)?; + let (public_share, input_shares) = $vdaf.shard($measurement, $nonce)?; - // Encode input shares. - input_shares - .iter() - .map(|input_share| input_share.get_encoded()) - .collect() + ( + public_share.get_encoded(), + input_shares + .iter() + .map(|input_share| input_share.get_encoded()) + .collect(), + ) }}; } @@ -41,19 +45,20 @@ macro_rules! shard { pub(crate) fn prio3_shard( config: &Prio3Config, measurement: DapMeasurement, -) -> Result>, VdafError> { + nonce: &[u8; 16], +) -> Result<(Vec, Vec>), VdafError> { match (&config, measurement) { (Prio3Config::Count, DapMeasurement::U64(measurement)) => { - let vdaf = Prio3::new_aes128_count(2)?; - Ok(shard!(vdaf, &measurement)) + let vdaf = Prio3::new_count(2)?; + Ok(shard!(vdaf, &measurement, nonce)) } (Prio3Config::Histogram { buckets }, DapMeasurement::U64(measurement)) => { - let vdaf = Prio3::new_aes128_histogram(2, buckets)?; - Ok(shard!(vdaf, &(measurement as u128))) + let vdaf = Prio3::new_histogram(2, buckets)?; + Ok(shard!(vdaf, &(measurement as u128), nonce)) } (Prio3Config::Sum { bits }, DapMeasurement::U64(measurement)) => { - let vdaf = Prio3::new_aes128_sum(2, *bits)?; - Ok(shard!(vdaf, &(measurement as u128))) + let vdaf = Prio3::new_sum(2, *bits)?; + Ok(shard!(vdaf, &(measurement as u128), nonce)) } _ => panic!("prio3_shard: unexpected VDAF config"), } @@ -64,15 +69,26 @@ macro_rules! prep_init { $vdaf:ident, $verify_key:expr, $agg_id:expr, - $nonce_data:expr, + $nonce:expr, + $public_share_data:expr, $input_share_data:expr ) => {{ + // Parse the public share. + let public_share = Prio3PublicShare::get_decoded_with_param(&$vdaf, $public_share_data)?; + // Parse the input share. let input_share = Prio3InputShare::get_decoded_with_param(&(&$vdaf, $agg_id), $input_share_data)?; // Run the prepare-init algorithm, returning the initial state. - $vdaf.prepare_init($verify_key, $agg_id, &(), $nonce_data, &(), &input_share)? + $vdaf.prepare_init( + $verify_key, + $agg_id, + &(), + $nonce, + &public_share, + &input_share, + )? }}; } @@ -81,29 +97,51 @@ pub(crate) fn prio3_prepare_init( config: &Prio3Config, verify_key: &[u8; 16], agg_id: usize, - nonce_data: &[u8], + nonce: &[u8; 16], + public_share_data: &[u8], input_share_data: &[u8], ) -> Result<(VdafState, VdafMessage), VdafError> { match &config { Prio3Config::Count => { - let vdaf = Prio3::new_aes128_count(2)?; - let (state, share) = prep_init!(vdaf, verify_key, agg_id, nonce_data, input_share_data); + let vdaf = Prio3::new_count(2)?; + let (state, share) = prep_init!( + vdaf, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data + ); Ok(( VdafState::Prio3Field64(state), VdafMessage::Prio3ShareField64(share), )) } Prio3Config::Histogram { buckets } => { - let vdaf = Prio3::new_aes128_histogram(2, buckets)?; - let (state, share) = prep_init!(vdaf, verify_key, agg_id, nonce_data, input_share_data); + let vdaf = Prio3::new_histogram(2, buckets)?; + let (state, share) = prep_init!( + vdaf, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data + ); Ok(( VdafState::Prio3Field128(state), VdafMessage::Prio3ShareField128(share), )) } Prio3Config::Sum { bits } => { - let vdaf = Prio3::new_aes128_sum(2, *bits)?; - let (state, share) = prep_init!(vdaf, verify_key, agg_id, nonce_data, input_share_data); + let vdaf = Prio3::new_sum(2, *bits)?; + let (state, share) = prep_init!( + vdaf, + verify_key, + agg_id, + nonce, + public_share_data, + input_share_data + ); Ok(( VdafState::Prio3Field128(state), VdafMessage::Prio3ShareField128(share), @@ -150,7 +188,7 @@ pub(crate) fn prio3_leader_prepare_finish( VdafState::Prio3Field64(state), VdafMessage::Prio3ShareField64(share), ) => { - let vdaf = Prio3::new_aes128_count(2)?; + let vdaf = Prio3::new_count(2)?; let (out_share, outbound) = leader_prep_fin!(vdaf, state, share, helper_share_data); let agg_share = VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?); (agg_share, outbound) @@ -160,7 +198,7 @@ pub(crate) fn prio3_leader_prepare_finish( VdafState::Prio3Field128(state), VdafMessage::Prio3ShareField128(share), ) => { - let vdaf = Prio3::new_aes128_histogram(2, buckets)?; + let vdaf = Prio3::new_histogram(2, buckets)?; let (out_share, outbound) = leader_prep_fin!(vdaf, state, share, helper_share_data); let agg_share = VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?); (agg_share, outbound) @@ -170,7 +208,7 @@ pub(crate) fn prio3_leader_prepare_finish( VdafState::Prio3Field128(state), VdafMessage::Prio3ShareField128(share), ) => { - let vdaf = Prio3::new_aes128_sum(2, *bits)?; + let vdaf = Prio3::new_sum(2, *bits)?; let (out_share, outbound) = leader_prep_fin!(vdaf, state, share, helper_share_data); let agg_share = VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?); (agg_share, outbound) @@ -210,17 +248,17 @@ pub(crate) fn prio3_helper_prepare_finish( ) -> Result { let agg_share = match (&config, state) { (Prio3Config::Count, VdafState::Prio3Field64(state)) => { - let vdaf = Prio3::new_aes128_count(2)?; + let vdaf = Prio3::new_count(2)?; let out_share = helper_prep_fin!(vdaf, state, peer_message_data); VdafAggregateShare::Field64(vdaf.aggregate(&(), [out_share])?) } (Prio3Config::Histogram { buckets }, VdafState::Prio3Field128(state)) => { - let vdaf = Prio3::new_aes128_histogram(2, buckets)?; + let vdaf = Prio3::new_histogram(2, buckets)?; let out_share = helper_prep_fin!(vdaf, state, peer_message_data); VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?) } (Prio3Config::Sum { bits }, VdafState::Prio3Field128(state)) => { - let vdaf = Prio3::new_aes128_sum(2, *bits)?; + let vdaf = Prio3::new_sum(2, *bits)?; let out_share = helper_prep_fin!(vdaf, state, peer_message_data); VdafAggregateShare::Field128(vdaf.aggregate(&(), [out_share])?) } @@ -258,19 +296,19 @@ pub(crate) fn prio3_decode_prepare_state( ) -> Result { match config { Prio3Config::Count => { - let vdaf = Prio3::new_aes128_count(2)?; + let vdaf = Prio3::new_count(2)?; Ok(VdafState::Prio3Field64( Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes)?, )) } Prio3Config::Histogram { buckets } => { - let vdaf = Prio3::new_aes128_histogram(2, buckets)?; + let vdaf = Prio3::new_histogram(2, buckets)?; Ok(VdafState::Prio3Field128( Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes)?, )) } Prio3Config::Sum { bits } => { - let vdaf = Prio3::new_aes128_sum(2, *bits)?; + let vdaf = Prio3::new_sum(2, *bits)?; Ok(VdafState::Prio3Field128( Prio3PrepareState::decode_with_param(&(&vdaf, agg_id), bytes)?, )) @@ -295,8 +333,7 @@ macro_rules! unshard { ) => {{ let mut agg_shares = Vec::with_capacity($vdaf.num_aggregators()); for data in $agg_shares.into_iter() { - let agg_share = AggregateShare::try_from(data.as_ref()) - .map_err(|e| CodecError::Other(Box::new(e)))?; + let agg_share = AggregateShare::get_decoded_with_param(&(&$vdaf, &()), data.as_ref())?; agg_shares.push(agg_share) } $vdaf.unshard(&(), agg_shares, $num_measurements) @@ -311,17 +348,17 @@ pub(crate) fn prio3_unshard>>( ) -> Result { match &config { Prio3Config::Count => { - let vdaf = Prio3::new_aes128_count(2)?; + let vdaf = Prio3::new_count(2)?; let agg_res = unshard!(vdaf, num_measurements, agg_shares)?; Ok(DapAggregateResult::U64(agg_res)) } Prio3Config::Histogram { buckets } => { - let vdaf = Prio3::new_aes128_histogram(2, buckets)?; + let vdaf = Prio3::new_histogram(2, buckets)?; let agg_res = unshard!(vdaf, num_measurements, agg_shares)?; Ok(DapAggregateResult::U128Vec(agg_res)) } Prio3Config::Sum { bits } => { - let vdaf = Prio3::new_aes128_sum(2, *bits)?; + let vdaf = Prio3::new_sum(2, *bits)?; let agg_res = unshard!(vdaf, num_measurements, agg_shares)?; Ok(DapAggregateResult::U128(agg_res)) } diff --git a/daphne/src/vdaf/prio3_test.rs b/daphne/src/vdaf/prio3_test.rs index ad3376733..20bf78431 100644 --- a/daphne/src/vdaf/prio3_test.rs +++ b/daphne/src/vdaf/prio3_test.rs @@ -53,18 +53,31 @@ fn test_prepare( ) -> Result<(), VdafError> { let mut rng = thread_rng(); let verify_key = rng.gen(); - let nonce = b"this is a good nonce"; + let nonce = [0; 16]; // Shard - let encoded_input_shares = prio3_shard(config, measurement).unwrap(); + let (encoded_public_share, encoded_input_shares) = + prio3_shard(config, measurement, &nonce).unwrap(); assert_eq!(encoded_input_shares.len(), 2); // Prepare - let (leader_state, leader_share) = - prio3_prepare_init(config, &verify_key, 0, nonce, &encoded_input_shares[0])?; + let (leader_state, leader_share) = prio3_prepare_init( + config, + &verify_key, + 0, + &nonce, + &encoded_public_share, + &encoded_input_shares[0], + )?; - let (helper_state, helper_share) = - prio3_prepare_init(config, &verify_key, 1, nonce, &encoded_input_shares[1])?; + let (helper_state, helper_share) = prio3_prepare_init( + config, + &verify_key, + 1, + &nonce, + &encoded_public_share, + &encoded_input_shares[1], + )?; let helper_share_data = prio3_encode_prepare_message(&helper_share); diff --git a/daphne_worker/Cargo.toml b/daphne_worker/Cargo.toml index bab97f47a..bccbcfa9b 100644 --- a/daphne_worker/Cargo.toml +++ b/daphne_worker/Cargo.toml @@ -27,7 +27,7 @@ getrandom = { version = "0.2.8", features = ["js"] } # Required for prio hex = { version = "0.4.3", features = ["serde"] } matchit = "0.7.0" paste = "1.0.12" -prio = "0.10.0" +prio = "0.12.0" prometheus = "0.13.3" rand = "0.8.5" reqwest-wasm = { version = "0.11.16", features = ["json"] } diff --git a/daphne_worker/src/config.rs b/daphne_worker/src/config.rs index 7dc8b55c0..f589080ed 100644 --- a/daphne_worker/src/config.rs +++ b/daphne_worker/src/config.rs @@ -22,14 +22,17 @@ use daphne::{ auth::BearerToken, constants, hpke::HpkeReceiverConfig, - messages::{decode_base64url_vec, HpkeConfig, Id, ReportMetadata}, - DapAbort, DapError, DapGlobalConfig, DapQueryConfig, DapRequest, DapTaskConfig, DapVersion, - Prio3Config, VdafConfig, + messages::{ + decode_base64url_vec, AggregationJobId, BatchId, CollectionJobId, HpkeConfig, + ReportMetadata, TaskId, + }, + DapAbort, DapError, DapGlobalConfig, DapQueryConfig, DapRequest, DapResource, DapResponse, + DapTaskConfig, DapVersion, Prio3Config, VdafConfig, }; use matchit::Router; use prio::{ codec::Decode, - vdaf::prg::{Prg, PrgAes128, Seed, SeedStream}, + vdaf::prg::{Prg, PrgSha3, Seed, SeedStream}, }; use prometheus::{Encoder, Registry}; use serde::{Deserialize, Serialize}; @@ -40,7 +43,7 @@ use std::{ sync::{Arc, RwLock, RwLockReadGuard}, time::Duration, }; -use tracing::{debug, error, trace}; +use tracing::{debug, error, info, trace}; use worker::{kv::KvStore, *}; pub(crate) const KV_KEY_PREFIX_HPKE_RECEIVER_CONFIG: &str = "hpke_receiver_config"; @@ -51,6 +54,9 @@ pub(crate) const KV_BINDING_DAP_CONFIG: &str = "DAP_CONFIG"; const DAP_BASE_URL: &str = "DAP_BASE_URL"; +const INT_ERR_PEER_ABORT: &str = "request aborted by peer"; +const INT_ERR_PEER_RESP_MISSING_MEDIA_TYPE: &str = "peer response is missing media type"; + /// Long-lived parameters for tasks using draft-wang-ppm-dap-taskprov-02 ("taskprov"). pub(crate) struct TaskprovConfig { /// HPKE collector configuration for all taskprov tasks. @@ -87,7 +93,7 @@ pub(crate) struct DaphneWorkerConfig { pub(crate) deployment: DaphneWorkerDeployment, /// Leader: Key used to derive collection job IDs. This field is not configured by the Helper. - pub(crate) collect_id_key: Option>, + pub(crate) collection_job_id_key: Option>, /// Sharding key, used to compute the ReportsPending or ReportsProcessed shard to map a report /// to (based on the report ID). @@ -151,15 +157,16 @@ impl DaphneWorkerConfig { None }; - let collect_id_key = if is_leader { - let collect_id_key_hex = env - .secret("DAP_COLLECT_ID_KEY") - .map_err(|e| format!("failed to load DAP_COLLECT_ID_KEY: {e}"))? + const DAP_COLLECTION_JOB_ID_KEY: &str = "DAP_COLLECTION_JOB_ID_KEY"; + let collection_job_id_key = if is_leader { + let collection_job_id_key_hex = env + .secret(DAP_COLLECTION_JOB_ID_KEY) + .map_err(|e| format!("failed to load {DAP_COLLECTION_JOB_ID_KEY}: {e}"))? .to_string(); - let collect_id_key = - Seed::get_decoded(&hex::decode(collect_id_key_hex).map_err(int_err)?) + let collection_job_id_key = + Seed::get_decoded(&hex::decode(collection_job_id_key_hex).map_err(int_err)?) .map_err(int_err)?; - Some(collect_id_key) + Some(collection_job_id_key) } else { None }; @@ -315,7 +322,7 @@ impl DaphneWorkerConfig { Ok(Self { global, deployment, - collect_id_key, + collection_job_id_key, report_shard_key, report_shard_count, base_url, @@ -337,7 +344,12 @@ impl DaphneWorkerConfig { metadata: &ReportMetadata, ) -> String { let mut shard_seed = [0; 8]; - PrgAes128::seed_stream(&self.report_shard_key, metadata.id.as_ref()).fill(&mut shard_seed); + PrgSha3::seed_stream( + &self.report_shard_key, + b"report shard", + metadata.id.as_ref(), + ) + .fill(&mut shard_seed); let shard = u64::from_be_bytes(shard_seed) % self.report_shard_count; let epoch = metadata.time - (metadata.time % self.global.report_storage_epoch_duration); durable_name_report_store(&task_config.version, task_id_hex, epoch, shard) @@ -357,13 +369,13 @@ pub(crate) struct DaphneWorkerIsolateState { hpke_receiver_configs: Arc>>, /// Laeder bearer token per task. - leader_bearer_tokens: Arc>>, + leader_bearer_tokens: Arc>>, /// Collector bearer token per task. - collector_bearer_tokens: Arc>>, + collector_bearer_tokens: Arc>>, /// Task list. - tasks: Arc>>, + tasks: Arc>>, } impl DaphneWorkerIsolateState { @@ -576,7 +588,7 @@ impl<'srv> DaphneWorker<'srv> { /// Retrieve from KV the Leader's bearer token for the given task. pub(crate) async fn get_leader_bearer_token<'a>( &'a self, - task_id: &'a Id, + task_id: &'a TaskId, ) -> Result> { self.kv_get_cached( &self.isolate_state().leader_bearer_tokens, @@ -589,7 +601,7 @@ impl<'srv> DaphneWorker<'srv> { /// Set a leader bearer token for the given task. pub(crate) async fn set_leader_bearer_token( &self, - task_id: &Id, + task_id: &TaskId, token: &BearerToken, ) -> Result> { self.kv_set_if_not_exists(KV_KEY_PREFIX_BEARER_TOKEN_LEADER, task_id, token.clone()) @@ -599,7 +611,7 @@ impl<'srv> DaphneWorker<'srv> { /// Retrieve from KV the Collector's bearer token for the given task. pub(crate) async fn get_collector_bearer_token<'a>( &'a self, - task_id: &'a Id, + task_id: &'a TaskId, ) -> Result> { self.kv_get_cached( &self.isolate_state().collector_bearer_tokens, @@ -612,7 +624,7 @@ impl<'srv> DaphneWorker<'srv> { /// Retrieve from KV the configuration for the given task. pub(crate) async fn get_task_config<'req>( &'srv self, - task_id: Cow<'req, Id>, + task_id: Cow<'req, TaskId>, ) -> Result>> where 'srv: 'req, @@ -628,7 +640,7 @@ impl<'srv> DaphneWorker<'srv> { /// Define a task in KV pub(crate) async fn set_task_config( &self, - task_id: &Id, + task_id: &TaskId, task_config: &DapTaskConfig, ) -> Result> { self.kv_set_if_not_exists(KV_KEY_PREFIX_TASK_CONFIG, task_id, task_config.clone()) @@ -639,7 +651,7 @@ impl<'srv> DaphneWorker<'srv> { /// indicated task is not recognized. pub(crate) async fn try_get_task_config<'req>( &'srv self, - task_id: &'req Id, + task_id: &'req TaskId, ) -> std::result::Result, DapError> where 'srv: 'req, @@ -685,8 +697,8 @@ impl<'srv> DaphneWorker<'srv> { /// applicable to fixed-size tasks. pub(crate) async fn internal_current_batch( &self, - task_id: &Id, - ) -> std::result::Result { + task_id: &TaskId, + ) -> std::result::Result { let task_config = self.try_get_task_config(task_id).await?; if !matches!(task_config.as_ref().query, DapQueryConfig::FixedSize { .. }) { return Err(DapError::fatal("query type mismatch")); @@ -752,9 +764,8 @@ impl<'srv> DaphneWorker<'srv> { cmd: InternalTestAddTask, ) -> Result<()> { // Task ID. - let task_id_data = decode_base64url_vec(cmd.task_id.as_bytes()) + let task_id = TaskId::try_from_base64url(&cmd.task_id) .ok_or_else(|| int_err("task ID is not valid URL-safe base64"))?; - let task_id = Id::get_decoded(&task_id_data).map_err(int_err)?; // VDAF config. let vdaf = match (cmd.vdaf.typ.as_ref(), cmd.vdaf.bits) { @@ -871,10 +882,13 @@ impl<'srv> DaphneWorker<'srv> { Ok(DapVersion::from(version)) } - pub(crate) async fn worker_request_to_dap( + pub(crate) async fn worker_request_to_dap( &self, mut req: Request, + ctx: &RouteContext, ) -> Result> { + let version = self.extract_version_parameter(&req)?; + // Determine the authorization method used by the sender. let bearer_token = req.headers().get("DAP-Auth-Token")?.map(BearerToken::from); let mut tls_client_auth = req.cf().tls_client_auth(); @@ -903,29 +917,63 @@ impl<'srv> DaphneWorker<'srv> { }; let content_type = req.headers().get("Content-Type")?; + let media_type = content_type.and_then(|s| constants::media_type_for(&s)); - let media_type = match content_type { - Some(s) => constants::media_type_for(&s), - None => None, - }; - - let version = self.extract_version_parameter(&req)?; let payload = req.bytes().await?; - // Parse the task ID from the front of the request payload and use it to look up the - // expected bearer token. - // - // TODO(cjpatton) Add regression tests that ensure each protocol message is prefixed by the - // task ID. - // - // TODO spec: Consider moving the task ID out of the payload. Right now we're parsing it - // twice so that we have a reference to the task ID before parsing the entire message. - let mut r = Cursor::new(payload.as_ref()); - let task_id = Id::decode(&mut r).ok(); + let (task_id, resource) = match version { + DapVersion::Draft02 => { + // Parse the task ID from the front of the request payload and use it to look up the + // expected bearer token. + // + // TODO(cjpatton) Add regression tests that ensure each protocol message is prefixed by the + // task ID. + // + // TODO spec: Consider moving the task ID out of the payload. Right now we're parsing it + // twice so that we have a reference to the task ID before parsing the entire message. + let mut r = Cursor::new(payload.as_ref()); + (TaskId::decode(&mut r).ok(), DapResource::Undefined) + } + DapVersion::Draft04 => { + let task_id = ctx.param("task_id").and_then(TaskId::try_from_base64url); + let resource = match media_type { + Some(constants::MEDIA_TYPE_AGG_CONT_REQ) + | Some(constants::MEDIA_TYPE_AGG_INIT_REQ) => { + if let Some(agg_job_id) = ctx + .param("agg_job_id") + .and_then(AggregationJobId::try_from_base64url) + { + DapResource::AggregationJob(agg_job_id) + } else { + // Missing or invalid agg job ID. This should be handled as a bad + // request (undefined resource) by the caller. + DapResource::Undefined + } + } + Some(constants::MEDIA_TYPE_COLLECT_REQ) => { + if let Some(collect_job_id) = ctx + .param("collect_job_id") + .and_then(CollectionJobId::try_from_base64url) + { + DapResource::CollectionJob(collect_job_id) + } else { + // Missing or invalid agg job ID. This should be handled as a bad + // request (undefined resource) by the caller. + DapResource::Undefined + } + } + _ => DapResource::Undefined, + }; + + (task_id, resource) + } + DapVersion::Unknown => unreachable!("unhandled version {version:?}"), + }; Ok(DapRequest { version, task_id, + resource, payload, url: req.url()?, media_type, @@ -940,6 +988,74 @@ impl<'srv> DaphneWorker<'srv> { pub(crate) fn greatest_valid_report_time(&self, now: u64) -> u64 { now.saturating_add(self.config().global.report_storage_max_future_time_skew) } + + // Generic HTTP POST/PUT + pub(crate) async fn send_http( + &self, + req: DapRequest, + is_put: bool, + ) -> std::result::Result { + let (payload, url) = (req.payload, req.url); + + let mut headers = reqwest_wasm::header::HeaderMap::new(); + if let Some(content_type) = req.media_type { + headers.insert( + reqwest_wasm::header::CONTENT_TYPE, + reqwest_wasm::header::HeaderValue::from_str(content_type) + .map_err(|e| DapError::Fatal(e.to_string()))?, + ); + } + + if let Some(DaphneWorkerAuth::BearerToken(bearer_token)) = req.sender_auth { + headers.insert( + reqwest_wasm::header::HeaderName::from_static("dap-auth-token"), + reqwest_wasm::header::HeaderValue::from_str(bearer_token.as_ref()) + .map_err(|e| DapError::Fatal(e.to_string()))?, + ); + } + + let client = &self.isolate_state().client; + let reqwest_req = if is_put { + client.put(url.as_str()) + } else { + client.post(url.as_str()) + } + .body(payload) + .headers(headers); + + let start = Date::now().as_millis(); + let reqwest_resp = reqwest_req + .send() + .await + .map_err(|e| DapError::Fatal(e.to_string()))?; + let end = Date::now().as_millis(); + info!("request to {} completed in {}ms", url, end - start); + let status = reqwest_resp.status(); + if status == 200 { + // Translate the reqwest response into a Worker response. + let content_type = reqwest_resp + .headers() + .get(reqwest_wasm::header::CONTENT_TYPE) + .ok_or_else(|| DapError::fatal(INT_ERR_PEER_RESP_MISSING_MEDIA_TYPE))? + .to_str() + .map_err(|e| DapError::Fatal(e.to_string()))?; + let media_type = constants::media_type_for(content_type); + + let payload = reqwest_resp + .bytes() + .await + .map_err(|e| DapError::Fatal(e.to_string()))? + .to_vec(); + + Ok(DapResponse { + payload, + media_type, + }) + } else { + error!("{}: request failed: {:?}", url, reqwest_resp); + Err(DapError::fatal(INT_ERR_PEER_ABORT)) + } + } } /// RwLockReadGuard'ed object, used to catch items fetched from KV. @@ -966,7 +1082,7 @@ impl AsRef for GuardedHpkeReceiverConfig<'_> { } } -pub(crate) type GuardedBearerToken<'a> = Guarded<'a, Id, BearerToken>; +pub(crate) type GuardedBearerToken<'a> = Guarded<'a, TaskId, BearerToken>; impl AsRef for GuardedBearerToken<'_> { fn as_ref(&self) -> &BearerToken { @@ -974,7 +1090,7 @@ impl AsRef for GuardedBearerToken<'_> { } } -pub(crate) type GuardedDapTaskConfig<'a> = Guarded<'a, Id, DapTaskConfig>; +pub(crate) type GuardedDapTaskConfig<'a> = Guarded<'a, TaskId, DapTaskConfig>; impl AsRef for GuardedDapTaskConfig<'_> { fn as_ref(&self) -> &DapTaskConfig { diff --git a/daphne_worker/src/dap.rs b/daphne_worker/src/dap.rs index d862794c8..c0fd82a6f 100644 --- a/daphne_worker/src/dap.rs +++ b/daphne_worker/src/dap.rs @@ -27,11 +27,13 @@ use crate::{ BatchCount, DURABLE_LEADER_BATCH_QUEUE_ASSIGN, DURABLE_LEADER_BATCH_QUEUE_REMOVE, }, leader_col_job_queue::{ - DURABLE_LEADER_COL_JOB_QUEUE_FINISH, DURABLE_LEADER_COL_JOB_QUEUE_GET, - DURABLE_LEADER_COL_JOB_QUEUE_GET_RESULT, DURABLE_LEADER_COL_JOB_QUEUE_PUT, + CollectQueueRequest, DURABLE_LEADER_COL_JOB_QUEUE_FINISH, + DURABLE_LEADER_COL_JOB_QUEUE_GET, DURABLE_LEADER_COL_JOB_QUEUE_GET_RESULT, + DURABLE_LEADER_COL_JOB_QUEUE_PUT, }, reports_pending::{ - ReportsPendingResult, DURABLE_REPORTS_PENDING_GET, DURABLE_REPORTS_PENDING_PUT, + PendingReport, ReportsPendingResult, DURABLE_REPORTS_PENDING_GET, + DURABLE_REPORTS_PENDING_PUT, }, reports_processed::DURABLE_REPORTS_PROCESSED_MARK_AGGREGATED, BINDING_DAP_AGGREGATE_STORE, BINDING_DAP_HELPER_STATE_STORE, @@ -44,17 +46,18 @@ use crate::{ use async_trait::async_trait; use daphne::{ auth::{BearerToken, BearerTokenProvider}, - constants::{media_type_for, sender_for_media_type}, + constants::sender_for_media_type, hpke::HpkeDecrypter, messages::{ - BatchSelector, CollectReq, CollectResp, HpkeCiphertext, Id, PartialBatchSelector, Report, - ReportId, ReportMetadata, TransitionFailure, + BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq, HpkeCiphertext, + PartialBatchSelector, Report, ReportId, ReportMetadata, TaskId, TransitionFailure, }, metrics::DaphneMetrics, roles::{early_metadata_check, DapAggregator, DapAuthorizedSender, DapHelper, DapLeader}, taskprov::{bad_request, get_taskprov_task_config}, DapAggregateShare, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig, DapHelperState, DapOutputShare, DapQueryConfig, DapRequest, DapResponse, DapSender, DapTaskConfig, DapVersion, + MetaAggregationJobId, }; use futures::future::try_join_all; use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode}; @@ -62,12 +65,9 @@ use std::{ borrow::Cow, collections::{HashMap, HashSet}, }; -use tracing::{debug, error, info, warn}; +use tracing::{debug, warn}; use worker::*; -const INT_ERR_PEER_ABORT: &str = "request aborted by peer"; -const INT_ERR_PEER_RESP_MISSING_MEDIA_TYPE: &str = "peer response is missing media type"; - pub(crate) fn dap_response_to_worker(resp: DapResponse) -> Result { let mut headers = Headers::new(); if let Some(media_type) = resp.media_type { @@ -84,7 +84,7 @@ impl<'srv> HpkeDecrypter<'srv> for DaphneWorker<'srv> { async fn get_hpke_config_for( &'srv self, version: DapVersion, - _task_id: Option<&Id>, + _task_id: Option<&TaskId>, ) -> std::result::Result, DapError> { let kv_store = self.kv().map_err(dap_err)?; let keys = kv_store @@ -155,7 +155,7 @@ impl<'srv> HpkeDecrypter<'srv> for DaphneWorker<'srv> { async fn can_hpke_decrypt( &self, - task_id: &Id, + task_id: &TaskId, config_id: u8, ) -> std::result::Result { let version = self.try_get_task_config(task_id).await?.as_ref().version; @@ -171,7 +171,7 @@ impl<'srv> HpkeDecrypter<'srv> for DaphneWorker<'srv> { async fn hpke_decrypt( &self, - task_id: &Id, + task_id: &TaskId, info: &[u8], aad: &[u8], ciphertext: &HpkeCiphertext, @@ -203,14 +203,14 @@ impl<'srv> BearerTokenProvider<'srv> for DaphneWorker<'srv> { async fn get_leader_bearer_token_for( &'srv self, - task_id: &'srv Id, + task_id: &'srv TaskId, ) -> std::result::Result, DapError> { self.get_leader_bearer_token(task_id).await.map_err(dap_err) } async fn get_collector_bearer_token_for( &'srv self, - task_id: &'srv Id, + task_id: &'srv TaskId, ) -> std::result::Result, DapError> { self.get_collector_bearer_token(task_id) .await @@ -245,7 +245,7 @@ impl<'srv> BearerTokenProvider<'srv> for DaphneWorker<'srv> { impl DapAuthorizedSender for DaphneWorker<'_> { async fn authorize( &self, - task_id: &Id, + task_id: &TaskId, media_type: &'static str, _payload: &[u8], ) -> std::result::Result { @@ -366,7 +366,7 @@ where async fn get_task_config_considering_taskprov( &'srv self, version: DapVersion, - task_id: Cow<'req, Id>, + task_id: Cow<'req, TaskId>, metadata: Option<&ReportMetadata>, ) -> std::result::Result>, DapError> { let found = self @@ -450,7 +450,7 @@ where async fn is_batch_overlapping( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> std::result::Result { let task_config = self.try_get_task_config(task_id).await?; @@ -483,8 +483,8 @@ where async fn batch_exists( &self, - task_id: &Id, - batch_id: &Id, + task_id: &TaskId, + batch_id: &BatchId, ) -> std::result::Result { let task_config = self.try_get_task_config(task_id).await?; @@ -507,7 +507,7 @@ where async fn put_out_shares( &self, - task_id: &Id, + task_id: &TaskId, part_batch_sel: &PartialBatchSelector, out_shares: Vec, ) -> std::result::Result<(), DapError> { @@ -534,7 +534,7 @@ where async fn get_agg_share( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> std::result::Result { let task_config = self.try_get_task_config(task_id).await?; @@ -561,7 +561,7 @@ where async fn check_early_reject<'b>( &self, - task_id: &Id, + task_id: &TaskId, part_batch_sel: &'b PartialBatchSelector, report_meta: impl Iterator, ) -> std::result::Result, DapError> { @@ -661,7 +661,7 @@ where async fn mark_collected( &self, - task_id: &Id, + task_id: &TaskId, batch_sel: &BatchSelector, ) -> std::result::Result<(), DapError> { let task_config = self.try_get_task_config(task_id).await?; @@ -683,7 +683,7 @@ where Ok(()) } - async fn current_batch(&self, task_id: &Id) -> std::result::Result { + async fn current_batch(&self, task_id: &TaskId) -> std::result::Result { self.internal_current_batch(task_id).await } @@ -692,16 +692,6 @@ where } } -fn task_id_from_report(report: &[u8]) -> std::result::Result { - // The task id MUST BE the first 32 bytes of the serialized report; if this - // needs to change in the future, then we must change the DO serialization - // format to contain a version. - let id = Id(report[..32] - .try_into() - .map_err(|_| DapError::fatal("serialize report is too short"))?); - Ok(id) -} - #[async_trait(?Send)] impl<'srv, 'req> DapLeader<'srv, 'req, DaphneWorkerAuth> for DaphneWorker<'srv> where @@ -709,10 +699,19 @@ where { type ReportSelector = DaphneWorkerReportSelector; - async fn put_report(&self, report: &Report) -> std::result::Result<(), DapError> { - let task_config = self.try_get_task_config(&report.task_id).await?; - let task_id_hex = report.task_id.to_hex(); - let report_hex = hex::encode(report.get_encoded_with_param(&task_config.as_ref().version)); + async fn put_report( + &self, + report: &Report, + task_id: &TaskId, + ) -> std::result::Result<(), DapError> { + let task_config = self.try_get_task_config(task_id).await?; + let task_id_hex = task_id.to_hex(); + let version = task_config.as_ref().version; + let pending_report = PendingReport { + version, + task_id: task_id.clone(), + report_hex: hex::encode(report.get_encoded_with_param(&version)), + }; let res: ReportsPendingResult = self .durable() .post( @@ -721,9 +720,9 @@ where self.config().durable_name_report_store( task_config.as_ref(), &task_id_hex, - &report.metadata, + &report.report_metadata, ), - &report_hex, + &pending_report, ) .await .map_err(dap_err)?; @@ -744,7 +743,7 @@ where async fn get_reports( &self, report_sel: &DaphneWorkerReportSelector, - ) -> std::result::Result>>, DapError> + ) -> std::result::Result>>, DapError> { let durable = self.durable(); // Read at most `report_sel.max_buckets` buckets from the agg job queue. The result is ordered @@ -766,9 +765,9 @@ where // by task. // // TODO Figure out if we can safely handle each instance in parallel. - let mut reports_per_task: HashMap> = HashMap::new(); + let mut reports_per_task: HashMap> = HashMap::new(); for reports_pending_id_hex in res.into_iter() { - let reports_from_durable: Vec = durable + let reports_from_durable: Vec = durable .post_by_id_hex( BINDING_DAP_REPORTS_PENDING, DURABLE_REPORTS_PENDING_GET, @@ -778,23 +777,26 @@ where .await .map_err(dap_err)?; - for report_hex in reports_from_durable { - let report_bytes = hex::decode(&report_hex).map_err(|_| { + for pending_report in reports_from_durable { + let report_bytes = hex::decode(&pending_report.report_hex).map_err(|_| { DapError::fatal("response from ReportsPending is not valid hex") })?; - let task_id = task_id_from_report(&report_bytes)?; - let task_config = self.try_get_task_config(&task_id).await?; - let report = - Report::get_decoded_with_param(&task_config.as_ref().version, &report_bytes)?; - if let Some(reports) = reports_per_task.get_mut(&report.task_id) { + + let version = self + .try_get_task_config(&pending_report.task_id) + .await? + .as_ref() + .version; + let report = Report::get_decoded_with_param(&version, &report_bytes)?; + if let Some(reports) = reports_per_task.get_mut(&pending_report.task_id) { reports.push(report); } else { - reports_per_task.insert(report.task_id.clone(), vec![report]); + reports_per_task.insert(pending_report.task_id.clone(), vec![report]); } } } - let mut reports_per_task_part: HashMap>> = + let mut reports_per_task_part: HashMap>> = HashMap::new(); for (task_id, mut reports) in reports_per_task.into_iter() { let task_config = self @@ -857,19 +859,25 @@ where async fn init_collect_job( &self, - collect_req: &CollectReq, + task_id: &TaskId, + collect_job_id: &Option, + collect_req: &CollectionReq, ) -> std::result::Result { - let task_config = self.try_get_task_config(&collect_req.task_id).await?; - + let task_config = self.try_get_task_config(task_id).await?; // Try to put the request into collection job queue. If the request is overlapping // with past requests, then abort. - let collect_id: Id = self + let collect_queue_req = CollectQueueRequest { + collect_req: collect_req.clone(), + task_id: task_id.clone(), + collect_job_id: collect_job_id.clone(), + }; + let collect_id: CollectionJobId = self .durable() .post( BINDING_DAP_LEADER_COL_JOB_QUEUE, DURABLE_LEADER_COL_JOB_QUEUE_PUT, durable_name_queue(0), - &collect_req, + &collect_queue_req, ) .await .map_err(dap_err)?; @@ -877,10 +885,11 @@ where let url = task_config.as_ref().leader_url.clone(); + // Note that we always return the draft02 URI, but draft04 and later ignore it. let collect_uri = url .join(&format!( "collect/task/{}/req/{}", - collect_req.task_id.to_base64url(), + task_id.to_base64url(), collect_id.to_base64url(), )) .map_err(|e| DapError::Fatal(e.to_string()))?; @@ -890,8 +899,8 @@ where async fn poll_collect_job( &self, - _task_id: &Id, - collect_id: &Id, + task_id: &TaskId, + collect_id: &CollectionJobId, ) -> std::result::Result { let res: DapCollectJob = self .durable() @@ -899,7 +908,7 @@ where BINDING_DAP_LEADER_COL_JOB_QUEUE, DURABLE_LEADER_COL_JOB_QUEUE_GET_RESULT, durable_name_queue(0), - &collect_id, + (&task_id, &collect_id), ) .await .map_err(dap_err)?; @@ -908,8 +917,8 @@ where async fn get_pending_collect_jobs( &self, - ) -> std::result::Result, DapError> { - let res: Vec<(Id, CollectReq)> = self + ) -> std::result::Result, DapError> { + let res: Vec<(TaskId, CollectionJobId, CollectionReq)> = self .durable() .get( BINDING_DAP_LEADER_COL_JOB_QUEUE, @@ -923,9 +932,9 @@ where async fn finish_collect_job( &self, - task_id: &Id, - collect_id: &Id, - collect_resp: &CollectResp, + task_id: &TaskId, + collect_id: &CollectionJobId, + collect_resp: &Collection, ) -> std::result::Result<(), DapError> { let task_config = self.try_get_task_config(task_id).await?; let durable = self.durable(); @@ -948,7 +957,7 @@ where BINDING_DAP_LEADER_COL_JOB_QUEUE, DURABLE_LEADER_COL_JOB_QUEUE_FINISH, durable_name_queue(0), - (collect_id, collect_resp), + (task_id, collect_id, collect_resp), ) .await .map_err(dap_err)?; @@ -959,64 +968,14 @@ where &self, req: DapRequest, ) -> std::result::Result { - let (payload, url) = (req.payload, req.url); - - let mut headers = reqwest_wasm::header::HeaderMap::new(); - if let Some(content_type) = req.media_type { - headers.insert( - reqwest_wasm::header::CONTENT_TYPE, - reqwest_wasm::header::HeaderValue::from_str(content_type) - .map_err(|e| DapError::Fatal(e.to_string()))?, - ); - } - - if let Some(DaphneWorkerAuth::BearerToken(bearer_token)) = req.sender_auth { - headers.insert( - reqwest_wasm::header::HeaderName::from_static("dap-auth-token"), - reqwest_wasm::header::HeaderValue::from_str(bearer_token.as_ref()) - .map_err(|e| DapError::Fatal(e.to_string()))?, - ); - } - - let reqwest_req = self - .isolate_state() - .client - .post(url.as_str()) - .body(payload) - .headers(headers); - - let start = Date::now().as_millis(); - let reqwest_resp = reqwest_req - .send() - .await - .map_err(|e| DapError::Fatal(e.to_string()))?; - let end = Date::now().as_millis(); - info!("request to {} completed in {}ms", url, end - start); - let status = reqwest_resp.status(); - if status == 200 { - // Translate the reqwest response into a Worker response. - let content_type = reqwest_resp - .headers() - .get(reqwest_wasm::header::CONTENT_TYPE) - .ok_or_else(|| DapError::fatal(INT_ERR_PEER_RESP_MISSING_MEDIA_TYPE))? - .to_str() - .map_err(|e| DapError::Fatal(e.to_string()))?; - let media_type = media_type_for(content_type); - - let payload = reqwest_resp - .bytes() - .await - .map_err(|e| DapError::Fatal(e.to_string()))? - .to_vec(); + self.send_http(req, false).await + } - Ok(DapResponse { - payload, - media_type, - }) - } else { - error!("{}: request failed: {:?}", url, reqwest_resp); - Err(DapError::fatal(INT_ERR_PEER_ABORT)) - } + async fn send_http_put( + &self, + req: DapRequest, + ) -> std::result::Result { + self.send_http(req, true).await } } @@ -1027,8 +986,8 @@ where { async fn put_helper_state( &self, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, helper_state: &DapHelperState, ) -> std::result::Result<(), DapError> { let task_config = self.try_get_task_config(task_id).await?; @@ -1047,8 +1006,8 @@ where async fn get_helper_state( &self, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, ) -> std::result::Result, DapError> { let task_config = self.try_get_task_config(task_id).await?; let res: Option = self diff --git a/daphne_worker/src/durable/helper_state_store.rs b/daphne_worker/src/durable/helper_state_store.rs index 1c3711610..2474e0a06 100644 --- a/daphne_worker/src/durable/helper_state_store.rs +++ b/daphne_worker/src/durable/helper_state_store.rs @@ -2,14 +2,14 @@ // SPDX-License-Identifier: BSD-3-Clause use crate::{config::DaphneWorkerConfig, durable::state_get, initialize_tracing, int_err}; -use daphne::{messages::Id, DapVersion}; +use daphne::{messages::TaskId, DapVersion, MetaAggregationJobId}; use tracing::{trace, warn}; use worker::*; pub(crate) fn durable_helper_state_name( version: &DapVersion, - task_id: &Id, - agg_job_id: &Id, + task_id: &TaskId, + agg_job_id: &MetaAggregationJobId, ) -> String { format!( "{}/task/{}/agg_job/{}", diff --git a/daphne_worker/src/durable/leader_batch_queue.rs b/daphne_worker/src/durable/leader_batch_queue.rs index 2fea21ce1..879647f7a 100644 --- a/daphne_worker/src/durable/leader_batch_queue.rs +++ b/daphne_worker/src/durable/leader_batch_queue.rs @@ -6,7 +6,7 @@ use crate::{ durable::{state_get, DurableOrdered, BINDING_DAP_LEADER_BATCH_QUEUE}, initialize_tracing, int_err, }; -use daphne::messages::Id; +use daphne::messages::BatchId; use rand::prelude::*; use serde::{Deserialize, Serialize}; use tracing::debug; @@ -22,14 +22,14 @@ const PENDING_PREFIX: &str = "pending"; #[derive(Clone, Deserialize, Serialize)] pub(crate) struct BatchCount { - pub(crate) batch_id: Id, + pub(crate) batch_id: BatchId, pub(crate) report_count: usize, } #[derive(Clone, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub(crate) enum LeaderBatchQueueResult { - Ok(Id), + Ok(BatchId), EmptyQueue, } @@ -67,7 +67,7 @@ impl LeaderBatchQueue { let queued = DurableOrdered::new_strictly_ordered( &self.state, BatchCount { - batch_id: Id(rng.gen()), + batch_id: BatchId(rng.gen()), report_count: 0, }, PENDING_PREFIX, diff --git a/daphne_worker/src/durable/leader_col_job_queue.rs b/daphne_worker/src/durable/leader_col_job_queue.rs index 01b4ce3b9..abcf5bb79 100644 --- a/daphne_worker/src/durable/leader_col_job_queue.rs +++ b/daphne_worker/src/durable/leader_col_job_queue.rs @@ -7,13 +7,14 @@ use crate::{ initialize_tracing, int_err, }; use daphne::{ - messages::{CollectReq, CollectResp, Id}, + messages::{Collection, CollectionJobId, CollectionReq, TaskId}, DapCollectJob, DapVersion, }; use prio::{ codec::ParameterizedEncode, - vdaf::prg::{Prg, PrgAes128, SeedStream}, + vdaf::prg::{Prg, PrgSha3, SeedStream}, }; +use serde::{Deserialize, Serialize}; use worker::*; const PENDING_PREFIX: &str = "pending"; @@ -26,6 +27,14 @@ pub(crate) const DURABLE_LEADER_COL_JOB_QUEUE_FINISH: &str = pub(crate) const DURABLE_LEADER_COL_JOB_QUEUE_GET_RESULT: &str = "/internal/do/leader_col_job_queue/get_result"; +#[derive(Clone, Deserialize, Serialize, Debug)] +#[serde(rename_all = "snake_case")] +pub(crate) struct CollectQueueRequest { + pub collect_req: CollectionReq, + pub task_id: TaskId, + pub collect_job_id: Option, +} + /// Durable Object (DO) for storing the Leader's state for a given task. /// /// This object implements the following API endpoints: @@ -39,10 +48,10 @@ pub(crate) const DURABLE_LEADER_COL_JOB_QUEUE_GET_RESULT: &str = /// The schema for data stored in instances of this DO is as follows: /// /// ```text -/// [Pending Lookup ID] pending/id/ -> String (reference to queue element) +/// [Pending Lookup ID] pending/id/ -> String (reference to queue element) /// [Pending queue] pending/next_ordinal -> u64 -/// [Pending queue] pending/item/order/ -> (Id, CollectReq) -/// [Processed] processed/ -> CollectResp +/// [Pending queue] pending/item/order/ -> (CollectionJobId, CollectReq) +/// [Processed] processed/ -> CollectResp /// ``` /// /// Note that the queue ordinal format is inherited from [`DurableOrdered::new_strictly_ordered`]. @@ -81,55 +90,61 @@ impl DurableObject for LeaderCollectionJobQueue { // Input: `collect_req: CollectReq` // Output: `Id` (collect job ID) (DURABLE_LEADER_COL_JOB_QUEUE_PUT, Method::Post) => { - let collect_req: CollectReq = req.json().await?; - - // Compute the collect job ID, used to derive the collect URI for this request. - // This value is computed by applying a pseudorandom function to the request. This - // has two desirable properties. First, it makes the collect URI unpredictable, - // which prevents clients from enumerating collect URIs. Second, it provides a - // stable map from requests to URIs, which prevents us from processing the same - // collect request more than once. - // - // We are serializing the collect_req into binary, and for now we assume the - // version is always Draf03 since that works for both Draft02 and Draft03, but - // if this structure changes further, then version information will need to be - // added to this request. - let collect_req_bytes = collect_req.get_encoded_with_param(&DapVersion::Draft03); - let mut collect_id_bytes = [0; 32]; - PrgAes128::seed_stream( - self.config.collect_id_key.as_ref().unwrap(), - &collect_req_bytes, - ) - .fill(&mut collect_id_bytes); - let collect_id = Id(collect_id_bytes); - let collect_id_hex = collect_id.to_hex(); + let collect_queue_req: CollectQueueRequest = req.json().await?; + let collection_job_id: CollectionJobId = + if let Some(cid) = &collect_queue_req.collect_job_id { + cid.clone() + } else { + // draft02 legacy: Compute the collect job ID, used to derive the collect + // URI for this request. This value is computed by applying a pseudorandom + // function to the request. This has two desirable properties. First, it + // makes the collect URI unpredictable, which prevents clients from + // enumerating collect URIs. Second, it provides a stable map from requests + // to URIs, which prevents us from processing the same collect request more + // than once. + let collect_req_bytes = collect_queue_req + .collect_req + .get_encoded_with_param(&DapVersion::Draft02); + let mut collection_job_id_bytes = [0; 16]; + PrgSha3::seed_stream( + self.config.collection_job_id_key.as_ref().unwrap(), + b"collection job id", + &collect_req_bytes, + ) + .fill(&mut collection_job_id_bytes); + CollectionJobId(collection_job_id_bytes) + }; // If the the request is new, then put it in the job queue. - let pending_key = format!("pending/id/{collect_id_hex}"); + let pending_key = pending_key(&collect_queue_req.task_id, &collection_job_id); + let processed_key = processed_key(&collect_queue_req.task_id, &collection_job_id); let pending: bool = state_get_or_default(&self.state, &pending_key).await?; - let processed: Option = - state_get(&self.state, &format!("{PROCESSED_PREFIX}/{collect_id_hex}")).await?; + let processed: Option = state_get(&self.state, &processed_key).await?; if processed.is_none() && !pending { let queued = DurableOrdered::new_strictly_ordered( &self.state, - (collect_id, collect_req), + ( + collect_queue_req.task_id, + collection_job_id.clone(), + collect_queue_req.collect_req, + ), PENDING_PREFIX, ) .await?; queued.put(&self.state).await?; self.state .storage() - .put(&lookup_key(&collect_id_hex), &queued.key()) + .put(&pending_key, &queued.key()) .await?; } - Response::from_json(&collect_id_hex) + Response::from_json(&collection_job_id.to_hex()) } // Get the list of pending collection jobs (oldest jobs first). // // Output: `Vec<(Id, CollectReq)>` (DURABLE_LEADER_COL_JOB_QUEUE_GET, Method::Get) => { - let queue: Vec<(Id, CollectReq)> = + let queue: Vec<(TaskId, CollectionJobId, CollectionReq)> = DurableOrdered::get_all(&self.state, PENDING_PREFIX) .await? .into_iter() @@ -140,12 +155,15 @@ impl DurableObject for LeaderCollectionJobQueue { // Remove a collection job from the pending queue and store the CollectResp. // - // Input: `(collect_id, collect_resp): (Id, CollectResp)` + // Input: `(collection_job_id, collect_resp): (Id, CollectResp)` (DURABLE_LEADER_COL_JOB_QUEUE_FINISH, Method::Post) => { - let (collect_id, collect_resp): (Id, CollectResp) = req.json().await?; - let collect_id_hex = collect_id.to_hex(); - let processed_key = format!("{PROCESSED_PREFIX}/{collect_id_hex}"); - let processed: Option = state_get(&self.state, &processed_key).await?; + let (task_id, collection_job_id, collect_resp): ( + TaskId, + CollectionJobId, + Collection, + ) = req.json().await?; + let processed_key = processed_key(&task_id, &collection_job_id); + let processed: Option = state_get(&self.state, &processed_key).await?; if processed.is_some() { return Err(int_err( "LeaderCollectionJobQueue: tried to overwrite collect response", @@ -153,15 +171,13 @@ impl DurableObject for LeaderCollectionJobQueue { } // Remove the collection job from the pending queue. - let pending_lookup_key = lookup_key(&collect_id_hex); - if let Some(lookup_val) = - state_get::(&self.state, &pending_lookup_key).await? - { + let pending_key = pending_key(&task_id, &collection_job_id); + if let Some(lookup_val) = state_get::(&self.state, &pending_key).await? { self.state.storage().delete(&lookup_val).await?; } let mut storage = self.state.storage(); - let f = storage.delete(&pending_lookup_key); + let f = storage.delete(&pending_key); // Store the CollectResp. self.state @@ -176,20 +192,19 @@ impl DurableObject for LeaderCollectionJobQueue { // Check if a collection job is complete. // - // Input: `collect_id: Id` + // Input: `collection_job_id: Id` // Output: `DapCollectionJob` (DURABLE_LEADER_COL_JOB_QUEUE_GET_RESULT, Method::Post) => { - let collect_id: Id = req.json().await?; - let collect_id_hex = collect_id.to_hex(); - let pending_lookup_key = lookup_key(&collect_id_hex); - let pending = state_get::(&self.state, &pending_lookup_key) + let (task_id, collection_job_id): (TaskId, CollectionJobId) = req.json().await?; + let pending_key = pending_key(&task_id, &collection_job_id); + let pending = state_get::(&self.state, &pending_key) .await? .is_some(); - let processed_key = format!("{PROCESSED_PREFIX}/{collect_id_hex}"); - let processed: Option = state_get(&self.state, &processed_key).await?; + let processed_key = processed_key(&task_id, &collection_job_id); + let processed: Option = state_get(&self.state, &processed_key).await?; if let Some(collect_resp) = processed { if pending { - self.state.storage().delete(&pending_lookup_key).await?; + self.state.storage().delete(&pending_key).await?; } Response::from_json(&DapCollectJob::Done(collect_resp)) } else if pending { @@ -208,6 +223,18 @@ impl DurableObject for LeaderCollectionJobQueue { } } -fn lookup_key(collect_id_hex: &str) -> String { - format!("{PENDING_PREFIX}/id/{collect_id_hex}") +fn pending_key(task_id: &TaskId, collection_job_id: &CollectionJobId) -> String { + format!( + "{PENDING_PREFIX}/tasks/{}/collection_jobs/{}", + task_id.to_base64url(), + collection_job_id.to_base64url() + ) +} + +fn processed_key(task_id: &TaskId, collection_job_id: &CollectionJobId) -> String { + format!( + "{PROCESSED_PREFIX}/tasks/{}/collection_jobs/{}", + task_id.to_base64url(), + collection_job_id.to_base64url() + ) } diff --git a/daphne_worker/src/durable/mod.rs b/daphne_worker/src/durable/mod.rs index e176ff35b..2366b15ef 100644 --- a/daphne_worker/src/durable/mod.rs +++ b/daphne_worker/src/durable/mod.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause use crate::{int_err, now}; -use daphne::{messages::Id, DapBatchBucket, DapVersion}; +use daphne::{messages::TaskId, DapBatchBucket, DapVersion}; use rand::prelude::*; use serde::{Deserialize, Serialize}; use worker::*; @@ -256,20 +256,6 @@ fn durable_name_bucket(bucket: &DapBatchBucket<'_>) -> String { } } -pub(crate) fn report_id_hex_from_report(report_hex: &str) -> Option<&str> { - // task_id (32 bytes) - if report_hex.len() < 64 { - return None; - } - let report_hex = &report_hex[64..]; - - // metadata.id - if report_hex.len() < 32 { - return None; - } - Some(&report_hex[..32]) -} - /// Reference to a DO instance, used by the garbage collector. #[derive(Deserialize, Serialize)] pub(crate) struct DurableReference { @@ -281,7 +267,7 @@ pub(crate) struct DurableReference { /// If applicable, the DAP task ID to which the DO instance is associated. #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) task_id: Option, + pub(crate) task_id: Option, } /// An element of a queue stored in a DO instance. diff --git a/daphne_worker/src/durable/mod_test.rs b/daphne_worker/src/durable/mod_test.rs index eb7b34a71..b2cbdcac4 100644 --- a/daphne_worker/src/durable/mod_test.rs +++ b/daphne_worker/src/durable/mod_test.rs @@ -3,10 +3,10 @@ use crate::durable::{ durable_name_agg_store, durable_name_queue, durable_name_report_store, - report_id_hex_from_report, + reports_pending::PendingReport, }; use daphne::{ - messages::{Id, Report, ReportId, ReportMetadata}, + messages::{BatchId, Report, ReportId, ReportMetadata, TaskId}, test_version, test_versions, DapBatchBucket, DapVersion, }; use paste::paste; @@ -16,8 +16,8 @@ use rand::prelude::*; #[test] fn durable_name() { let time = 1664850074; - let id1 = Id([17; 32]); - let id2 = Id([34; 32]); + let id1 = TaskId([17; 32]); + let id2 = BatchId([34; 32]); let shard = 1234; assert_eq!(durable_name_queue(shard), "queue/1234"); @@ -38,14 +38,15 @@ fn durable_name() { ); } -// Test that the `report_id_from_report()` method properly extracts the report ID from the +// Test that the `PendingReport.report_id_hex()` method properly extracts the report ID from the // hex-encoded report. This helps ensure that changes to the `Report` wire format don't cause any // regressions to `ReportStore`. fn parse_report_id_hex_from_report(version: DapVersion) { let mut rng = thread_rng(); + let task_id = TaskId([17; 32]); let report = Report { - task_id: Id(rng.gen()), - metadata: ReportMetadata { + draft02_task_id: task_id.for_request_payload(&version), + report_metadata: ReportMetadata { id: ReportId(rng.gen()), time: rng.gen(), extensions: Vec::default(), @@ -54,12 +55,23 @@ fn parse_report_id_hex_from_report(version: DapVersion) { encrypted_input_shares: Vec::default(), }; - let report_hex = hex::encode(report.get_encoded_with_param(&version)); - let key = report_id_hex_from_report(&report_hex).unwrap(); - assert_eq!( - ReportId::get_decoded_with_param(&version, &hex::decode(key).unwrap()).unwrap(), - report.metadata.id - ); + let pending_report = PendingReport { + task_id, + version, + report_hex: hex::encode(report.get_encoded_with_param(&version)), + }; + + let got = ReportId::get_decoded_with_param( + &version, + &hex::decode( + pending_report + .report_id_hex() + .expect("report_id_hex() failed"), + ) + .expect("hex::decode() failed"), + ) + .expect("ReportId::get_decoded_with_param() failed"); + assert_eq!(got, report.report_metadata.id); } test_versions! {parse_report_id_hex_from_report} diff --git a/daphne_worker/src/durable/reports_pending.rs b/daphne_worker/src/durable/reports_pending.rs index 4c121b683..2c18036fe 100644 --- a/daphne_worker/src/durable/reports_pending.rs +++ b/daphne_worker/src/durable/reports_pending.rs @@ -8,11 +8,12 @@ use crate::{ leader_agg_job_queue::{ DURABLE_LEADER_AGG_JOB_QUEUE_FINISH, DURABLE_LEADER_AGG_JOB_QUEUE_PUT, }, - report_id_hex_from_report, state_get, state_set_if_not_exists, DurableConnector, - DurableOrdered, BINDING_DAP_LEADER_AGG_JOB_QUEUE, BINDING_DAP_REPORTS_PENDING, + state_get, state_set_if_not_exists, DurableConnector, DurableOrdered, + BINDING_DAP_LEADER_AGG_JOB_QUEUE, BINDING_DAP_REPORTS_PENDING, }, initialize_tracing, int_err, }; +use daphne::{messages::TaskId, DapVersion}; use serde::{Deserialize, Serialize}; use tracing::debug; use worker::*; @@ -27,6 +28,30 @@ pub(crate) enum ReportsPendingResult { ErrReportExists, } +#[derive(Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct PendingReport { + pub(crate) task_id: TaskId, + pub(crate) version: DapVersion, + + /// Hex-encdoed, serialized report. + // + // TODO(cjpatton) Consider changing the type to `Report`. If I recall correctly, this triggers + // the serde-wasm-bindgen bug we saw in workers-rs 0.0.12, which should be fixed as of 0.0.15. + pub(crate) report_hex: String, +} + +impl PendingReport { + pub(crate) fn report_id_hex(&self) -> Option<&str> { + match self.version { + DapVersion::Draft02 if self.report_hex.len() >= 96 => Some(&self.report_hex[64..96]), + DapVersion::Draft04 if self.report_hex.len() >= 32 => Some(&self.report_hex[..32]), + DapVersion::Unknown => unreachable!("unhandled version {:?}", self.version), + _ => None, + } + } +} + /// Durable Object (DO) for storing reports waiting to be processed. /// /// The following API endpoints are defined: @@ -43,8 +68,8 @@ pub(crate) enum ReportsPendingResult { /// The schema for stored reports is as follows: /// /// ```text -/// [Pending report] pending/ -> String -/// [Aggregation job] agg_job -> DurableOrdered +/// [Pending report] pending/ -> PendingReport +/// [Aggregation job] agg_job -> DurableOrdered /// ``` /// /// where `` is the ID of the report. The value is the hex-encoded report. The @@ -82,7 +107,7 @@ impl DurableObject for ReportsPending { // Drain the requested number of reports from storage. // // Input: `reports_requested: usize` - // Output: `Vec` (hex-encoded reports) + // Output: `Vec` (DURABLE_REPORTS_PENDING_GET, Method::Post) => { let reports_requested: usize = req.json().await?; let opt = ListOptions::new() @@ -93,9 +118,9 @@ impl DurableObject for ReportsPending { let mut reports = Vec::with_capacity(reports_requested); let mut keys = Vec::with_capacity(reports_requested); while !item.done() { - let (key, report_hex): (String, String) = + let (key, pending_report): (String, PendingReport) = serde_wasm_bindgen::from_value(item.value()).map_err(int_err)?; - reports.push(report_hex); + reports.push(pending_report); keys.push(key); item = iter.next()?; } @@ -149,15 +174,15 @@ impl DurableObject for ReportsPending { // Store a report. // - // Input: `report_hex: String` (hex-encoded report) - // Output: `ReportsPendingResult` + // Input: `pending_report: PendingReport` + // Output: `ReportsPendingResult` (DURABLE_REPORTS_PENDING_PUT, Method::Post) => { - let report_hex: String = req.json().await?; - let report_id_hex = report_id_hex_from_report(&report_hex) - .ok_or_else(|| int_err("failed to parse report_id from report"))?; - + let pending_report: PendingReport = req.json().await?; + let report_id_hex = pending_report + .report_id_hex() + .ok_or_else(|| int_err("failed to parse report ID from report"))?; let key = format!("pending/{report_id_hex}"); - let exists = state_set_if_not_exists::(&self.state, &key, &report_hex) + let exists = state_set_if_not_exists(&self.state, &key, &pending_report) .await? .is_some(); if exists { diff --git a/daphne_worker/src/lib.rs b/daphne_worker/src/lib.rs index ec0a7f108..cc714a1b6 100644 --- a/daphne_worker/src/lib.rs +++ b/daphne_worker/src/lib.rs @@ -158,13 +158,13 @@ use crate::{ }; use daphne::{ auth::BearerToken, - constants, - messages::{decode_base64url, Duration, Id, Time}, + constants::{self, versioned_media_type_for}, + messages::{CollectionJobId, Duration, TaskId, Time}, roles::{DapAggregator, DapHelper, DapLeader}, DapAbort, DapCollectJob, DapError, DapResponse, }; use once_cell::sync::OnceCell; -use prio::codec::Encode; +use prio::codec::ParameterizedEncode; use serde::{Deserialize, Serialize}; use std::str; use tracing::{debug, error, info_span, Instrument}; @@ -180,20 +180,6 @@ pub struct DaphneWorkerReportSelector { pub max_reports: u64, } -macro_rules! parse_id { - ( - $option_str:expr - ) => { - match $option_str { - Some(id_base64url) => match decode_base64url(id_base64url.as_bytes()) { - Some(id_bytes) => Id(id_bytes), - None => return Response::error("Bad Request", 400), - }, - None => return Response::error("Bad Request", 400), - } - }; -} - /// HTTP request handler for Daphne-Worker. #[derive(Default)] pub struct DaphneWorkerRouter { @@ -249,7 +235,7 @@ impl DaphneWorkerRouter { let router = Router::with_data(&state) .get_async("/:version/hpke_config", |req, ctx| async move { let daph = ctx.data.handler(&ctx.env); - let req = daph.worker_request_to_dap(req).await?; + let req = daph.worker_request_to_dap(req, &ctx).await?; match daph .http_get_hpke_config(&req) .instrument(info_span!("hpke_config")) @@ -284,22 +270,11 @@ impl DaphneWorkerRouter { let router = match env.var("DAP_AGGREGATOR_ROLE")?.to_string().as_ref() { "leader" => { router - .post_async("/:version/upload", |req, ctx| async move { + .post_async("/v02/upload", put_report_into_task) // draft02 + .put_async("/:version/tasks/:task_id/reports", put_report_into_task) + .post_async("/v02/collect", |req, ctx| async move { let daph = ctx.data.handler(&ctx.env); - let req = daph.worker_request_to_dap(req).await?; - - match daph - .http_post_upload(&req) - .instrument(info_span!("upload")) - .await - { - Ok(()) => Response::empty(), - Err(e) => abort(e), - } - }) - .post_async("/:version/collect", |req, ctx| async move { - let daph = ctx.data.handler(&ctx.env); - let req = daph.worker_request_to_dap(req).await?; + let req = daph.worker_request_to_dap(req, &ctx).await?; match daph .http_post_collect(&req) @@ -316,22 +291,100 @@ impl DaphneWorkerRouter { } Err(e) => abort(e), } - }) + }) // draft02 .get_async( - "/:version/collect/task/:task_id/req/:collect_id", - |_req, ctx| async move { - let task_id = parse_id!(ctx.param("task_id")); - let collect_id = parse_id!(ctx.param("collect_id")); + "/v02/collect/task/:task_id/req/:collect_id", + |req, ctx| async move { + let task_id = + match ctx.param("task_id").and_then(TaskId::try_from_base64url) { + Some(id) => id, + None => { + return abort(DapAbort::BadRequest( + "missing task_id parameter".to_string(), + )) + } + }; + let collect_id = match ctx + .param("collect_id") + .and_then(CollectionJobId::try_from_base64url) + { + Some(id) => id, + None => { + return abort(DapAbort::BadRequest( + "missing collect_id parameter".to_string(), + )) + } + }; let daph = ctx.data.handler(&ctx.env); + let version = daph.extract_version_parameter(&req)?; match daph .poll_collect_job(&task_id, &collect_id) + .instrument(info_span!("poll_collect_job (draft02)")) + .await + { + Ok(DapCollectJob::Done(collect_resp)) => { + dap_response_to_worker(DapResponse { + media_type: versioned_media_type_for( + &version, + constants::MEDIA_TYPE_COLLECT_RESP, + ), + payload: collect_resp.get_encoded_with_param(&version), + }) + } + Ok(DapCollectJob::Pending) => { + Ok(Response::empty().unwrap().with_status(202)) + } + // TODO spec: Decide whether to define this behavior. + Ok(DapCollectJob::Unknown) => { + abort(DapAbort::BadRequest("unknown collect id".into())) + } + Err(e) => abort(e.into()), + } + }, + ) // draft02 + .put_async( + "/:version/tasks/:task_id/collection_jobs/:collect_job_id", + |req, ctx| async move { + let daph = ctx.data.handler(&ctx.env); + let req = daph.worker_request_to_dap(req, &ctx).await?; + + match daph + .http_post_collect(&req) + .instrument(info_span!("collect (PUT)")) + .await + { + Ok(_) => Ok(Response::empty().unwrap().with_status(201)), + Err(e) => abort(e), + } + }, + ) + .post_async( + "/:version/tasks/:task_id/collection_jobs/:collect_job_id", + |req, ctx| async move { + let daph = ctx.data.handler(&ctx.env); + let version = daph.extract_version_parameter(&req)?; + let req = daph.worker_request_to_dap(req, &ctx).await?; + let task_id = match req.task_id() { + Ok(id) => id, + Err(e) => return abort(e), + }; + let collect_job_id = match req.collection_job_id() { + Ok(id) => id, + Err(e) => return abort(e), + }; + + match daph + .poll_collect_job(task_id, collect_job_id) .instrument(info_span!("poll_collect_job")) .await { Ok(DapCollectJob::Done(collect_resp)) => { dap_response_to_worker(DapResponse { - media_type: Some(constants::MEDIA_TYPE_COLLECT_RESP), - payload: collect_resp.get_encoded(), + media_type: versioned_media_type_for( + &version, + constants::MEDIA_TYPE_COLLECT_RESP, + ), + payload: collect_resp.get_encoded_with_param(&version), }) } Ok(DapCollectJob::Pending) => { @@ -368,7 +421,15 @@ impl DaphneWorkerRouter { // task. The task ID and batch ID are both encoded in URL-safe base64. // // TODO(cjpatton) Only enable this if `self.enable_internal_test` is set. - let task_id = parse_id!(ctx.param("task_id")); + let task_id = + match ctx.param("task_id").and_then(TaskId::try_from_base64url) { + Some(id) => id, + None => { + return abort(DapAbort::BadRequest( + "missing or malformed task ID".into(), + )) + } + }; let daph = ctx.data.handler(&ctx.env); match daph .internal_current_batch(&task_id) @@ -385,32 +446,20 @@ impl DaphneWorkerRouter { } "helper" => router - .post_async("/:version/aggregate", |req, ctx| async move { - let daph = ctx.data.handler(&ctx.env); - let req = daph.worker_request_to_dap(req).await?; - - match daph - .http_post_aggregate(&req) - .instrument(info_span!("aggregate")) - .await - { - Ok(resp) => dap_response_to_worker(resp), - Err(e) => abort(e), - } - }) - .post_async("/:version/aggregate_share", |req, ctx| async move { - let daph = ctx.data.handler(&ctx.env); - let req = daph.worker_request_to_dap(req).await?; - - match daph - .http_post_aggregate_share(&req) - .instrument(info_span!("aggregate_share")) - .await - { - Ok(resp) => dap_response_to_worker(resp), - Err(e) => abort(e), - } - }), + .post_async("/:version/aggregate", handle_agg_job) // draft02 + .post_async("/:version/aggregate_share", handle_agg_share_req) // draft02 + .put_async( + "/:version/tasks/:task_id/aggregation_jobs/:agg_job_id", + handle_agg_job, + ) + .post_async( + "/:version/tasks/:task_id/aggregation_jobs/:agg_job_id", + handle_agg_job, + ) + .post_async( + "/:version/tasks/:task_id/aggregate_shares", + handle_agg_share_req, + ), _ => return abort(DapError::fatal("unexpected role").into()), }; @@ -518,6 +567,57 @@ impl DaphneWorkerRouter { } } +async fn put_report_into_task( + req: Request, + ctx: RouteContext<&DaphneWorkerRequestState<'_>>, +) -> Result { + let daph = ctx.data.handler(&ctx.env); + let req = daph.worker_request_to_dap(req, &ctx).await?; + + match daph + .http_post_upload(&req) + .instrument(info_span!("upload")) + .await + { + Ok(()) => Response::empty(), + Err(e) => abort(e), + } +} + +async fn handle_agg_job( + req: Request, + ctx: RouteContext<&DaphneWorkerRequestState<'_>>, +) -> Result { + let daph = ctx.data.handler(&ctx.env); + let req = daph.worker_request_to_dap(req, &ctx).await?; + + match daph + .http_post_aggregate(&req) + .instrument(info_span!("aggregate")) + .await + { + Ok(resp) => dap_response_to_worker(resp), + Err(e) => abort(e), + } +} + +async fn handle_agg_share_req( + req: Request, + ctx: RouteContext<&DaphneWorkerRequestState<'_>>, +) -> Result { + let daph = ctx.data.handler(&ctx.env); + let req = daph.worker_request_to_dap(req, &ctx).await?; + + match daph + .http_post_aggregate_share(&req) + .instrument(info_span!("aggregate_share")) + .await + { + Ok(resp) => dap_response_to_worker(resp), + Err(e) => abort(e), + } +} + pub(crate) fn now() -> u64 { Date::now().as_millis() / 1000 } diff --git a/daphne_worker_test/Cargo.toml b/daphne_worker_test/Cargo.toml index bb25a296b..e0ce76082 100644 --- a/daphne_worker_test/Cargo.toml +++ b/daphne_worker_test/Cargo.toml @@ -4,8 +4,8 @@ name = "daphne-worker-test" version = "0.3.0" authors = [ - "Christopher Patton ", - "Armando Faz Hernandez ", + "Christopher Patton ", + "Armando Faz Hernandez ", ] edition = "2021" license = "BSD-3-Clause" @@ -40,7 +40,7 @@ hex = { version = "0.4.3", features = ["serde"] } hpke-rs = "0.1.0" lazy_static = "1.4.0" paste = "1.0.12" -prio = "0.10.0" +prio = "0.12.0" rand = "0.8.5" reqwest = { version = "0.11.14", features = ["json"] } ring = "0.16.20" diff --git a/daphne_worker_test/tests/e2e.rs b/daphne_worker_test/tests/e2e.rs index 586b1db1c..fa09dbb68 100644 --- a/daphne_worker_test/tests/e2e.rs +++ b/daphne_worker_test/tests/e2e.rs @@ -11,18 +11,19 @@ use daphne::{ taskprov::{ DpConfig, QueryConfig, QueryConfigVar, TaskConfig, UrlBytes, VdafConfig, VdafTypeVar, }, - BatchSelector, CollectReq, CollectResp, Extension, HpkeCiphertext, Id, Interval, Query, - Report, ReportId, ReportMetadata, + BatchSelector, Collection, CollectionReq, Extension, HpkeCiphertext, Interval, Query, + Report, ReportId, ReportMetadata, TaskId, }, taskprov::{compute_task_id, TaskprovVersion}, DapAggregateResult, DapMeasurement, DapTaskConfig, DapVersion, }; use daphne_worker::DaphneWorkerReportSelector; use paste::paste; -use prio::codec::{Decode, Encode, ParameterizedEncode}; +use prio::codec::{ParameterizedDecode, ParameterizedEncode}; use rand::prelude::*; use serde::Deserialize; use serde_json::json; +use std::cmp::{max, min}; use test_runner::{TestRunner, MIN_BATCH_SIZE, TIME_PRECISION}; use url::Url; @@ -89,7 +90,7 @@ async fn e2e_leader_endpoint_for_task(version: DapVersion, want_prefix: bool) { let expected = if want_prefix { format!("/{}/", version.as_ref()) } else { - String::from("/v03/") // Must match DAP_DEFAULT_VERSION + String::from("/v04/") // Must match DAP_DEFAULT_VERSION }; assert_eq!(res.endpoint.unwrap(), expected); } @@ -130,7 +131,7 @@ async fn e2e_helper_endpoint_for_task(version: DapVersion, want_prefix: bool) { let expected = if want_prefix { format!("/{}/", version.as_ref()) } else { - String::from("/v03/") // Must match DAP_DEFAULT_VERSION + String::from("/v04/") // Must match DAP_DEFAULT_VERSION }; assert_eq!(res.endpoint.unwrap(), expected); } @@ -184,7 +185,7 @@ async fn e2e_leader_upload(version: DapVersion) { let mut rng = thread_rng(); let client = t.http_client(); let hpke_config_list = t.get_hpke_configs(version, &client).await; - let path = "upload"; + let path = t.upload_path(); // Generate and upload a report. let report = t @@ -198,19 +199,19 @@ async fn e2e_leader_upload(version: DapVersion) { version, ) .unwrap(); - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - path, + &path, constants::MEDIA_TYPE_REPORT, report.get_encoded_with_param(&version), ) .await; // Try uploading the same report a second time (expect failure due to repeated ID. - t.leader_post_expect_abort( + t.leader_put_expect_abort( &client, None, // dap_auth_token - path, + &path, constants::MEDIA_TYPE_REPORT, report.get_encoded_with_param(&version), 400, @@ -219,17 +220,19 @@ async fn e2e_leader_upload(version: DapVersion) { .await; // Try uploading a report with the incorrect task ID. - t.leader_post_expect_abort( + let bad_id = TaskId(rng.gen()); + let bad_path = t.upload_path_for_task(&bad_id); + t.leader_put_expect_abort( &client, None, // dap_auth_token - path, + &bad_path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf .produce_report( &hpke_config_list, t.now, - &Id(rng.gen()), + &bad_id, DapMeasurement::U64(999), version, ) @@ -253,10 +256,10 @@ async fn e2e_leader_upload(version: DapVersion) { ) .unwrap(); report.encrypted_input_shares[0].config_id ^= 0xff; - t.leader_post_expect_abort( + t.leader_put_expect_abort( &client, None, // dap_auth_token - path, + &path, constants::MEDIA_TYPE_REPORT, report.get_encoded_with_param(&version), 400, @@ -265,10 +268,10 @@ async fn e2e_leader_upload(version: DapVersion) { .await; // Try uploading a malformed report. - t.leader_post_expect_abort( + t.leader_put_expect_abort( &client, None, // dap_auth_token - path, + &path, constants::MEDIA_TYPE_REPORT, b"junk data".to_vec(), 400, @@ -288,10 +291,10 @@ async fn e2e_leader_upload(version: DapVersion) { version, ) .unwrap(); - t.leader_post_expect_abort( + t.leader_put_expect_abort( &client, None, // dap_auth_token - path, + &path, constants::MEDIA_TYPE_REPORT, report.get_encoded_with_param(&version), 400, @@ -302,13 +305,17 @@ async fn e2e_leader_upload(version: DapVersion) { // Upload a fixed report. This is a sanity check to make sure that the test resets the Leader's // state each time the test is run. If it didn't, this would result in an error due to the // report ID being repeated. - let url = t.leader_url.join(path).unwrap(); - let resp = client - .post(url.as_str()) + let url = t.leader_url.join(&path).unwrap(); + let builder = match t.version { + DapVersion::Draft02 => client.post(url.as_str()), + DapVersion::Draft04 => client.put(url.as_str()), + _ => unreachable!("unhandled version {}", t.version), + }; + let resp = builder .body( Report { - task_id: t.task_id.clone(), - metadata: ReportMetadata { + draft02_task_id: t.task_id.for_request_payload(&version), + report_metadata: ReportMetadata { id: ReportId([1; 16]), time: t.now, extensions: Vec::default(), @@ -515,6 +522,7 @@ async fn e2e_leader_upload_taskprov() { async fn e2e_internal_leader_process(version: DapVersion) { let t = TestRunner::default_with_version(version).await; + let path = t.upload_path(); let client = t.http_client(); let hpke_config_list = t.get_hpke_configs(version, &client).await; @@ -530,9 +538,9 @@ async fn e2e_internal_leader_process(version: DapVersion) { let mut rng = thread_rng(); for _ in 0..report_sel.max_reports + 3 { let now = rng.gen_range(t.report_interval(&batch_interval)); - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -577,14 +585,15 @@ async fn e2e_leader_process_min_agg_rate(version: DapVersion) { let client = t.http_client(); let batch_interval = t.batch_interval(); let hpke_config_list = t.get_hpke_configs(version, &client).await; + let path = t.upload_path(); // The reports are uploaded in the background. let mut rng = thread_rng(); for _ in 0..7 { let now = rng.gen_range(t.report_interval(&batch_interval)); - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -625,14 +634,19 @@ async fn e2e_leader_collect_ok(version: DapVersion) { let client = t.http_client(); let hpke_config_list = t.get_hpke_configs(version, &client).await; + let path = t.upload_path(); // The reports are uploaded in the background. let mut rng = thread_rng(); + let mut time_min = u64::MAX; + let mut time_max = 0u64; for _ in 0..t.task_config.min_batch_size { let now = rng.gen_range(t.report_interval(&batch_interval)); - t.leader_post_expect_ok( + time_min = min(time_min, now); + time_max = max(time_max, now); + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -650,8 +664,8 @@ async fn e2e_leader_collect_ok(version: DapVersion) { } // Get the collect URI. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval: batch_interval.clone(), }, @@ -663,7 +677,7 @@ async fn e2e_leader_collect_ok(version: DapVersion) { println!("collect_uri: {}", collect_uri); // Poll the collect URI before the CollectResp is ready. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 202, "response: {:?}", resp); // The reports are aggregated in the background. @@ -690,10 +704,11 @@ async fn e2e_leader_collect_ok(version: DapVersion) { ); // Poll the collect URI. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 200); - let collect_resp = CollectResp::get_decoded(&resp.bytes().await.unwrap()).unwrap(); + let collection = + Collection::get_decoded_with_param(&t.version, &resp.bytes().await.unwrap()).unwrap(); let agg_res = t .task_config .vdaf @@ -703,8 +718,8 @@ async fn e2e_leader_collect_ok(version: DapVersion) { &BatchSelector::TimeInterval { batch_interval: batch_interval.clone(), }, - collect_resp.report_count, - collect_resp.encrypted_agg_shares.clone(), + collection.report_count, + collection.encrypted_agg_shares.clone(), version, ) .await @@ -714,11 +729,24 @@ async fn e2e_leader_collect_ok(version: DapVersion) { DapAggregateResult::U128(t.task_config.min_batch_size as u128) ); + if version != DapVersion::Draft02 { + // Check that the time interval for the reports is correct. + let interval = collection.interval.as_ref().unwrap(); + let low = t.task_config.quantized_time_lower_bound(time_min); + let high = t.task_config.quantized_time_upper_bound(time_max); + assert!(low < high); + assert_eq!(interval.start, low); + assert_eq!(interval.duration, high - low); + } + // Poll the collect URI once more. Expect the response to be the same as the first, per HTTP // GET semantics. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 200); - assert_eq!(resp.bytes().await.unwrap(), collect_resp.get_encoded()); + assert_eq!( + resp.bytes().await.unwrap(), + collection.get_encoded_with_param(&version) + ); // NOTE Our Leader doesn't check if a report is stale until it is ready to process it. As such, // It won't tell the Client at this point that its report is stale. Delaying this check allows @@ -750,14 +778,15 @@ async fn e2e_leader_collect_ok_interleaved(version: DapVersion) { let client = t.http_client(); let batch_interval = t.batch_interval(); let hpke_config_list = t.get_hpke_configs(version, &client).await; + let path = t.upload_path(); // The reports are uploaded in the background. let mut rng = thread_rng(); for _ in 0..t.task_config.min_batch_size { let now = rng.gen_range(t.report_interval(&batch_interval)); - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -787,8 +816,8 @@ async fn e2e_leader_collect_ok_interleaved(version: DapVersion) { ); // ... then the collect request is issued ... - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval: batch_interval.clone(), }, @@ -813,14 +842,15 @@ async fn e2e_leader_collect_not_ready_min_batch_size(version: DapVersion) { let batch_interval = t.batch_interval(); let client = t.http_client(); let hpke_config_list = t.get_hpke_configs(version, &client).await; + let path = t.upload_path(); // A number of reports are uploaded, but not enough to meet the minimum batch requirement. let mut rng = thread_rng(); for _ in 0..t.task_config.min_batch_size - 1 { let now = rng.gen_range(t.report_interval(&batch_interval)); - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -838,8 +868,8 @@ async fn e2e_leader_collect_not_ready_min_batch_size(version: DapVersion) { } // Get the collect URI. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval: batch_interval.clone(), }, @@ -871,7 +901,7 @@ async fn e2e_leader_collect_not_ready_min_batch_size(version: DapVersion) { assert_eq!(agg_telem.reports_collected, 0); // Poll the collect URI before the CollectResp is ready. - let resp = client.get(collect_uri).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 202); } @@ -882,17 +912,21 @@ async fn e2e_leader_collect_abort_unknown_request(version: DapVersion) { let client = t.http_client(); // Poll collect URI for an unknown collect request. - let fake_id = Id([0; 32]); - let collect_uri = t - .leader_url - .join(&format!( - "collect/task/{}/req/{}", - fake_id.to_base64url(), - fake_id.to_base64url() - )) - .unwrap(); - let resp = client.get(collect_uri).send().await.unwrap(); - assert_eq!(resp.status(), 400); + let fake_task_id = TaskId([0; 32]); + let fake_collection_job_id = TaskId([0; 32]); + let url_suffix = if t.version == DapVersion::Draft02 { + format!("collect/task/{fake_task_id}/req/{fake_collection_job_id}") + } else { + format!("/tasks/{fake_task_id}/collection_jobs/{fake_collection_job_id}") + }; + let expected_status = if t.version == DapVersion::Draft02 { + 400 + } else { + 404 + }; + let collect_uri = t.leader_url.join(&url_suffix).unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; + assert_eq!(resp.status(), expected_status); } async_test_versions! { e2e_leader_collect_abort_unknown_request } @@ -901,15 +935,14 @@ async fn e2e_leader_collect_accept_global_config_max_batch_duration(version: Dap let t = TestRunner::default_with_version(version).await; let client = t.http_client(); let batch_interval = Interval { - start: t.now - - (t.now % t.task_config.time_precision) + start: t.task_config.quantized_time_lower_bound(t.now) - t.global_config.max_batch_duration / 2, duration: t.global_config.max_batch_duration, }; // Maximum allowed batch duration. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval }, agg_param: Vec::new(), }; @@ -924,11 +957,11 @@ async fn e2e_leader_collect_abort_invalid_batch_interval(version: DapVersion) { let t = TestRunner::default_with_version(version).await; let client = t.http_client(); let batch_interval = t.batch_interval(); - let path = "collect"; + let path = &t.collect_url_suffix(); // Start of batch interval does not align with time_precision. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval: Interval { start: batch_interval.start + 1, @@ -937,20 +970,33 @@ async fn e2e_leader_collect_abort_invalid_batch_interval(version: DapVersion) { }, agg_param: Vec::new(), }; - t.leader_post_expect_abort( - &client, - Some(&t.collector_bearer_token), - path, - constants::MEDIA_TYPE_COLLECT_REQ, - collect_req.get_encoded_with_param(&t.version), - 400, - "batchInvalid", - ) - .await; + if t.version == DapVersion::Draft02 { + t.leader_post_expect_abort( + &client, + Some(&t.collector_bearer_token), + path, + constants::MEDIA_TYPE_COLLECT_REQ, + collect_req.get_encoded_with_param(&t.version), + 400, + "batchInvalid", + ) + .await; + } else { + t.leader_put_expect_abort( + &client, + Some(&t.collector_bearer_token), + path, + constants::MEDIA_TYPE_COLLECT_REQ, + collect_req.get_encoded_with_param(&t.version), + 400, + "batchInvalid", + ) + .await; + } // Batch interval duration does not align wiht min_batch_duration. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval: Interval { start: batch_interval.start, @@ -959,16 +1005,29 @@ async fn e2e_leader_collect_abort_invalid_batch_interval(version: DapVersion) { }, agg_param: Vec::new(), }; - t.leader_post_expect_abort( - &client, - Some(&t.collector_bearer_token), - path, - constants::MEDIA_TYPE_COLLECT_REQ, - collect_req.get_encoded_with_param(&t.version), - 400, - "batchInvalid", - ) - .await; + if t.version == DapVersion::Draft02 { + t.leader_post_expect_abort( + &client, + Some(&t.collector_bearer_token), + path, + constants::MEDIA_TYPE_COLLECT_REQ, + collect_req.get_encoded_with_param(&t.version), + 400, + "batchInvalid", + ) + .await; + } else { + t.leader_put_expect_abort( + &client, + Some(&t.collector_bearer_token), + path, + constants::MEDIA_TYPE_COLLECT_REQ, + collect_req.get_encoded_with_param(&t.version), + 400, + "batchInvalid", + ) + .await; + } } async_test_versions! { e2e_leader_collect_abort_invalid_batch_interval } @@ -978,14 +1037,15 @@ async fn e2e_leader_collect_abort_overlapping_batch_interval(version: DapVersion let batch_interval = t.batch_interval(); let client = t.http_client(); let hpke_config_list = t.get_hpke_configs(version, &client).await; + let path = t.upload_path(); // The reports are uploaded in the background. let mut rng = thread_rng(); for _ in 0..t.task_config.min_batch_size { let now = rng.gen_range(t.report_interval(&batch_interval)); - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -1003,8 +1063,8 @@ async fn e2e_leader_collect_abort_overlapping_batch_interval(version: DapVersion } // Get the collect URI. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval: batch_interval.clone(), }, @@ -1042,8 +1102,8 @@ async fn e2e_leader_collect_abort_overlapping_batch_interval(version: DapVersion // NOTE: Since DURABLE_LEADER_COL_JOB_QUEUE_PUT has a mechanism to reject CollectReq // with the EXACT SAME content as previous requests, we need to tweak the request // a little bit. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: Query::TimeInterval { batch_interval: Interval { start: batch_interval.start, @@ -1052,16 +1112,30 @@ async fn e2e_leader_collect_abort_overlapping_batch_interval(version: DapVersion }, agg_param: Vec::new(), }; - t.leader_post_expect_abort( - &client, - Some(&t.collector_bearer_token), - "collect", - constants::MEDIA_TYPE_COLLECT_REQ, - collect_req.get_encoded_with_param(&t.version), - 400, - "batchOverlap", - ) - .await; + let path = t.collect_url_suffix(); + if t.version == DapVersion::Draft02 { + t.leader_post_expect_abort( + &client, + Some(&t.collector_bearer_token), + &path, + constants::MEDIA_TYPE_COLLECT_REQ, + collect_req.get_encoded_with_param(&t.version), + 400, + "batchOverlap", + ) + .await; + } else { + t.leader_put_expect_abort( + &client, + Some(&t.collector_bearer_token), + &path, + constants::MEDIA_TYPE_COLLECT_REQ, + collect_req.get_encoded_with_param(&t.version), + 400, + "batchOverlap", + ) + .await; + } } async_test_versions! { e2e_leader_collect_abort_overlapping_batch_interval } @@ -1075,6 +1149,7 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { return; } let t = TestRunner::fixed_size(version).await; + let path = t.upload_path(); let report_sel = DaphneWorkerReportSelector { max_agg_jobs: 100, // Needs to be sufficiently large to touch each bucket. max_reports: 100, @@ -1085,9 +1160,9 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { // Clients: Upload reports. for _ in 0..t.task_config.min_batch_size { - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -1123,8 +1198,8 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { let batch_id = t.internal_current_batch(&t.task_id).await; // Collector: Get the collect URI. - let collect_req = CollectReq { - task_id: t.task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: t.collect_task_id_field(), query: if use_current { Query::FixedSizeCurrentBatch } else { @@ -1140,7 +1215,7 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { println!("collect_uri: {}", collect_uri); // Collector: Poll the collect URI before the CollectResp is ready. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 202, "response: {:?}", resp); // ... Aggregators run processing loop. @@ -1153,10 +1228,11 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { ); // Collector: Poll the collect URI. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 200); - let collect_resp = CollectResp::get_decoded(&resp.bytes().await.unwrap()).unwrap(); + let collection = + Collection::get_decoded_with_param(&t.version, &resp.bytes().await.unwrap()).unwrap(); let agg_res = t .task_config .vdaf @@ -1166,8 +1242,8 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { &BatchSelector::FixedSizeByBatchId { batch_id: batch_id.clone(), }, - collect_resp.report_count, - collect_resp.encrypted_agg_shares.clone(), + collection.report_count, + collection.encrypted_agg_shares.clone(), version, ) .await @@ -1179,15 +1255,18 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { // Collector: Poll the collect URI once more. Expect the response to be the same as the first, // per HTTP GET semantics. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 200); - assert_eq!(resp.bytes().await.unwrap(), collect_resp.get_encoded()); + assert_eq!( + resp.bytes().await.unwrap(), + collection.get_encoded_with_param(&t.version) + ); // Clients: Upload reports. for _ in 0..2 { - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, t.task_config .vdaf @@ -1217,23 +1296,43 @@ async fn e2e_fixed_size(version: DapVersion, use_current: bool) { assert_ne!(batch_id, prev_batch_id); // Collector: Try CollectReq with out-dated batch ID. - t.leader_post_expect_abort( - &client, - Some(&t.collector_bearer_token), - "collect", - constants::MEDIA_TYPE_COLLECT_REQ, - CollectReq { - task_id: t.task_id.clone(), - query: Query::FixedSizeByBatchId { - batch_id: prev_batch_id.clone(), - }, - agg_param: Vec::new(), - } - .get_encoded_with_param(&t.version), - 400, - "batchOverlap", - ) - .await; + if t.version == DapVersion::Draft02 { + t.leader_post_expect_abort( + &client, + Some(&t.collector_bearer_token), + &t.collect_url_suffix(), + constants::MEDIA_TYPE_COLLECT_REQ, + CollectionReq { + draft02_task_id: t.collect_task_id_field(), + query: Query::FixedSizeByBatchId { + batch_id: prev_batch_id.clone(), + }, + agg_param: Vec::new(), + } + .get_encoded_with_param(&t.version), + 400, + "batchOverlap", + ) + .await; + } else { + t.leader_put_expect_abort( + &client, + Some(&t.collector_bearer_token), + &t.collect_url_suffix(), + constants::MEDIA_TYPE_COLLECT_REQ, + CollectionReq { + draft02_task_id: t.collect_task_id_field(), + query: Query::FixedSizeByBatchId { + batch_id: prev_batch_id.clone(), + }, + agg_param: Vec::new(), + } + .get_encoded_with_param(&t.version), + 400, + "batchOverlap", + ) + .await; + } } async fn e2e_fixed_size_no_current(version: DapVersion) { @@ -1288,6 +1387,7 @@ async fn e2e_leader_collect_taskprov_ok(version: DapVersion) { &t.taskprov_collector_hpke_receiver.config, ) .unwrap(); + let path = t.upload_path_for_task(&task_id); // The reports are uploaded in the background. let mut rng = thread_rng(); @@ -1296,9 +1396,9 @@ async fn e2e_leader_collect_taskprov_ok(version: DapVersion) { payload: payload.clone(), }]; let now = rng.gen_range(t.report_interval(&batch_interval)); - t.leader_post_expect_ok( + t.leader_put_expect_ok( &client, - "upload", + &path, constants::MEDIA_TYPE_REPORT, task_config .vdaf @@ -1317,8 +1417,8 @@ async fn e2e_leader_collect_taskprov_ok(version: DapVersion) { } // Get the collect URI. - let collect_req = CollectReq { - task_id: task_id.clone(), + let collect_req = CollectionReq { + draft02_task_id: Some(task_id.clone()), query: Query::TimeInterval { batch_interval: batch_interval.clone(), }, @@ -1334,7 +1434,7 @@ async fn e2e_leader_collect_taskprov_ok(version: DapVersion) { println!("collect_uri: {}", collect_uri); // Poll the collect URI before the CollectResp is ready. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 202, "response: {:?}", resp); // The reports are aggregated in the background. @@ -1361,10 +1461,11 @@ async fn e2e_leader_collect_taskprov_ok(version: DapVersion) { ); // Poll the collect URI. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 200); - let collect_resp = CollectResp::get_decoded(&resp.bytes().await.unwrap()).unwrap(); + let collection = + Collection::get_decoded_with_param(&t.version, &resp.bytes().await.unwrap()).unwrap(); let agg_res = task_config .vdaf .consume_encrypted_agg_shares( @@ -1373,8 +1474,8 @@ async fn e2e_leader_collect_taskprov_ok(version: DapVersion) { &BatchSelector::TimeInterval { batch_interval: batch_interval.clone(), }, - collect_resp.report_count, - collect_resp.encrypted_agg_shares.clone(), + collection.report_count, + collection.encrypted_agg_shares.clone(), version, ) .await @@ -1386,9 +1487,12 @@ async fn e2e_leader_collect_taskprov_ok(version: DapVersion) { // Poll the collect URI once more. Expect the response to be the same as the first, per HTTP // GET semantics. - let resp = client.get(collect_uri.as_str()).send().await.unwrap(); + let resp = t.poll_collection_url(&client, &collect_uri).await; assert_eq!(resp.status(), 200); - assert_eq!(resp.bytes().await.unwrap(), collect_resp.get_encoded()); + assert_eq!( + resp.bytes().await.unwrap(), + collection.get_encoded_with_param(&t.version) + ); } async_test_version! { e2e_leader_collect_taskprov_ok, Draft02 } diff --git a/daphne_worker_test/tests/janus.rs b/daphne_worker_test/tests/janus.rs index 1ab9bee26..645806486 100644 --- a/daphne_worker_test/tests/janus.rs +++ b/daphne_worker_test/tests/janus.rs @@ -103,7 +103,7 @@ async fn janus_helper() { assert_eq!(agg_telem.reports_aggregated, 0); // Get the collect URI. - let collect_req = daphne::messages::CollectReq { + let collect_req = daphne::messages::CollectionReq { task_id: t.task_id.clone(), query: daphne::messages::Query::TimeInterval { batch_interval: batch_interval.clone(), @@ -124,7 +124,7 @@ async fn janus_helper() { let decrypter: daphne::hpke::HpkeReceiverConfig = serde_json::from_str(COLLECTOR_HPKE_RECEIVER_CONFIG).unwrap(); let collect_resp = - daphne::messages::CollectResp::get_decoded(&resp.bytes().await.unwrap()).unwrap(); + daphne::messages::Collection::get_decoded(&resp.bytes().await.unwrap()).unwrap(); let agg_res = t .task_config .vdaf diff --git a/daphne_worker_test/tests/test_runner.rs b/daphne_worker_test/tests/test_runner.rs index f25285bbd..5fb43e14c 100644 --- a/daphne_worker_test/tests/test_runner.rs +++ b/daphne_worker_test/tests/test_runner.rs @@ -8,8 +8,8 @@ use daphne::{ constants::MEDIA_TYPE_COLLECT_REQ, hpke::HpkeReceiverConfig, messages::{ - decode_base64url, encode_base64url, Duration, HpkeAeadId, HpkeConfig, HpkeConfigList, - HpkeKdfId, HpkeKemId, Id, Interval, + encode_base64url, BatchId, CollectionJobId, Duration, HpkeAeadId, HpkeConfig, + HpkeConfigList, HpkeKdfId, HpkeKemId, Interval, TaskId, }, taskprov::TaskprovVersion, DapGlobalConfig, DapLeaderProcessTelemetry, DapQueryConfig, DapTaskConfig, DapVersion, @@ -44,7 +44,7 @@ struct InternalTestAddTaskResult { #[allow(dead_code)] pub struct TestRunner { pub global_config: DapGlobalConfig, - pub task_id: Id, + pub task_id: TaskId, pub task_config: DapTaskConfig, pub now: u64, pub leader_url: Url, @@ -84,13 +84,13 @@ impl TestRunner { .unwrap() .as_secs(); - let task_id = Id(rng.gen()); + let task_id = TaskId(rng.gen()); // When running in a local development environment, override the hostname of each // aggregator URL with 127.0.0.1. let version_path = match version { DapVersion::Draft02 => "v02", - DapVersion::Draft03 => "v03", + DapVersion::Draft04 => "v04", _ => panic!("unimplemented DapVersion"), }; let mut leader_url = Url::parse(&format!("http://leader:8787/{}/", version_path)).unwrap(); @@ -251,7 +251,7 @@ impl TestRunner { } pub fn batch_interval(&self) -> Interval { - let start = self.now - (self.now % self.task_config.time_precision); + let start = self.task_config.quantized_time_lower_bound(self.now); Interval { start, duration: self.task_config.time_precision * 2, @@ -377,13 +377,112 @@ impl TestRunner { ); } + /// Send a PUT request or, if draft02 is in use, a POST request. + pub async fn leader_put_expect_ok( + &self, + client: &reqwest::Client, + path: &str, + media_type: &str, + data: Vec, + ) { + // draft02 always POSTs + if self.version == DapVersion::Draft02 { + return self + .leader_post_expect_ok(client, path, media_type, data) + .await; + } + let url = self.leader_url.join(path).unwrap(); + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::CONTENT_TYPE, media_type.parse().unwrap()); + let resp = client + .put(url.as_str()) + .body(data) + .send() + .await + .expect("request failed"); + + assert_eq!( + reqwest::StatusCode::from_u16(200).unwrap(), + resp.status(), + "unexpected response status: {:?}", + resp.text().await.unwrap() + ); + } + + /// Send a PUT request or, if draft02 is in use, a POST request, and expect an abort. + #[allow(clippy::too_many_arguments)] + pub async fn leader_put_expect_abort( + &self, + client: &reqwest::Client, + dap_auth_token: Option<&str>, + path: &str, + media_type: &str, + data: Vec, + expected_status: u16, + expected_err_type: &str, + ) { + // draft02 always POSTs + if self.version == DapVersion::Draft02 { + return self + .leader_post_expect_abort( + client, + dap_auth_token, + path, + media_type, + data, + expected_status, + expected_err_type, + ) + .await; + } + + let url = self.leader_url.join(path).unwrap(); + + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(reqwest::header::CONTENT_TYPE, media_type.parse().unwrap()); + if let Some(token) = dap_auth_token { + headers.insert( + reqwest::header::HeaderName::from_static("dap-auth-token"), + reqwest::header::HeaderValue::from_str(token).unwrap(), + ); + } + + let resp = client + .put(url.as_str()) + .body(data) + .headers(headers) + .send() + .await + .expect("request failed"); + + assert_eq!( + reqwest::StatusCode::from_u16(expected_status).unwrap(), + resp.status(), + "unexpected response status: {:?}", + resp.text().await.unwrap() + ); + + assert_eq!( + resp.headers().get("Content-Type").unwrap(), + "application/problem+json" + ); + + let problem_details: serde_json::Value = resp.json().await.unwrap(); + let got = problem_details.as_object().unwrap().get("type").unwrap(); + assert_eq!( + got, + &format!("urn:ietf:params:ppm:dap:error:{}", expected_err_type) + ); + } + pub async fn leader_post_collect_using_token( &self, client: &reqwest::Client, data: Vec, token: &str, ) -> Url { - let url = self.leader_url.join("collect").unwrap(); + let url_suffix = self.collect_url_suffix(); + let url = self.leader_url.join(&url_suffix).unwrap(); let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, @@ -393,22 +492,36 @@ impl TestRunner { reqwest::header::HeaderName::from_static("dap-auth-token"), reqwest::header::HeaderValue::from_str(token).unwrap(), ); - let resp = client - .post(url.as_str()) + let builder = if self.version == DapVersion::Draft02 { + client.post(url.as_str()) + } else { + client.put(url.as_str()) + }; + let resp = builder .body(data) .headers(headers) .send() .await .expect("request failed"); + let expected_status = if self.version == DapVersion::Draft02 { + 303 + } else { + 201 + }; + assert_eq!( resp.status(), - 303, + expected_status, "request failed: {:?}", resp.text().await.unwrap() ); - let collect_uri = resp.headers().get("Location").unwrap().to_str().unwrap(); - collect_uri.parse().unwrap() + if self.version == DapVersion::Draft02 { + let collect_uri = resp.headers().get("Location").unwrap().to_str().unwrap(); + collect_uri.parse().unwrap() + } else { + url + } } pub async fn leader_post_collect(&self, client: &reqwest::Client, data: Vec) -> Url { @@ -422,7 +535,7 @@ impl TestRunner { client: &reqwest::Client, report_sel: &DaphneWorkerReportSelector, ) -> DapLeaderProcessTelemetry { - // Replace path "/v02" with "/internal/process". + // Replace path "/v04" with "/internal/process". let mut url = self.leader_url.clone(); url.set_path("internal/process"); @@ -453,7 +566,7 @@ impl TestRunner { } else { self.helper_url.clone() }; - url.set_path(path); // Overwrites the version path (i.e., "/v02") + url.set_path(path); // Overwrites the version path (i.e., "/v04") let resp = client .post(url.clone()) .json(data) @@ -492,7 +605,7 @@ impl TestRunner { } #[allow(dead_code)] - pub async fn internal_current_batch(&self, task_id: &Id) -> Id { + pub async fn internal_current_batch(&self, task_id: &TaskId) -> BatchId { let client = self.http_client(); let mut url = self.leader_url.clone(); url.set_path(&format!( @@ -506,13 +619,64 @@ impl TestRunner { .expect("request failed"); if resp.status() == 200 { let batch_id_base64url = resp.text().await.unwrap(); - let batch_id = Id(decode_base64url(batch_id_base64url.as_bytes()) - .expect("Failed to parse URL-safe base64 batch ID")); - batch_id + BatchId::try_from_base64url(batch_id_base64url) + .expect("Failed to parse URL-safe base64 batch ID") } else { panic!("request to {} failed: response: {:?}", url, resp); } } + + pub fn upload_path_for_task(&self, id: &TaskId) -> String { + match self.version { + DapVersion::Draft02 => "upload".to_string(), + DapVersion::Draft04 => format!("tasks/{}/reports", id.to_base64url()), + _ => unreachable!("unknown version"), + } + } + + pub fn upload_path(&self) -> String { + self.upload_path_for_task(&self.task_id) + } + + pub fn collect_url_suffix(&self) -> String { + if self.version == DapVersion::Draft02 { + "collect".to_string() + } else { + let mut rng = thread_rng(); + let collect_job_id = CollectionJobId(rng.gen()); + format!( + "tasks/{}/collection_jobs/{}", + self.task_id.to_base64url(), + collect_job_id.to_base64url() + ) + } + } + + pub async fn poll_collection_url( + &self, + client: &reqwest::Client, + url: &Url, + ) -> reqwest::Response { + let builder = if self.version == DapVersion::Draft02 { + client.get(url.as_str()) + } else { + client.post(url.as_str()) + }; + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + reqwest::header::HeaderValue::from_static(MEDIA_TYPE_COLLECT_REQ), + ); + builder.headers(headers).send().await.unwrap() + } + + pub fn collect_task_id_field(&self) -> Option { + if self.version == DapVersion::Draft02 { + Some(self.task_id.clone()) + } else { + None + } + } } #[cfg(feature = "test_janus")] @@ -680,7 +844,7 @@ async fn post_internal_delete_all( base_url: &Url, batch_interval: &Interval, ) { - // Replace path "/v02" with "/internal/delete_all". + // Replace path "/v04" with "/internal/delete_all". let mut url = base_url.clone(); url.set_path("internal/delete_all"); diff --git a/daphne_worker_test/wrangler.toml b/daphne_worker_test/wrangler.toml index bb4e835f3..338a0f2c6 100644 --- a/daphne_worker_test/wrangler.toml +++ b/daphne_worker_test/wrangler.toml @@ -23,7 +23,7 @@ fallthrough = false DAP_AGGREGATOR_ROLE = "leader" DAP_BASE_URL = "http://127.0.0.1:8787/" DAP_ISSUE73_DISABLE_AGG_JOB_QUEUE_GARBAGE_COLLECTION = "true" -DAP_COLLECT_ID_KEY = "b416a85d280591d6da14e5b75a7d6e31" # SECRET +DAP_COLLECTION_JOB_ID_KEY = "b416a85d280591d6da14e5b75a7d6e31" # SECRET DAP_REPORT_SHARD_KEY = "61cd9685547370cfea76c2eb8d156ad9" # SECRET DAP_REPORT_SHARD_COUNT = "2" DAP_GLOBAL_CONFIG = """{ @@ -54,7 +54,7 @@ DAP_TASKPROV_LEADER_AUTH = """{ DAP_TASKPROV_COLLECTOR_AUTH = """{ "bearer_token": "I am the collector!" }""" # SECRET -DAP_DEFAULT_VERSION = "v03" +DAP_DEFAULT_VERSION = "v04" DAP_TRACING = "debug" [env.leader.durable_objects] @@ -131,7 +131,7 @@ DAP_TASKPROV_VDAF_VERIFY_KEY_INIT = "b029a72fa327931a5cb643dcadcaafa098fcbfac07d DAP_TASKPROV_LEADER_AUTH = """{ "bearer_token": "I am the leader!" }""" # SECRET -DAP_DEFAULT_VERSION = "v03" +DAP_DEFAULT_VERSION = "v04" DAP_TRACING = "debug" [env.helper.durable_objects]