Skip to content

Commit

Permalink
[CHEETAH] Optimize the MulAA communication. (#850)
Browse files Browse the repository at this point in the history
Previous impl realizes the MulAA via two OLEs for computing two terms
x0*y1 and x1*y0. This will introduce a larger communication overhead.

We switch to another strategy by computing the sum x0*y1+x1*y0
homomorphically. To further utilize the CPU resources, we split a long
vector into two subtasks and to let Rank0 and Rank1 to handle each half.

# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #

## Possible side effects?

- Performance:

- Backward compatibility:
  • Loading branch information
fionser authored Sep 18, 2024
1 parent 4720ef4 commit c7a5ba5
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 22 deletions.
167 changes: 167 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ struct CheetahMul::Impl : public EnableCPRNG {
NdArrayRef MulOLE(const NdArrayRef &shr, yacl::link::Context *conn,
bool evaluator, uint32_t msg_width_hint);

NdArrayRef MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr,
yacl::link::Context *conn, bool evaluator,
uint32_t msg_width_hint);

protected:
void LocalExpandSEALContexts(size_t target);

Expand Down Expand Up @@ -167,6 +171,16 @@ struct CheetahMul::Impl : public EnableCPRNG {
absl::Span<const uint64_t> rnd_mask,
yacl::link::Context *conn = nullptr);

// Enc(x0) * y1 + Enc(y0) * x1 + rand_mask
void FMAThenResponse(FieldType field, int64_t num_elts,
const Options &options,
absl::Span<const yacl::Buffer> ciphers_x0,
absl::Span<const yacl::Buffer> ciphers_y0,
absl::Span<const RLWEPt> plains_x1,
absl::Span<const RLWEPt> plains_y1,
absl::Span<const uint64_t> rnd_mask,
yacl::link::Context *conn = nullptr);

void PrepareRandomMask(FieldType field, int64_t size, const Options &options,
std::vector<uint64_t> &mask);

Expand Down Expand Up @@ -386,6 +400,79 @@ NdArrayRef CheetahMul::Impl::MulOLE(const NdArrayRef &shr,
return DecryptArray(field, numel, options, recv_ct).reshape(shr.shape());
}

NdArrayRef CheetahMul::Impl::MulShare(const NdArrayRef &xshr,
const NdArrayRef &yshr,
yacl::link::Context *conn, bool evaluator,
uint32_t msg_width_hint) {
if (conn == nullptr) {
conn = lctx_.get();
}

auto eltype = xshr.eltype();
SPU_ENFORCE(eltype.isa<Ring2k>(), "must be ring_type, got={}", eltype);
SPU_ENFORCE(yshr.eltype().isa<Ring2k>(), "must be ring_type, got={}",
yshr.eltype());
SPU_ENFORCE(xshr.numel() > 0);
SPU_ENFORCE_EQ(xshr.shape(), yshr.shape());

auto field = eltype.as<Ring2k>()->field();
Options options;
options.ring_bitlen = SizeOf(field) * 8;
options.msg_bitlen =
msg_width_hint == 0 ? options.ring_bitlen : msg_width_hint;
SPU_ENFORCE(options.msg_bitlen > 0 &&
options.msg_bitlen <= options.ring_bitlen);
LazyExpandSEALContexts(options, conn);
LazyInitModSwitchHelper(options);

size_t numel = xshr.numel();
int nxt_rank = conn->NextRank();

// x0*y0 + <x0 + y1 + x1 * y0> + x1 * y1
if (evaluator) {
std::vector<RLWEPt> encoded_x0;
std::vector<RLWEPt> encoded_y0;
EncodeArray(xshr, false, options, &encoded_x0);
EncodeArray(yshr, false, options, &encoded_y0);

size_t payload_sze = encoded_x0.size();
std::vector<yacl::Buffer> recv_ct_x1(payload_sze);
std::vector<yacl::Buffer> recv_ct_y1(payload_sze);
auto io_task = std::async(std::launch::async, [&]() {
for (size_t idx = 0; idx < payload_sze; ++idx) {
recv_ct_x1[idx] = conn->Recv(nxt_rank, "");
}
for (size_t idx = 0; idx < payload_sze; ++idx) {
recv_ct_y1[idx] = conn->Recv(nxt_rank, "");
}
});

std::vector<uint64_t> random_share_mask;
PrepareRandomMask(field, xshr.numel(), options, random_share_mask);

// wait for IO
io_task.get();
FMAThenResponse(field, numel, options, recv_ct_x1, recv_ct_y1, encoded_x0,
encoded_y0, absl::MakeConstSpan(random_share_mask), conn);
// convert x \in [0, P) to [0, 2^k) by round(2^k*x/P)
auto &ms_helper = ms_helpers_.find(options)->second;
auto out = ms_helper.ModulusDownRNS(field, xshr.shape(), random_share_mask)
.reshape(xshr.shape());
ring_add_(out, ring_mul(xshr, yshr));
return out;
}

size_t payload_sze = EncryptArrayThenSend(xshr, options, conn);
(void)EncryptArrayThenSend(yshr, options, conn);
std::vector<yacl::Buffer> recv_ct(payload_sze);
for (size_t idx = 0; idx < payload_sze; ++idx) {
recv_ct[idx] = conn->Recv(nxt_rank, "");
}
auto out = DecryptArray(field, numel, options, recv_ct).reshape(xshr.shape());
ring_add_(out, ring_mul(xshr, yshr));
return out;
}

size_t CheetahMul::Impl::EncryptArrayThenSend(const NdArrayRef &array,
const Options &options,
yacl::link::Context *conn) {
Expand Down Expand Up @@ -573,6 +660,72 @@ void CheetahMul::Impl::MulThenResponse(FieldType, int64_t num_elts,
}
}

void CheetahMul::Impl::FMAThenResponse(
FieldType, int64_t num_elts, const Options &options,
absl::Span<const yacl::Buffer> ciphers_x0,
absl::Span<const yacl::Buffer> ciphers_y0,
absl::Span<const RLWEPt> plains_x1, absl::Span<const RLWEPt> plains_y1,
absl::Span<const uint64_t> rnd_mask, yacl::link::Context *conn) {
SPU_ENFORCE(!ciphers_x0.empty(), "CheetahMul: empty cipher");
SPU_ENFORCE(!ciphers_y0.empty(), "CheetahMul: empty cipher");
SPU_ENFORCE_EQ(ciphers_x0.size(), ciphers_y0.size());
SPU_ENFORCE_EQ(plains_x1.size(), ciphers_x0.size(),
"CheetahMul: ct/pt size mismatch");
SPU_ENFORCE_EQ(plains_y1.size(), ciphers_y0.size(),
"CheetahMul: ct/pt size mismatch");

const int64_t num_splits = CeilDiv(num_elts, num_slots());
const int64_t num_seal_ctx = WorkingContextSize(options);
const int64_t num_ciphers = num_seal_ctx * num_splits;
SPU_ENFORCE(ciphers_x0.size() == (size_t)num_ciphers,
"CheetahMul : expect {} != {}", num_ciphers, ciphers_x0.size());
SPU_ENFORCE(rnd_mask.size() == (size_t)num_elts * num_seal_ctx,
"CheetahMul: rnd_mask size mismatch");

std::vector<yacl::Buffer> response(num_ciphers);
yacl::parallel_for(0, num_ciphers, [&](int64_t job_bgn, int64_t job_end) {
RLWECt ct_x;
RLWECt ct_y;
std::vector<uint64_t> u64tmp(num_slots(), 0);
for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) {
int64_t cntxt_id = job_id / num_splits;
int64_t split_id = job_id % num_splits;

int64_t slice_bgn = split_id * num_slots();
int64_t slice_n = std::min(num_slots(), num_elts - slice_bgn);
// offset by context id
slice_bgn += cntxt_id * num_elts;

DecodeSEALObject(ciphers_x0[job_id], seal_cntxts_[cntxt_id], &ct_x);
DecodeSEALObject(ciphers_y0[job_id], seal_cntxts_[cntxt_id], &ct_y);

// ct_x <- Re-randomize(ct_x * pt_y + ct_y * pt_x) - random_mask
simd_mul_instances_[cntxt_id]->FMAThenReshareInplace(
{&ct_x, 1}, {&ct_y, 1}, plains_y1.subspan(job_id, 1),
plains_x1.subspan(job_id, 1), rnd_mask.subspan(slice_bgn, slice_n),
*peer_pub_key_, seal_cntxts_[cntxt_id]);

response[job_id] = EncodeSEALObject(ct_x);
}
});

if (conn == nullptr) {
conn = lctx_.get();
}

int nxt_rank = conn->NextRank();
for (int64_t i = 0; i < num_ciphers; i += kCtAsyncParallel) {
int64_t this_batch = std::min(num_ciphers - i, kCtAsyncParallel);
conn->Send(nxt_rank, response[i],
fmt::format("FMAThenResponse ct[{}] to rank{}", i, nxt_rank));
for (int64_t j = 1; j < this_batch; ++j) {
conn->SendAsync(
nxt_rank, response[i + j],
fmt::format("FMAThenResponse ct[{}] to rank{}", i + j, nxt_rank));
}
}
}

