Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #811 for Insecure PackedB2A #819

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 61 additions & 94 deletions libspu/mpc/cheetah/ot/basic_ot_prot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,122 +68,89 @@ NdArrayRef BasicOTProtocols::B2A(const NdArrayRef &inp) {
return PackedB2A(inp);
}

// Convert the packed boolean shares to arithmetic share
// Input x in Z2k is the packed of b-bits for 1 <= b <= k.
// That is x0, x1, ..., x{b-1}
// Output y in Z2k such that y = \sum_i x{i}*2^i mod 2^k
//
// Ref: The ABY paper https://encrypto.de/papers/DSZ15.pdf Section E
NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) {
const auto *share_t = inp.eltype().as<BShrTy>();
auto field = inp.eltype().as<Ring2k>()->field();
const int64_t n = inp.numel();
size_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits();
if (n >= 8) {
// 8bits-align for a larger input
nbits = (nbits + 7) / 8 * 8;
}
SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field));

auto rand_bits = DISPATCH_ALL_FIELDS(field, [&]() {
if ((nbits & 7) or (n * inp.elsize()) & 7) {
// The SseTranspose requires the #rows and #columns is multiple of 8.
// Thus, we call the less efficient RandBits on margin cases.
return RandBits(field, {static_cast<int64_t>(n * nbits)});
}
const int64_t ring_width = SizeOf(field) * 8;

// More efficient randbits that ultilize collapse COTs.
int64_t B = nbits;
auto r = ring_randbit(field, {n * B}).as(makeType<BShrTy>(field, 1));
const int64_t numl = r.numel();
const int64_t n = inp.numel();
const int64_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits();
const int64_t numel = n * nbits;

NdArrayRef oup = ring_zeros(field, r.shape());
NdArrayRef cot_oup = ring_zeros(field, {numel});
NdArrayRef arith_oup = ring_zeros(field, inp.shape());
DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;
auto input = NdArrayView<const u2k>(r);
auto output = absl::MakeSpan(&oup.at<u2k>(0), numl);
SPU_ENFORCE(oup.isCompact());
auto input = NdArrayView<const u2k>(inp);
auto cot_output = absl::MakeSpan(&cot_oup.at<u2k>(0), cot_oup.numel());

if (Rank() == 0) {
std::vector<u2k> corr_data(numl);
// NOTE(lwj): Masking to make sure there is only single bit.
for (int64_t i = 0; i < numl; ++i) {
// corr=-2*xi
corr_data[i] = -((input[i] & 1) << 1);
std::vector<u2k> corr_data(numel);

for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
auto msk = makeBitsMask<u2k>(ring_width - k);
for (int64_t j = 0; j < n; ++j) {
// corr[k] = -2*x0_k
corr_data[i + j] = -2 * ((input[j] >> k) & 1);
corr_data[i + j] &= msk;
}
}
// Run the multiple COT in the collapse mode.
// That is, the i-th COT returns output of `nbits - i` bits.
ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), output,
/*bw*/ nbits, /*num_level*/ nbits);
ferret_sender_->Flush();
// That is, the k-th COT returns output of `ring_width - k` bits.
//
// The k-th COT gives the arithmetic share of the k-th bit of the input
// according to x_0 ^ x_1 = x_0 + x_1 - 2 * x_0 * x_1
ferret_sender_->SendCAMCC_Collapse(absl::MakeSpan(corr_data), cot_output,
/*bw*/ ring_width,
/*num_level*/ nbits);

for (int64_t i = 0; i < numl; ++i) {
output[i] = (input[i] & 1) - output[i];
ferret_sender_->Flush();
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
cot_output[i + j] = ((input[j] >> k) & 1) - cot_output[i + j];
}
}
} else {
std::vector<uint8_t> choices(numl);
for (int64_t i = 0; i < numl; ++i) {
choices[i] = static_cast<uint8_t>(input[i] & 1);
}
ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), output,
nbits, nbits);

for (int64_t i = 0; i < numl; ++i) {
output[i] = (input[i] & 1) + output[i];
// choice[k] is the k-th bit x1_k
std::vector<uint8_t> choices(numel);
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
choices[i + j] = (input[j] >> k) & 1;
}
}
}

// oup.shape B x (n * T)
std::vector<uint8_t> tmp(B * n * inp.elsize());

// bit matrix transpose
SseTranspose(oup.data<uint8_t>(), tmp.data(), B, n * inp.elsize());
ferret_receiver_->RecvCAMCC_Collapse(absl::MakeSpan(choices), cot_output,
ring_width, nbits);

std::copy_n(tmp.data(), tmp.size(), oup.data<uint8_t>());
return oup;
});

// convert the bit form to integer form
auto rand = [&](NdArrayRef _bits) {
SPU_ENFORCE(_bits.isCompact(), "need compact input");
const int64_t n = _bits.numel() / nbits;
// init as all 0s.
auto iform = ring_zeros(field, inp.shape());
DISPATCH_ALL_FIELDS(field, [&]() {
auto bits = NdArrayView<const ring2k_t>(_bits);
auto digit = NdArrayView<ring2k_t>(iform);
for (int64_t i = 0; i < n; ++i) {
// LSB is bits[0]; MSB is bits[nbits - 1]
// We iterate the bits in reversed order
const size_t offset = i * nbits;
digit[i] = 0;
for (size_t j = nbits; j > 0; --j) {
digit[i] = (digit[i] << 1) | (bits[offset + j - 1] & 1);
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
cot_output[i + j] = ((input[j] >> k) & 1) + cot_output[i + j];
}
}
});
return iform;
}(rand_bits);

// open c = x ^ r
auto opened = OpenShare(ring_xor(inp, rand), ReduceOp::XOR, nbits, conn_);
}

