From 80dabb10f0b3fec9e8487ba048ba94ad00de5879 Mon Sep 17 00:00:00 2001 From: fionser Date: Wed, 14 Aug 2024 22:54:14 +0800 Subject: [PATCH 1/2] fix #811 for Insecure PackedB2A --- libspu/mpc/cheetah/ot/basic_ot_prot.cc | 155 +++++++++------------- libspu/mpc/cheetah/ot/emp/ferret_test.cc | 56 ++++++++ libspu/mpc/cheetah/ot/yacl/ferret_test.cc | 65 +++++++++ 3 files changed, 182 insertions(+), 94 deletions(-) diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index f5a03549..34d05a1c 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -68,122 +68,89 @@ NdArrayRef BasicOTProtocols::B2A(const NdArrayRef &inp) { return PackedB2A(inp); } +// Convert the packed boolean shares to arithmetic share +// Input x in Z2k is the packed of b-bits for 1 <= b <= k. +// That is x0, x1, ..., x{b-1} +// Output y in Z2k such that y = \sum_i x{i}*2^i mod 2^k +// +// Ref: The ABY paper https://encrypto.de/papers/DSZ15.pdf Section E NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { const auto *share_t = inp.eltype().as(); auto field = inp.eltype().as()->field(); - const int64_t n = inp.numel(); - size_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits(); - if (n >= 8) { - // 8bits-align for a larger input - nbits = (nbits + 7) / 8 * 8; - } - SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); - - auto rand_bits = DISPATCH_ALL_FIELDS(field, [&]() { - if ((nbits & 7) or (n * inp.elsize()) & 7) { - // The SseTranspose requires the #rows and #columns is multiple of 8. - // Thus, we call the less efficient RandBits on margin cases. - return RandBits(field, {static_cast(n * nbits)}); - } + const int64_t ring_width = SizeOf(field) * 8; - // More efficient randbits that ultilize collapse COTs. - int64_t B = nbits; - auto r = ring_randbit(field, {n * B}).as(makeType(field, 1)); - const int64_t numl = r.numel(); + const int64_t n = inp.numel(); + const int64_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits(); + const int64_t numel = n * nbits; - NdArrayRef oup = ring_zeros(field, r.shape()); + NdArrayRef cot_oup = ring_zeros(field, {numel}); + NdArrayRef arith_oup = ring_zeros(field, inp.shape()); + DISPATCH_ALL_FIELDS(field, [&]() { using u2k = std::make_unsigned::type; - auto input = NdArrayView(r); - auto output = absl::MakeSpan(&oup.at(0), numl); - SPU_ENFORCE(oup.isCompact()); + auto input = NdArrayView(inp); + auto cot_output = absl::MakeSpan(&cot_oup.at(0), cot_oup.numel()); if (Rank() == 0) { - std::vector corr_data(numl); - // NOTE(lwj): Masking to make sure there is only single bit. - for (int64_t i = 0; i < numl; ++i) { - // corr=-2*xi - corr_data[i] = -((input[i] & 1) << 1); + std::vector corr_data(numel); + + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + auto msk = makeBitsMask(ring_width - k); + for (int64_t j = 0; j < n; ++j) { + // corr[k] = -2*x0_k + corr_data[i + j] = -2 * ((input[j] >> k) & 1); + corr_data[i + j] &= msk; + } } // Run the multiple COT in the collapse mode. - // That is, the i-th COT returns output of `nbits - i` bits. - ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), output, - /*bw*/ nbits, /*num_level*/ nbits); - ferret_sender_->Flush(); + // That is, the k-th COT returns output of `ring_width - k` bits. + // + // The k-th COT gives the arithmetic share of the k-th bit of the input + // according to x_0 ^ x_1 = x_0 + x_1 - 2 * x_0 * x_1 + ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), cot_output, + /*bw*/ ring_width, + /*num_level*/ nbits); - for (int64_t i = 0; i < numl; ++i) { - output[i] = (input[i] & 1) - output[i]; + ferret_sender_->Flush(); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + cot_output[i + j] = ((input[j] >> k) & 1) - cot_output[i + j]; + } } } else { - std::vector choices(numl); - for (int64_t i = 0; i < numl; ++i) { - choices[i] = static_cast(input[i] & 1); - } - ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), output, - nbits, nbits); - - for (int64_t i = 0; i < numl; ++i) { - output[i] = (input[i] & 1) + output[i]; + // choice[k] is the k-th bit x1_k + std::vector choices(numel); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + choices[i + j] = (input[j] >> k) & 1; + } } - } - - // oup.shape B x (n * T) - std::vector tmp(B * n * inp.elsize()); - // bit matrix transpose - SseTranspose(oup.data(), tmp.data(), B, n * inp.elsize()); + ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), cot_output, + ring_width, nbits); - std::copy_n(tmp.data(), tmp.size(), oup.data()); - return oup; - }); - - // convert the bit form to integer form - auto rand = [&](NdArrayRef _bits) { - SPU_ENFORCE(_bits.isCompact(), "need compact input"); - const int64_t n = _bits.numel() / nbits; - // init as all 0s. - auto iform = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, [&]() { - auto bits = NdArrayView(_bits); - auto digit = NdArrayView(iform); - for (int64_t i = 0; i < n; ++i) { - // LSB is bits[0]; MSB is bits[nbits - 1] - // We iterate the bits in reversed order - const size_t offset = i * nbits; - digit[i] = 0; - for (size_t j = nbits; j > 0; --j) { - digit[i] = (digit[i] << 1) | (bits[offset + j - 1] & 1); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + cot_output[i + j] = ((input[j] >> k) & 1) + cot_output[i + j]; } } - }); - return iform; - }(rand_bits); - - // open c = x ^ r - auto opened = OpenShare(ring_xor(inp, rand), ReduceOp::XOR, nbits, conn_); + } - // compute c + (1 - 2*c)* - NdArrayRef oup = ring_zeros(field, inp.shape()); - DISPATCH_ALL_FIELDS(field, [&]() { - using u2k = std::make_unsigned::type; - int rank = Rank(); - auto xr = NdArrayView(rand_bits); - auto xc = NdArrayView(opened); - auto xo = NdArrayView(oup); - - for (int64_t i = 0; i < n; ++i) { - const size_t offset = i * nbits; - u2k this_elt = xc[i]; - for (size_t j = 0; j < nbits; ++j, this_elt >>= 1) { - u2k c_ij = this_elt & 1; - ring2k_t one_bit = (1 - c_ij * 2) * xr[offset + j]; - if (rank == 0) { - one_bit += c_ij; - } - xo[i] += (one_bit << j); + // = \sum_k 2^k * + // where is the arithmetic share of the k-th bit + NdArrayView arith(arith_oup); + for (int64_t k = 0; k < nbits; ++k) { + int64_t i = k * n; + for (int64_t j = 0; j < n; ++j) { + arith[j] += (cot_output[i + j] << k); } } }); - return oup; + + return arith_oup; } // Math: diff --git a/libspu/mpc/cheetah/ot/emp/ferret_test.cc b/libspu/mpc/cheetah/ot/emp/ferret_test.cc index b878bdbe..e78394b0 100644 --- a/libspu/mpc/cheetah/ot/emp/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/emp/ferret_test.cc @@ -205,4 +205,60 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { }); } +TEST_P(FerretCOTTest, COT_Collapse) { + size_t kWorldSize = 2; + int64_t n = 8; + auto field = GetParam(); + + const auto bw = SizeOf(field) * 8; + const int level = bw; + + // generate random choices and correlation + const auto _correlation = ring_rand(field, {static_cast(n * level)}); + const auto N = _correlation.numel(); + + NdArrayRef oup1 = ring_zeros(field, _correlation.shape()); + NdArrayRef oup2 = ring_zeros(field, _correlation.shape()); + + std::vector choices(N, 1); + + DISPATCH_ALL_FIELDS(field, [&]() { + using u2k = std::make_unsigned::type; + + auto out1_span = absl::MakeSpan(&oup1.at(0), N); + auto out2_span = absl::MakeSpan(&oup2.at(0), N); + + NdArrayView correlation(_correlation); + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + + EmpFerretOt ferret(conn, rank == 0); + if (rank == 0) { + ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw, + level); + ferret.Flush(); + + } else { + ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw, + level); + } + }); + + // Sample-major order + // n || n || n || .... || n + // k=level||k=level - 1||k=level - 2|| .... + for (int64_t i = 0; i < N; i += n) { + const auto cur_bw = bw - (i / n); + const auto mask = makeMask(cur_bw); + for (int64_t j = 0; j < n; ++j) { + ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask; + ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask; + + ASSERT_EQ(c, e); + } + } + }); +} } // namespace spu::mpc::cheetah::test diff --git a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc index a1eea9e7..1d5822ca 100644 --- a/libspu/mpc/cheetah/ot/yacl/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/yacl/ferret_test.cc @@ -210,4 +210,69 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { }); } +template +T makeMask(int bw) { + if (bw == sizeof(T) * 8) { + return static_cast(-1); + } + return (static_cast(1) << bw) - 1; +} + +TEST_P(FerretCOTTest, COT_Collapse) { + size_t kWorldSize = 2; + int64_t n = 8; + auto field = std::get<0>(GetParam()); + auto use_ss = std::get<1>(GetParam()); + + const auto bw = SizeOf(field) * 8; + const int level = bw; + + // generate random choices and correlation + const auto _correlation = ring_rand(field, {static_cast(n * level)}); + const auto N = _correlation.numel(); + + NdArrayRef oup1 = ring_zeros(field, _correlation.shape()); + NdArrayRef oup2 = ring_zeros(field, _correlation.shape()); + + std::vector choices(N, 1); + + DISPATCH_ALL_FIELDS(field, [&]() { + using u2k = std::make_unsigned::type; + + auto out1_span = absl::MakeSpan(&oup1.at(0), N); + auto out2_span = absl::MakeSpan(&oup2.at(0), N); + + NdArrayView correlation(_correlation); + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + + YaclFerretOt ferret(conn, rank == 0, use_ss); + if (rank == 0) { + ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw, + level); + ferret.Flush(); + + } else { + ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw, + level); + } + }); + + // Sample-major order + // n || n || n || .... || n + // k=level||k=level - 1||k=level - 2|| .... + for (int64_t i = 0; i < N; i += n) { + const auto cur_bw = bw - (i / n); + const auto mask = makeMask(cur_bw); + for (int64_t j = 0; j < n; ++j) { + ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask; + ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask; + + ASSERT_EQ(c, e); + } + } + }); +} } // namespace spu::mpc::cheetah::test From ec264b4f461482c36790fbb40704f9d8ce70b6fa Mon Sep 17 00:00:00 2001 From: fionser Date: Wed, 14 Aug 2024 23:00:16 +0800 Subject: [PATCH 2/2] fix UT build --- libspu/mpc/cheetah/ot/emp/ferret_test.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/libspu/mpc/cheetah/ot/emp/ferret_test.cc b/libspu/mpc/cheetah/ot/emp/ferret_test.cc index e78394b0..35353964 100644 --- a/libspu/mpc/cheetah/ot/emp/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/emp/ferret_test.cc @@ -205,6 +205,14 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { }); } +template +T makeMask(int bw) { + if (bw == sizeof(T) * 8) { + return static_cast(-1); + } + return (static_cast(1) << bw) - 1; +} + TEST_P(FerretCOTTest, COT_Collapse) { size_t kWorldSize = 2; int64_t n = 8;