NdArrayRef CheetahMul::Impl::DecryptArray(
FieldType field, int64_t size, const Options &options,
const std::vector<yacl::Buffer> &ct_array) {
Expand Down Expand Up @@ -625,6 +778,20 @@ size_t CheetahMul::OLEBatchSize() const {
return impl_->OLEBatchSize();
}

NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr,
yacl::link::Context *conn, bool is_evaluator,
uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
SPU_ENFORCE(conn != nullptr);
return impl_->MulShare(xshr, yshr, conn, is_evaluator, msg_width_hint);
}

NdArrayRef CheetahMul::MulShare(const NdArrayRef &xshr, const NdArrayRef &yshr,
bool is_evaluator, uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
return impl_->MulShare(xshr, yshr, nullptr, is_evaluator, msg_width_hint);
}

NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, yacl::link::Context *conn,
bool is_evaluator, uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
Expand Down
13 changes: 13 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,27 @@ class CheetahMul {

void LazyInitKeys(FieldType field, uint32_t msg_width_hint = 0);

// x, y => [x*y] for two private inputs
// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, yacl::link::Context* conn,
bool is_evaluator, uint32_t msg_width_hint = 0);

// x, y => [x*y] for two private inputs
// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, bool is_evaluator,
uint32_t msg_width_hint = 0);

