Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #827 #828

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libspu/mpc/cheetah/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,7 @@ spu_cc_test(
deps = [
":ot_util",
"//libspu/mpc/utils:ring_ops",
"//libspu/mpc/common:communicator",
"//libspu/mpc/utils:simulate",
],
)
57 changes: 48 additions & 9 deletions libspu/mpc/cheetah/ot/ot_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ void U8ToBool(absl::Span<uint8_t> bits, uint8_t u8) {
}
}

template <typename T>
static T _makeBitsMask(size_t nbits) {
size_t max = sizeof(T) * 8;
if (nbits == 0) {
nbits = max;
}
SPU_ENFORCE(nbits <= max);
T mask = static_cast<T>(-1);
if (nbits < max) {
mask = (static_cast<T>(1) << nbits) - 1;
}
return mask;
}

static void maskArray(NdArrayRef array, FieldType field, size_t bw) {
DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView<ring2k_t> view(array);
auto msk = _makeBitsMask<ring2k_t>(bw);
for (int64_t i = 0; i < view.numel(); ++i) {
view[i] &= msk;
}
});
}

NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
std::shared_ptr<Communicator> conn) {
SPU_ENFORCE(conn != nullptr);
Expand All @@ -52,20 +76,27 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
nbits = fwidth;
}
SPU_ENFORCE(nbits <= fwidth, "nbits out-of-bound");
bool packable = fwidth > nbits;
if (not packable) {
return conn->allReduce(op, shr, "open");
}

size_t space_bits = op == ReduceOp::ADD ? nbits + 1 : nbits;
size_t numel = shr.numel();
size_t compact_numel = CeilDiv(numel * nbits, fwidth);
size_t compact_numel = CeilDiv(numel * space_bits, fwidth);

if (space_bits > nbits and 0 != (fwidth % space_bits)) {
// FIXME(lwj): for Add, we can have a better ZipArray to handle a ring
// element that placed in two different blocks.
// For now, we use ZipArray for Add only when one element is just fit in one
// block.
auto out = conn->allReduce(op, shr, "open");
maskArray(out, field, nbits);
return out;
}

NdArrayRef out(shr.eltype(), {(int64_t)numel});
DISPATCH_ALL_FIELDS(field, [&]() {
auto inp = absl::MakeConstSpan(&shr.at<ring2k_t>(0), numel);
auto oup = absl::MakeSpan(&out.at<ring2k_t>(0), compact_numel);

size_t used = ZipArray(inp, nbits, oup);
size_t used = ZipArray(inp, space_bits, oup);
SPU_ENFORCE_EQ(used, compact_numel);

std::vector<ring2k_t> opened;
Expand All @@ -76,8 +107,16 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
}

oup = absl::MakeSpan(&out.at<ring2k_t>(0), numel);
UnzipArray(absl::MakeConstSpan(opened), nbits, oup);
UnzipArray(absl::MakeConstSpan(opened), space_bits, oup);

if (space_bits > nbits and nbits < fwidth) {
auto msk = (static_cast<ring2k_t>(1) << nbits) - 1;
for (size_t i = 0; i < numel; ++i) {
oup[i] &= msk;
}
}
});

return out.reshape(shr.shape());
}

Expand All @@ -87,8 +126,8 @@ NdArrayRef OpenShare(const NdArrayRef &shr, ReduceOp op, size_t nbits,
#include "sse2neon.h"
#endif

#define INP(x, y) inp[(x) * ncols / 8 + (y) / 8]
#define OUT(x, y) out[(y) * nrows / 8 + (x) / 8]
#define INP(x, y) inp[(x)*ncols / 8 + (y) / 8]
#define OUT(x, y) out[(y)*nrows / 8 + (x) / 8]

#ifdef __x86_64__
__attribute__((target("sse2")))
Expand Down
78 changes: 76 additions & 2 deletions libspu/mpc/cheetah/ot/ot_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include "gtest/gtest.h"

#include "libspu/mpc/common/communicator.h"
#include "libspu/mpc/utils/ring_ops.h"
#include "libspu/mpc/utils/simulate.h"

namespace spu::mpc::cheetah::test {

Expand Down Expand Up @@ -81,7 +83,6 @@ TEST_P(OtUtilTest, ZipArrayBit) {
auto _zip = absl::MakeSpan(&zip.at<ring2k_t>(0), zip.numel());
auto _unzip = absl::MakeSpan(&unzip.at<ring2k_t>(0), unzip.numel());
pforeach(0, array.numel(), [&](int64_t i) { inp[i] &= mask; });

size_t zip_sze = ZipArrayBit<ring2k_t>(inp, bw, _zip);
SPU_ENFORCE(zip_sze == pack_sze);

Expand All @@ -99,4 +100,77 @@ TEST_P(OtUtilTest, ZipArrayBit) {
});
}

} // namespace spu::mpc::cheetah::test
template <typename T>
T makeBitsMask(size_t nbits) {
size_t max = sizeof(T) * 8;
if (nbits == 0) {
nbits = max;
}
SPU_ENFORCE(nbits <= max);
T mask = static_cast<T>(-1);
if (nbits < max) {
mask = (static_cast<T>(1) << nbits) - 1;
}
return mask;
}

void MaskArray(NdArrayRef array, FieldType field, size_t bw) {
DISPATCH_ALL_FIELDS(field, [&]() {
NdArrayView<ring2k_t> view(array);
auto msk = makeBitsMask<ring2k_t>(bw);
for (int64_t i = 0; i < view.numel(); ++i) {
view[i] &= msk;
}
});
}

TEST_P(OtUtilTest, OpenShare_ADD) {
const auto field = GetParam();
Shape shape = {1000L};

for (size_t bw_offset : {0, 15, 17}) {
size_t bw = SizeOf(field) * 8 - bw_offset;
NdArrayRef inp[2];
utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> ctx) {
int rank = ctx->Rank();

inp[rank] = ring_rand(field, shape);
MaskArray(inp[rank], field, bw);

auto conn = std::make_shared<Communicator>(ctx);
auto opened = OpenShare(inp[rank], ReduceOp::ADD, bw, conn);
if (rank == 0) return;
auto expected = ring_add(inp[0], inp[1]);
MaskArray(expected, field, bw);

ASSERT_TRUE(std::memcmp(&opened.at<uint8_t>(0), &expected.at<uint8_t>(0),
opened.elsize() * opened.numel()) == 0);
});
}
}

TEST_P(OtUtilTest, OpenShare_XOR) {
const auto field = GetParam();
Shape shape = {1000L};

for (size_t bw_offset : {0, 3, 15}) {
size_t bw = SizeOf(field) * 8 - bw_offset;
NdArrayRef inp[2];
utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> ctx) {
int rank = ctx->Rank();

inp[rank] = ring_rand(field, shape);
MaskArray(inp[rank], field, bw);

auto conn = std::make_shared<Communicator>(ctx);
auto opened = OpenShare(inp[rank], ReduceOp::XOR, bw, conn);
if (rank == 0) return;
auto expected = ring_xor(inp[0], inp[1]);

ASSERT_TRUE(std::memcmp(&opened.at<uint8_t>(0), &expected.at<uint8_t>(0),
opened.elsize() * opened.numel()) == 0);
});
}
}

} // namespace spu::mpc::cheetah::test
Loading