Skip to content

Commit

Permalink
fix #827 (#828)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #827 

## Possible side effects?

- Performance:

- Backward compatibility:
  • Loading branch information
fionser authored Aug 21, 2024
1 parent d51f4e7 commit 3117206
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 11 deletions.
2 changes: 2 additions & 0 deletions libspu/mpc/cheetah/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,7 @@ spu_cc_test(
deps = [
":ot_util",
"//libspu/mpc/utils:ring_ops",
"//libspu/mpc/common:communicator",
"//libspu/mpc/utils:simulate",
],
)
57 changes: 48 additions & 9 deletions libspu/mpc/cheetah/ot/ot_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ void U8ToBool(absl::Span<uint8_t> bits, uint8_t u8) {
}
}

template <typename T>
static T _makeBitsMask(size_t nbits) {
size_t max = sizeof(T) * 8;
if (nbits == 0) {
nbits = max;
}
SPU_ENFORCE(nbits <= max);
T mask = static_cast<T>(-1);
if (nbits < max) {
mask = (static_cast<T>(1) << nbits) - 1;
}
return mask;
}

static void maskArray(NdArrayRef array, FieldType field, size_t bw) {
DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView<ring2k_t> view(array);
auto msk = _makeBitsMask<ring2k_t>(bw);
for (int64_t i = 0; i < view.numel(); ++i) {
view[i] &= msk;
}
});
}

NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
std::shared_ptr<Communicator> conn) {
SPU_ENFORCE(conn != nullptr);
Expand All @@ -52,20 +76,27 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
nbits = fwidth;
}
SPU_ENFORCE(nbits <= fwidth, "nbits out-of-bound");
bool packable = fwidth > nbits;
if (not packable) {
return conn->allReduce(op, shr, "open");
}

size_t space_bits = op == ReduceOp::ADD ? nbits + 1 : nbits;
size_t numel = shr.numel();
size_t compact_numel = CeilDiv(numel * nbits, fwidth);
size_t compact_numel = CeilDiv(numel * space_bits, fwidth);

if (space_bits > nbits and 0 != (fwidth % space_bits)) {
// FIXME(lwj): for Add, we can have a better ZipArray to handle a ring
// element that placed in two different blocks.
// For now, we use ZipArray for Add only when one element is just fit in one
// block.
auto out = conn->allReduce(op, shr, "open");
maskArray(out, field, nbits);
return out;
}

NdArrayRef out(shr.eltype(), {(int64_t)numel});
DISPATCH_ALL_FIELDS(field, [&]() {
auto inp = absl::MakeConstSpan(&shr.at<ring2k_t>(0), numel);
auto oup = absl::MakeSpan(&out.at<ring2k_t>(0), compact_numel);

size_t used = ZipArray(inp, nbits, oup);
size_t used = ZipArray(inp, space_bits, oup);
SPU_ENFORCE_EQ(used, compact_numel);

std::vector<ring2k_t> opened;
Expand All @@ -76,8 +107,16 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
}

oup = absl::MakeSpan(&out.at<ring2k_t>(0), numel);
UnzipArray(absl::MakeConstSpan(opened), nbits, oup);
UnzipArray(absl::MakeConstSpan(opened), space_bits, oup);

if (space_bits > nbits and nbits < fwidth) {
auto msk = (static_cast<ring2k_t>(1) << nbits) - 1;
for (size_t i = 0; i < numel; ++i) {
oup[i] &= msk;
}
}
});

return out.reshape(shr.shape());
}

Expand All @@ -87,8 +126,8 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
#include "sse2neon.h"
#endif

#define INP(x, y) inp[(x) * ncols / 8 + (y) / 8]
#define OUT(x, y) out[(y) * nrows / 8 + (x) / 8]
#define INP(x, y) inp[(x)*ncols / 8 + (y) / 8]
#define OUT(x, y) out[(y)*nrows / 8 + (x) / 8]

#ifdef __x86_64__
__attribute__((target("sse2")))
Expand Down
78 changes: 76 additions & 2 deletions libspu/mpc/cheetah/ot/ot_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include "gtest/gtest.h"

#include "libspu/mpc/common/communicator.h"
#include "libspu/mpc/utils/ring_ops.h"
#include "libspu/mpc/utils/simulate.h"

namespace spu::mpc::cheetah::test {

Expand Down Expand Up @@ -81,7 +83,6 @@ TEST_P(OtUtilTest, ZipArrayBit) {
auto _zip = absl::MakeSpan(&zip.at<ring2k_t>(0), zip.numel());
auto _unzip = absl::MakeSpan(&unzip.at<ring2k_t>(0), unzip.numel());
pforeach(0, array.numel(), [&](int64_t i) { inp[i] &= mask; });

size_t zip_sze = ZipArrayBit<ring2k_t>(inp, bw, _zip);
SPU_ENFORCE(zip_sze == pack_sze);

Expand All @@ -99,4 +100,77 @@ TEST_P(OtUtilTest, ZipArrayBit) {
});
}

} // namespace spu::mpc::cheetah::test
template <typename T>
T makeBitsMask(size_t nbits) {
size_t max = sizeof(T) * 8;
if (nbits == 0) {
nbits = max;
}
SPU_ENFORCE(nbits <= max);
T mask = static_cast<T>(-1);
if (nbits < max) {
mask = (static_cast<T>(1) << nbits) - 1;
}
return mask;
}

void MaskArray(NdArrayRef array, FieldType field, size_t bw) {
DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView<ring2k_t> view(array);
auto msk = makeBitsMask<ring2k_t>(bw);
for (int64_t i = 0; i < view.numel(); ++i) {
view[i] &= msk;
}
});
}

TEST_P(OtUtilTest, OpenShare_ADD) {
const auto field = GetParam();
Shape shape = {1000L};

for (size_t bw_offset : {0, 15, 17}) {
size_t bw = SizeOf(field) * 8 - bw_offset;
NdArrayRef inp[2];
utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> ctx) {
int rank = ctx->Rank();

inp[rank] = ring_rand(field, shape);
MaskArray(inp[rank], field, bw);

auto conn = std::make_shared<Communicator>(ctx);
auto opened = OpenShare(inp[rank], ReduceOp::ADD, bw, conn);
if (rank == 0) return;
auto expected = ring_add(inp[0], inp[1]);
MaskArray(expected, field, bw);

ASSERT_TRUE(std::memcmp(&opened.at<uint8_t>(0), &expected.at<uint8_t>(0),
opened.elsize() * opened.numel()) == 0);
});
}
}

TEST_P(OtUtilTest, OpenShare_XOR) {
const auto field = GetParam();
Shape shape = {1000L};

for (size_t bw_offset : {0, 3, 15}) {
size_t bw = SizeOf(field) * 8 - bw_offset;
NdArrayRef inp[2];
utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> ctx) {
int rank = ctx->Rank();

inp[rank] = ring_rand(field, shape);
MaskArray(inp[rank], field, bw);

auto conn = std::make_shared<Communicator>(ctx);
auto opened = OpenShare(inp[rank], ReduceOp::XOR, bw, conn);
if (rank == 0) return;
auto expected = ring_xor(inp[0], inp[1]);

ASSERT_TRUE(std::memcmp(&opened.at<uint8_t>(0), &expected.at<uint8_t>(0),
opened.elsize() * opened.numel()) == 0);
});
}
}

} // namespace spu::mpc::cheetah::test

0 comments on commit 3117206

Please sign in to comment.