// [x], [y] => [x*y] for two shares
// NOTE: make sure to call InitKeys first
NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y,
yacl::link::Context* conn, bool is_evaluator,
uint32_t msg_width_hint = 0);

// [x], [y] => [x*y] for two shares
// NOTE: make sure to call InitKeys first
NdArrayRef MulShare(const NdArrayRef& x, const NdArrayRef& y,
bool is_evaluator, uint32_t msg_width_hint = 0);

int Rank() const;

size_t OLEBatchSize() const;
Expand Down
33 changes: 33 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,37 @@ TEST_P(CheetahMulTest, MixedRingSizeMul) {
EXPECT_TRUE(ring_all_equal(expected2, computed2, kMaxDiff));
}

TEST_P(CheetahMulTest, MulShare) {
size_t kWorldSize = 2;
auto field = std::get<0>(GetParam());
int64_t n = std::get<1>(GetParam());
bool allow_approx = std::get<2>(GetParam());

auto a_bits = ring_rand(field, {n});
auto b_bits = ring_rand(field, {n});

std::vector<NdArrayRef> a_shr(kWorldSize);
std::vector<NdArrayRef> b_shr(kWorldSize);
a_shr[0] = ring_rand(field, {n});
b_shr[0] = ring_rand(field, {n});
a_shr[1] = ring_sub(a_bits, a_shr[0]);
b_shr[1] = ring_sub(b_bits, b_shr[0]);

std::vector<NdArrayRef> result(kWorldSize);
utils::simulate(kWorldSize, [&](std::shared_ptr<yacl::link::Context> lctx) {
int rank = lctx->Rank();
// (a0 + a1) * (b0 + b1)
// a0*b0 + a0*b1 + a1*b0 + a1*b1
auto mul = std::make_shared<CheetahMul>(lctx, allow_approx);

result[rank] = mul->MulShare(a_shr[rank], b_shr[rank], rank == 0);
});

auto expected = ring_mul(a_bits, b_bits);
auto computed = ring_add(result[0], result[1]);

const int64_t kMaxDiff = allow_approx ? 1 : 0;
EXPECT_TRUE(ring_all_equal(expected, computed, kMaxDiff));
}

} // namespace spu::mpc::cheetah::test
60 changes: 60 additions & 0 deletions libspu/mpc/cheetah/arith/simd_mul_prot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,66 @@ void SIMDMulProt::MulThenReshareInplace(absl::Span<RLWECt> ct,
}
}

