diff --git a/prio/prg.c b/prio/prg.c index 03845f7..78851fd 100644 --- a/prio/prg.c +++ b/prio/prg.c @@ -119,6 +119,32 @@ PRG_get_int(PRG prg, mp_int* out, const mp_int* max) return rand_int_rng(out, max, &PRG_get_bytes_internal, (void*)prg); } +SECStatus +PRG_get_int_range(PRG prg, mp_int* out, const mp_int* lower, const mp_int* max) +{ + SECStatus rv; + mp_int width; + MP_DIGITS(&width) = NULL; + MP_CHECKC(mp_init(&width)); + + // Compute + // width = max - lower + MP_CHECKC(mp_sub(max, lower, &width)); + + // Get an integer x in the range [0, width) + P_CHECKC(PRG_get_int(prg, out, &width)); + + // Set + // out = lower + x + // which is in the range [lower, width+lower), + // which is [lower, max). + MP_CHECKC(mp_add(lower, out, out)); + +cleanup: + mp_clear(&width); + return rv; +} + SECStatus PRG_get_array(PRG prg, MPArray dst, const mp_int* mod) { diff --git a/prio/prg.h b/prio/prg.h index 5dc6520..f1a3b30 100644 --- a/prio/prg.h +++ b/prio/prg.h @@ -35,6 +35,13 @@ SECStatus PRG_get_bytes(PRG prg, unsigned char* bytes, size_t len); */ SECStatus PRG_get_int(PRG prg, mp_int* out, const mp_int* max); +/* + * Use the PRG output to sample a big integer x in the range + * lower <= x < max. + */ +SECStatus PRG_get_int_range(PRG prg, mp_int* out, const mp_int* lower, + const mp_int* max); + /* * Use secret sharing to split the int src into two shares. * Use PRG to generate the value `shareB`. diff --git a/prio/server.c b/prio/server.c index a648c68..5bdb9a9 100644 --- a/prio/server.c +++ b/prio/server.c @@ -188,19 +188,25 @@ compute_shares(PrioVerifier v, const_PrioPacketClient p) const int n = v->s->cfg->num_data_fields + 1; const int N = next_power_of_two(n); mp_int eval_at; + mp_int lower; MP_DIGITS(&eval_at) = NULL; + MP_DIGITS(&lower) = NULL; MPArray points_f = NULL; MPArray points_g = NULL; MPArray points_h = NULL; MP_CHECKC(mp_init(&eval_at)); + MP_CHECKC(mp_init(&lower)); P_CHECKA(points_f = MPArray_new(N)); P_CHECKA(points_g = MPArray_new(N)); P_CHECKA(points_h = MPArray_new(2 * N)); - // Use PRG to generate random point - MP_CHECKC(PRG_get_int(v->s->prg, &eval_at, &v->s->cfg->modulus)); + // Use PRG to generate random point. Per Appendix D.2 of full version of + // Prio paper, this value must be in the range + // [n+1, modulus). + mp_set(&lower, n + 1); + P_CHECKC(PRG_get_int_range(v->s->prg, &eval_at, &lower, &v->s->cfg->modulus)); // Reduce value into the field we're using. This // doesn't yield exactly a uniformly random point, @@ -243,6 +249,7 @@ compute_shares(PrioVerifier v, const_PrioPacketClient p) MPArray_clear(points_g); MPArray_clear(points_h); mp_clear(&eval_at); + mp_clear(&lower); return rv; } diff --git a/ptest/client_test.c b/ptest/client_test.c index 0ae8622..fa58b1e 100644 --- a/ptest/client_test.c +++ b/ptest/client_test.c @@ -74,7 +74,6 @@ test_client_agg(int nclients, int nfields, bool config_is_okay) PT_CHECKC(Keypair_new(&skA, &pkA)); PT_CHECKC(Keypair_new(&skB, &pkB)); - printf("fields: %d\n", nfields); P_CHECKA(cfg = PrioConfig_new(nfields, pkA, pkB, batch_id, batch_id_len)); if (!config_is_okay) { PT_CHECKCB( diff --git a/ptest/prg_test.c b/ptest/prg_test.c index f4d2030..40f6170 100644 --- a/ptest/prg_test.c +++ b/ptest/prg_test.c @@ -352,3 +352,64 @@ mu_test__prg_share_arr(void) MPArray_clear(arr_share); PrioConfig_clear(cfg); } + +void +test_prg_range_once(int bot, int limit) +{ + SECStatus rv = SECSuccess; + PrioPRGSeed key; + mp_int lower; + mp_int max; + mp_int out; + PRG prg = NULL; + + MP_DIGITS(&lower) = NULL; + MP_DIGITS(&max) = NULL; + MP_DIGITS(&out) = NULL; + + PT_CHECKC(PrioPRGSeed_randomize(&key)); + PT_CHECKA(prg = PRG_new(key)); + + MPT_CHECKC(mp_init(&max)); + MPT_CHECKC(mp_init(&out)); + MPT_CHECKC(mp_init(&lower)); + + mp_set(&lower, bot); + mp_set(&max, limit); + + for (int i = 0; i < 100; i++) { + PT_CHECKC(PRG_get_int_range(prg, &out, &lower, &max)); + mu_check(mp_cmp_d(&out, limit) == -1); + mu_check(mp_cmp_d(&out, bot) > -1); + mu_check(mp_cmp_z(&out) > -1); + } + +cleanup: + mu_check(rv == SECSuccess); + mp_clear(&lower); + mp_clear(&max); + mp_clear(&out); + PRG_clear(prg); +} + +void +mu_test_prg_range__multiple_of_8(void) +{ + test_prg_range_once(128, 256); + test_prg_range_once(256, 256 * 256); +} + +void +mu_test_prg_range__near_multiple_of_8(void) +{ + test_prg_range_once(256, 256 + 1); + test_prg_range_once(256 * 256, 256 * 256 + 1); +} + +void +mu_test_prg_range__odd(void) +{ + test_prg_range_once(23, 39); + test_prg_range_once(7, 123); + test_prg_range_once(99000, 993123); +}