diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index 29fe5f96..201a57ad 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -74,15 +74,15 @@ NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { const int64_t n = inp.numel(); size_t nbits = share_t->nbits() == 0 ? 1 : share_t->nbits(); if (n >= 8) { - // 8bits-align + // 8bits-align for a larger input nbits = (nbits + 7) / 8 * 8; } SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); auto rand_bits = DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { - if ((n * inp.elsize()) & 7) { - // The SseTranspose requires the #columns is multiple of 8 - // Thus, we call the less efficient RandBits. + if ((nbits & 7) or (n * inp.elsize()) & 7) { + // The SseTranspose requires the #rows and #columns is multiple of 8. + // Thus, we call the less efficient RandBits on margin cases. return RandBits(field, {static_cast(n * nbits)}); } diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc index 09fbb4df..aa18321c 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc @@ -129,10 +129,10 @@ TEST_P(BasicOTProtTest, SingleB2A) { TEST_P(BasicOTProtTest, PackedB2A) { size_t kWorldSize = 2; - Shape shape = {11, 12, 13}; + Shape shape = {2}; FieldType field = std::get<0>(GetParam()); auto ot_type = std::get<1>(GetParam()); - for (size_t nbits : {8}) { + for (size_t nbits : {3, 8, 9}) { size_t packed_nbits = nbits; auto boolean_t = makeType(field, packed_nbits);