// Compute ct0 * pt1 + ct1 * pt1 - mask mod p
void SIMDMulProt::FMAThenReshareInplace(absl::Span<RLWECt> ct0,
absl::Span<const RLWECt> ct1,
absl::Span<const RLWEPt> pt0,
absl::Span<const RLWEPt> pt1,
absl::Span<const uint64_t> share_mask,
const RLWEPublicKey &public_key,
const seal::SEALContext &context) {
SPU_ENFORCE_EQ(ct0.size(), ct1.size());
SPU_ENFORCE_EQ(pt0.size(), pt1.size());
SPU_ENFORCE_EQ(ct0.size(), pt0.size());
SPU_ENFORCE_EQ(CeilDiv(share_mask.size(), (size_t)simd_lane_), ct0.size());

seal::Evaluator evaluator(context);
RLWECt zero_enc;
RLWEPt rnd;

constexpr int kMarginBitsForDec = 10;
seal::parms_id_type final_level_id = context.last_parms_id();
while (final_level_id != context.first_parms_id()) {
auto cntxt = context.get_context_data(final_level_id);
if (cntxt->total_coeff_modulus_bit_count() >=
kMarginBitsForDec + cntxt->parms().plain_modulus().bit_count()) {
break;
}
final_level_id = cntxt->prev_context_data()->parms_id();
}

RLWECt tmp_ct;
for (size_t i = 0; i < ct0.size(); ++i) {
// 1. Ct-Pt Mul
evaluator.multiply_plain_inplace(ct0[i], pt0[i]);
evaluator.multiply_plain(ct1[i], pt1[i], tmp_ct);
evaluator.add_inplace(ct0[i], tmp_ct);

// 2. Noise flooding
NoiseFloodInplace(ct0[i], context);

// 3. Drop some modulus for a smaller communication
evaluator.mod_switch_to_inplace(ct0[i], final_level_id);

// 4. Re-randomize via adding enc(0)
seal::util::encrypt_zero_asymmetric(public_key, context, ct0[i].parms_id(),
ct0[i].is_ntt_form(), zero_enc);
evaluator.add_inplace(ct0[i], zero_enc);

// 5. Additive share
size_t slice_bgn = i * simd_lane_;
size_t slice_n =
std::min((size_t)simd_lane_, share_mask.size() - slice_bgn);
EncodeSingle(share_mask.subspan(slice_bgn, slice_n), rnd);
evaluator.sub_plain_inplace(ct0[i], rnd);

// 6. Truncate for smaller communication
if (ct0[i].coeff_modulus_size() == 1) {
TruncateBFVForDecryption(ct0[i], context);
}
}
}

void SIMDMulProt::NoiseFloodInplace(RLWECt &ct,
const seal::SEALContext &context) {
SPU_ENFORCE(seal::is_metadata_valid_for(ct, context));
Expand Down
18 changes: 13 additions & 5 deletions libspu/mpc/cheetah/arith/simd_mul_prot.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ class SIMDMulProt : public EnableCPRNG {
const RLWEPublicKey& public_key,
const seal::SEALContext& context);

void MulThenReshareInplaceOneBit(absl::Span<RLWECt> ct,
absl::Span<const RLWEPt> pt,
absl::Span<uint64_t> share_mask,
const RLWEPublicKey& public_key,
const seal::SEALContext& context);
// ct0 * pt0 + ct1 * pt1 + mask
void FMAThenReshareInplace(absl::Span<RLWECt> ct0,
absl::Span<const RLWECt> ct1,
absl::Span<const RLWEPt> pt0,
absl::Span<const RLWEPt> pt1,
absl::Span<const uint64_t> share_mask,
const RLWEPublicKey& public_key,
const seal::SEALContext& context);

[[deprecated]] void MulThenReshareInplaceOneBit(
absl::Span<RLWECt> ct, absl::Span<const RLWEPt> pt,
absl::Span<uint64_t> share_mask, const RLWEPublicKey& public_key,
const seal::SEALContext& context);

inline int64_t SIMDLane() const { return simd_lane_; }

Expand Down
20 changes: 5 additions & 15 deletions libspu/mpc/cheetah/arith/simd_mul_prot_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class SIMDMulTest : public ::testing::TestWithParam<bool>, public EnableCPRNG {
};

INSTANTIATE_TEST_SUITE_P(
Cheetah, SIMDMulTest, testing::Values(true, false),
Cheetah, SIMDMulTest, testing::Values(true),
[](const testing::TestParamInfo<SIMDMulTest::ParamType> &p) {
return fmt::format("{}", p.param ? "NoiseFlood" : "Approx");
});
Expand Down Expand Up @@ -116,20 +116,10 @@ TEST_P(SIMDMulTest, Basic) {
simd_mul_prot_->SymEncrypt(encode_b, *rlwe_sk_, *context_, false,
absl::MakeSpan(encrypt_b));

if (GetParam()) {
RandomPlain(absl::MakeSpan(out_a));
simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a,
absl::MakeConstSpan(out_a),
*rlwe_pk_, *context_);
} else {
simd_mul_prot_->MulThenReshareInplaceOneBit(
absl::MakeSpan(encrypt_b), encode_a, absl::MakeSpan(out_a), *rlwe_pk_,
*context_);
}
if (rep == 0) {
printf("rep ct.L %zd\n", encrypt_b[0].coeff_modulus_size());
}

RandomPlain(absl::MakeSpan(out_a));
simd_mul_prot_->MulThenReshareInplace(absl::MakeSpan(encrypt_b), encode_a,
absl::MakeConstSpan(out_a), *rlwe_pk_,
*context_);
auto _out_b = absl::MakeSpan(out_b);
for (size_t i = 0; i < num_pt; ++i) {
seal::Plaintext pt;
Expand Down
Loading

0 comments on commit c7a5ba5

Please sign in to comment.