From bfa5d61bacf076d808900ded31073332b76caa3b Mon Sep 17 00:00:00 2001 From: fionser Date: Wed, 21 Aug 2024 12:37:29 +0800 Subject: [PATCH] fix #827 --- libspu/mpc/cheetah/ot/BUILD.bazel | 2 + libspu/mpc/cheetah/ot/ot_util.cc | 57 ++++++++++++++++---- libspu/mpc/cheetah/ot/ot_util_test.cc | 78 ++++++++++++++++++++++++++- 3 files changed, 126 insertions(+), 11 deletions(-) diff --git a/libspu/mpc/cheetah/ot/BUILD.bazel b/libspu/mpc/cheetah/ot/BUILD.bazel index 54663cc6..5e82e1ac 100644 --- a/libspu/mpc/cheetah/ot/BUILD.bazel +++ b/libspu/mpc/cheetah/ot/BUILD.bazel @@ -73,5 +73,7 @@ spu_cc_test( deps = [ ":ot_util", "//libspu/mpc/utils:ring_ops", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:simulate", ], ) diff --git a/libspu/mpc/cheetah/ot/ot_util.cc b/libspu/mpc/cheetah/ot/ot_util.cc index 3d593512..e5693513 100644 --- a/libspu/mpc/cheetah/ot/ot_util.cc +++ b/libspu/mpc/cheetah/ot/ot_util.cc @@ -40,6 +40,30 @@ void U8ToBool(absl::Span bits, uint8_t u8) { } } +template +static T _makeBitsMask(size_t nbits) { + size_t max = sizeof(T) * 8; + if (nbits == 0) { + nbits = max; + } + SPU_ENFORCE(nbits <= max); + T mask = static_cast(-1); + if (nbits < max) { + mask = (static_cast(1) << nbits) - 1; + } + return mask; +} + +static void maskArray(NdArrayRef array, FieldType field, size_t bw) { + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView view(array); + auto msk = _makeBitsMask(bw); + for (int64_t i = 0; i < view.numel(); ++i) { + view[i] &= msk; + } + }); +} + NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, std::shared_ptr conn) { SPU_ENFORCE(conn != nullptr); @@ -52,20 +76,27 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, nbits = fwidth; } SPU_ENFORCE(nbits <= fwidth, "nbits out-of-bound"); - bool packable = fwidth > nbits; - if (not packable) { - return conn->allReduce(op, shr, "open"); - } + size_t space_bits = op == ReduceOp::ADD ? nbits + 1 : nbits; size_t numel = shr.numel(); - size_t compact_numel = CeilDiv(numel * nbits, fwidth); + size_t compact_numel = CeilDiv(numel * space_bits, fwidth); + + if (space_bits > nbits and 0 != (fwidth % space_bits)) { + // FIXME(lwj): for Add, we can have a better ZipArray to handle a ring + // element that placed in two different blocks. + // For now, we use ZipArray for Add only when one element is just fit in one + // block. + auto out = conn->allReduce(op, shr, "open"); + maskArray(out, field, nbits); + return out; + } NdArrayRef out(shr.eltype(), {(int64_t)numel}); DISPATCH_ALL_FIELDS(field, [&]() { auto inp = absl::MakeConstSpan(&shr.at(0), numel); auto oup = absl::MakeSpan(&out.at(0), compact_numel); - size_t used = ZipArray(inp, nbits, oup); + size_t used = ZipArray(inp, space_bits, oup); SPU_ENFORCE_EQ(used, compact_numel); std::vector opened; @@ -76,8 +107,16 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, } oup = absl::MakeSpan(&out.at(0), numel); - UnzipArray(absl::MakeConstSpan(opened), nbits, oup); + UnzipArray(absl::MakeConstSpan(opened), space_bits, oup); + + if (space_bits > nbits and nbits < fwidth) { + auto msk = (static_cast(1) << nbits) - 1; + for (size_t i = 0; i < numel; ++i) { + oup[i] &= msk; + } + } }); + return out.reshape(shr.shape()); } @@ -87,8 +126,8 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits, #include "sse2neon.h" #endif -#define INP(x, y) inp[(x) * ncols / 8 + (y) / 8] -#define OUT(x, y) out[(y) * nrows / 8 + (x) / 8] +#define INP(x, y) inp[(x)*ncols / 8 + (y) / 8] +#define OUT(x, y) out[(y)*nrows / 8 + (x) / 8] #ifdef __x86_64__ __attribute__((target("sse2"))) diff --git a/libspu/mpc/cheetah/ot/ot_util_test.cc b/libspu/mpc/cheetah/ot/ot_util_test.cc index 149e9694..2105f408 100644 --- a/libspu/mpc/cheetah/ot/ot_util_test.cc +++ b/libspu/mpc/cheetah/ot/ot_util_test.cc @@ -16,7 +16,9 @@ #include "gtest/gtest.h" +#include "libspu/mpc/common/communicator.h" #include "libspu/mpc/utils/ring_ops.h" +#include "libspu/mpc/utils/simulate.h" namespace spu::mpc::cheetah::test { @@ -81,7 +83,6 @@ TEST_P(OtUtilTest, ZipArrayBit) { auto _zip = absl::MakeSpan(&zip.at(0), zip.numel()); auto _unzip = absl::MakeSpan(&unzip.at(0), unzip.numel()); pforeach(0, array.numel(), [&](int64_t i) { inp[i] &= mask; }); - size_t zip_sze = ZipArrayBit(inp, bw, _zip); SPU_ENFORCE(zip_sze == pack_sze); @@ -99,4 +100,77 @@ TEST_P(OtUtilTest, ZipArrayBit) { }); } -} // namespace spu::mpc::cheetah::test \ No newline at end of file +template +T makeBitsMask(size_t nbits) { + size_t max = sizeof(T) * 8; + if (nbits == 0) { + nbits = max; + } + SPU_ENFORCE(nbits <= max); + T mask = static_cast(-1); + if (nbits < max) { + mask = (static_cast(1) << nbits) - 1; + } + return mask; +} + +void MaskArray(NdArrayRef array, FieldType field, size_t bw) { + DISPATCH_ALL_FIELDS(field, [&]() { + NdArrayView view(array); + auto msk = makeBitsMask(bw); + for (int64_t i = 0; i < view.numel(); ++i) { + view[i] &= msk; + } + }); +} + +TEST_P(OtUtilTest, OpenShare_ADD) { + const auto field = GetParam(); + Shape shape = {1000L}; + + for (size_t bw_offset : {0, 15, 17}) { + size_t bw = SizeOf(field) * 8 - bw_offset; + NdArrayRef inp[2]; + utils::simulate(2, [&](std::shared_ptr ctx) { + int rank = ctx->Rank(); + + inp[rank] = ring_rand(field, shape); + MaskArray(inp[rank], field, bw); + + auto conn = std::make_shared(ctx); + auto opened = OpenShare(inp[rank], ReduceOp::ADD, bw, conn); + if (rank == 0) return; + auto expected = ring_add(inp[0], inp[1]); + MaskArray(expected, field, bw); + + ASSERT_TRUE(std::memcmp(&opened.at(0), &expected.at(0), + opened.elsize() * opened.numel()) == 0); + }); + } +} + +TEST_P(OtUtilTest, OpenShare_XOR) { + const auto field = GetParam(); + Shape shape = {1000L}; + + for (size_t bw_offset : {0, 3, 15}) { + size_t bw = SizeOf(field) * 8 - bw_offset; + NdArrayRef inp[2]; + utils::simulate(2, [&](std::shared_ptr ctx) { + int rank = ctx->Rank(); + + inp[rank] = ring_rand(field, shape); + MaskArray(inp[rank], field, bw); + + auto conn = std::make_shared(ctx); + auto opened = OpenShare(inp[rank], ReduceOp::XOR, bw, conn); + if (rank == 0) return; + auto expected = ring_xor(inp[0], inp[1]); + + ASSERT_TRUE(std::memcmp(&opened.at(0), &expected.at(0), + opened.elsize() * opened.numel()) == 0); + }); + } +} + +} // namespace spu::mpc::cheetah::test