// compute c + (1 - 2*c)*<r>
NdArrayRef oup = ring_zeros(field, inp.shape());
DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;
int rank = Rank();
auto xr = NdArrayView<const u2k>(rand_bits);
auto xc = NdArrayView<const u2k>(opened);
auto xo = NdArrayView<ring2k_t>(oup);

for (int64_t i = 0; i < n; ++i) {
const size_t offset = i * nbits;
u2k this_elt = xc[i];
for (size_t j = 0; j < nbits; ++j, this_elt >>= 1) {
u2k c_ij = this_elt & 1;
ring2k_t one_bit = (1 - c_ij * 2) * xr[offset + j];
if (rank == 0) {
one_bit += c_ij;
}
xo[i] += (one_bit << j);
// <x> = \sum_k 2^k * <x_k>
// where <x_k> is the arithmetic share of the k-th bit
NdArrayView<u2k> arith(arith_oup);
for (int64_t k = 0; k < nbits; ++k) {
int64_t i = k * n;
for (int64_t j = 0; j < n; ++j) {
arith[j] += (cot_output[i + j] << k);
}
}
});
return oup;

return arith_oup;
}

// Math:
Expand Down
64 changes: 64 additions & 0 deletions libspu/mpc/cheetah/ot/emp/ferret_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,68 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) {
});
}

template <typename T>
T makeMask(int bw) {
if (bw == sizeof(T) * 8) {
return static_cast<T>(-1);
}
return (static_cast<T>(1) << bw) - 1;
}

TEST_P(FerretCOTTest, COT_Collapse) {
size_t kWorldSize = 2;
int64_t n = 8;
auto field = GetParam();

const auto bw = SizeOf(field) * 8;
const int level = bw;

// generate random choices and correlation
const auto _correlation = ring_rand(field, {static_cast<int64_t>(n * level)});
const auto N = _correlation.numel();

NdArrayRef oup1 = ring_zeros(field, _correlation.shape());
NdArrayRef oup2 = ring_zeros(field, _correlation.shape());

std::vector<uint8_t> choices(N, 1);

DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;

auto out1_span = absl::MakeSpan(&oup1.at<u2k>(0), N);
auto out2_span = absl::MakeSpan(&oup2.at<u2k>(0), N);

NdArrayView<u2k> correlation(_correlation);

utils::simulate(kWorldSize, [&](std::shared_ptr<yacl::link::Context> ctx) {
auto conn = std::make_shared<Communicator>(ctx);
int rank = ctx->Rank();

EmpFerretOt ferret(conn, rank == 0);
if (rank == 0) {
ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw,
level);
ferret.Flush();

} else {
ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw,
level);
}
});

// Sample-major order
// n || n || n || .... || n
// k=level||k=level - 1||k=level - 2|| ....
for (int64_t i = 0; i < N; i += n) {
const auto cur_bw = bw - (i / n);
const auto mask = makeMask<ring2k_t>(cur_bw);
for (int64_t j = 0; j < n; ++j) {
ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask;
ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask;

ASSERT_EQ(c, e);
}
}
});
}
} // namespace spu::mpc::cheetah::test
65 changes: 65 additions & 0 deletions libspu/mpc/cheetah/ot/yacl/ferret_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,69 @@ TEST_P(FerretCOTTest, ChosenMsgChosenChoice) {
});
}

template <typename T>
T makeMask(int bw) {
if (bw == sizeof(T) * 8) {
return static_cast<T>(-1);
}
return (static_cast<T>(1) << bw) - 1;
}

TEST_P(FerretCOTTest, COT_Collapse) {
size_t kWorldSize = 2;
int64_t n = 8;
auto field = std::get<0>(GetParam());
auto use_ss = std::get<1>(GetParam());

const auto bw = SizeOf(field) * 8;
const int level = bw;

// generate random choices and correlation
const auto _correlation = ring_rand(field, {static_cast<int64_t>(n * level)});
const auto N = _correlation.numel();

NdArrayRef oup1 = ring_zeros(field, _correlation.shape());
NdArrayRef oup2 = ring_zeros(field, _correlation.shape());

std::vector<uint8_t> choices(N, 1);

DISPATCH_ALL_FIELDS(field, [&]() {
using u2k = std::make_unsigned<ring2k_t>::type;

auto out1_span = absl::MakeSpan(&oup1.at<u2k>(0), N);
auto out2_span = absl::MakeSpan(&oup2.at<u2k>(0), N);

NdArrayView<u2k> correlation(_correlation);

utils::simulate(kWorldSize, [&](std::shared_ptr<yacl::link::Context> ctx) {
auto conn = std::make_shared<Communicator>(ctx);
int rank = ctx->Rank();

YaclFerretOt ferret(conn, rank == 0, use_ss);
if (rank == 0) {
ferret.SendCAMCC_Collapse(makeConstSpan(correlation), out1_span, bw,
level);
ferret.Flush();

} else {
ferret.RecvCAMCC_Collapse(absl::MakeSpan(choices), out2_span, bw,
level);
}
});

// Sample-major order
// n || n || n || .... || n
// k=level||k=level - 1||k=level - 2|| ....
for (int64_t i = 0; i < N; i += n) {
const auto cur_bw = bw - (i / n);
const auto mask = makeMask<ring2k_t>(cur_bw);
for (int64_t j = 0; j < n; ++j) {
ring2k_t c = (-out1_span[i + j] + out2_span[i + j]) & mask;
ring2k_t e = (choices[i + j] ? correlation[i + j] : 0) & mask;

ASSERT_EQ(c, e);
}
}
});
}
} // namespace spu::mpc::cheetah::test
Loading