From c0c409a980001d6bc180fd442a144d176de318ad Mon Sep 17 00:00:00 2001 From: fionser Date: Thu, 16 May 2024 23:17:34 +0800 Subject: [PATCH 1/2] [BUGFIX] margin case for non-aligned bits --- libspu/mpc/cheetah/ot/basic_ot_prot.cc | 8 ++++---- libspu/mpc/cheetah/ot/basic_ot_prot_test.cc | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index 29fe5f96..1245a87b 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -74,15 +74,15 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { const int64_t n = inp.numel(); size_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits(); if (n >= 8) { - // 8bits-align + // 8bits-align for a larger input nbits = (nbits + 7) / 8 * 8; } SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); auto rand_bits = DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { - if ((n * inp.elsize()) & 7) { - // The SseTranspose requires the #columns is multiple of 8 - // Thus, we call the less efficient RandBits. + if ((nbits & 7) or (n * inp.elsize()) & 7) { + // The SseTranspose requires the $rows and #columns is multiple of 8 + // Thus, we call the less efficient RandBits. return RandBits(field, {static_cast(n * nbits)}); } diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc index 09fbb4df..aa18321c 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc @@ -129,10 +129,10 @@ TEST_P(BasicOTProtTest, SingleB2A) { TEST_P(BasicOTProtTest, PackedB2A) { size_t kWorldSize = 2; - Shape shape = {11, 12, 13}; + Shape shape = {2}; FieldType field = std::get<0>(GetParam()); auto ot_type = std::get<1>(GetParam()); - for (size_t nbits : {8}) { + for (size_t nbits : {3, 8, 9}) { size_t packed_nbits = nbits; auto boolean_t = makeType(field, packed_nbits); From f4e201416291e1eb9cd6e9b732c7fd8dda5b1235 Mon Sep 17 00:00:00 2001 From: fionser Date: Thu, 16 May 2024 23:20:01 +0800 Subject: [PATCH 2/2] typos --- libspu/mpc/cheetah/ot/basic_ot_prot.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index 1245a87b..201a57ad 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -81,8 +81,8 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { auto rand_bits = DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { if ((nbits & 7) or (n * inp.elsize()) & 7) { - // The SseTranspose requires the $rows and #columns is multiple of 8 - // Thus, we call the less efficient RandBits. + // The SseTranspose requires the #rows and #columns is multiple of 8. + // Thus, we call the less efficient RandBits on margin cases. return RandBits(field, {static_cast(n * nbits)}); }