Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo sync #826

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/asan-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ version: 2.1
parameters:
run-asan:
type: boolean
default: false
default: true

# Define a job to be invoked later in a workflow.
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
Expand Down Expand Up @@ -55,7 +55,7 @@ jobs:
command: |
set +e
declare -i test_status
bazel test //libspu/... --features=asan --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]}
bazel test //libspu/... --features=asan --test_timeout=500 --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]}

sh ../devtools/rename-junit-xml.sh
find bazel-testlogs/ -type f -name "test.log" -print0 | xargs -0 tar -cvzf test_logs.tar.gz
Expand Down
4 changes: 4 additions & 0 deletions libspu/kernel/hal/prot_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ Value _s2v(SPUContext* ctx, const Value& in, int owner) {
MAP_UNARY_OP(not_p)
MAP_UNARY_OP(not_s)
MAP_UNARY_OP(not_v)
// Negate family
MAP_UNARY_OP(negate_p)
MAP_UNARY_OP(negate_s)
MAP_UNARY_OP(negate_v)
// Msb family
MAP_UNARY_OP(msb_p)
MAP_UNARY_OP(msb_s)
Expand Down
4 changes: 4 additions & 0 deletions libspu/kernel/hal/prot_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ Value _not_p(SPUContext* ctx, const Value& in);
Value _not_s(SPUContext* ctx, const Value& in);
Value _not_v(SPUContext* ctx, const Value& in);

Value _negate_p(SPUContext* ctx, const Value& in);
Value _negate_s(SPUContext* ctx, const Value& in);
Value _negate_v(SPUContext* ctx, const Value& in);

Value _msb_p(SPUContext* ctx, const Value& in);
Value _msb_s(SPUContext* ctx, const Value& in);
Value _msb_v(SPUContext* ctx, const Value& in);
Expand Down
9 changes: 1 addition & 8 deletions libspu/kernel/hal/ring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ Value _cast_type(SPUContext* ctx, const Value& x, const Type& to) {
SPU_THROW("unsupport unary op={} for {}", #Name, in); \
} \
}

IMPL_UNARY_OP(_not)
IMPL_UNARY_OP(_negate)
IMPL_UNARY_OP(_msb)
IMPL_UNARY_OP(_square)

Expand Down Expand Up @@ -438,13 +438,6 @@ Value _equal(SPUContext* ctx, const Value& x, const Value& y) {
_xor(ctx, _less(ctx, y, x), _k1));
}

Value _negate(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);

// negate(x) = not(x) + 1
return _add(ctx, _not(ctx, x), _constant(ctx, 1, x.shape()));
}

Value _sign(SPUContext* ctx, const Value& x) {
SPU_TRACE_HAL_LEAF(ctx, x);

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/ab_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Value rand_b(SPUContext* ctx, const Shape& shape) {
FORCE_DISPATCH(ctx, shape);
}

Value not_a(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }
Value negate_a(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }

Value add_ap(SPUContext* ctx, const Value& x, const Value& y) {
FORCE_DISPATCH(ctx, x, y);
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/ab_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Value msb_a2b(SPUContext* ctx, const Value& x);
Value rand_a(SPUContext* ctx, const Shape& shape);
Value rand_b(SPUContext* ctx, const Shape& shape);

Value not_a(SPUContext* ctx, const Value& x);
Value negate_a(SPUContext* ctx, const Value& x);

Value equal_ap(SPUContext* ctx, const Value& x, const Value& y);
Value equal_aa(SPUContext* ctx, const Value& x, const Value& y);
Expand Down
8 changes: 4 additions & 4 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ TEST_P(ArithmeticTest, MatMulAV) {
});
}

TEST_P(ArithmeticTest, NotA) {
TEST_P(ArithmeticTest, NegateA) {
const auto factory = std::get<0>(GetParam());
const RuntimeConfig& conf = std::get<1>(GetParam());
const size_t npc = std::get<2>(GetParam());
Expand All @@ -450,15 +450,15 @@ TEST_P(ArithmeticTest, NotA) {

/* WHEN */
auto prev = obj->prot()->getState<Communicator>()->getStats();
auto r_a = not_a(obj.get(), a0);
auto r_a = negate_a(obj.get(), a0);
auto cost = obj->prot()->getState<Communicator>()->getStats() - prev;

auto r_p = a2p(obj.get(), r_a);
auto r_pp = a2p(obj.get(), not_a(obj.get(), a0));
auto r_pp = a2p(obj.get(), negate_a(obj.get(), a0));

/* THEN */
EXPECT_VALUE_EQ(r_p, r_pp);
EXPECT_TRUE(verifyCost(obj->prot()->getKernel("not_a"), "not_a",
EXPECT_TRUE(verifyCost(obj->prot()->getKernel("negate_a"), "negate_a",
conf.field(), kShape, npc, cost));
});
}
Expand Down
12 changes: 1 addition & 11 deletions libspu/mpc/aby3/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,10 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
});
}

NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto* comm = ctx->getState<Communicator>();
NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
const auto* in_ty = in.eltype().as<AShrTy>();
const auto field = in_ty->field();

auto rank = comm->getRank();

return DISPATCH_ALL_FIELDS(field, [&]() {
using el_t = std::make_unsigned_t<ring2k_t>;
using shr_t = std::array<el_t, 2>;
Expand All @@ -315,16 +312,9 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
NdArrayView<shr_t> _out(out);
NdArrayView<shr_t> _in(in);

// neg(x) = not(x) + 1
// not(x) = neg(x) - 1
pforeach(0, in.numel(), [&](int64_t idx) {
_out[idx][0] = -_in[idx][0];
_out[idx][1] = -_in[idx][1];
if (rank == 0) {
_out[idx][1] -= 1;
} else if (rank == 1) {
_out[idx][0] -= 1;
}
});

return out;
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/aby3/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ class RandA : public RandKernel {
NdArrayRef proc(KernelEvalContext* ctx, const Shape& shape) const override;
};

class NotA : public UnaryKernel {
class NegateA : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_a";
static constexpr char kBindName[] = "negate_a";

ce::CExpr latency() const override { return ce::Const(0); }

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/aby3/protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void regAby3Protocol(SPUContext* ctx,
aby3::B2P, aby3::P2B, aby3::A2B, // Conversion2
aby3::B2ASelector, /*aby3::B2AByOT, aby3::B2AByPPA*/ // B2A
aby3::CastTypeB, // Cast
aby3::NotA, // Not
aby3::NegateA, // Negate
aby3::AddAP, aby3::AddAA, // Add
aby3::MulAP, aby3::MulAA, aby3::MulA1B, // Mul
aby3::MatMulAP, aby3::MatMulAA, // MatMul
Expand Down
40 changes: 30 additions & 10 deletions libspu/mpc/api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,22 +273,42 @@ Value rand_s(SPUContext* ctx, const Shape& shape) {
return rand_a(ctx, shape);
}

// only works for Z2k.
// Neg(x) = Not(x) + 1
// Not(x) = Neg(x) - 1
Value not_v(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
auto k1 = make_p(ctx, 1, x.shape());
return add_vp(ctx, negate_v(ctx, x), negate_p(ctx, k1));
}

Value not_p(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
auto k1 = make_p(ctx, 1, x.shape());
return add_pp(ctx, negate_p(ctx, x), negate_p(ctx, k1));
}

Value not_s(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
if (x.storage_type().isa<BShare>()) {
auto ones = make_p(ctx, -1, x.shape());
return xor_bp(ctx, x, ones);
} else {
SPU_ENFORCE(x.storage_type().isa<Secret>());
auto k1 = make_p(ctx, 1, x.shape());
return add_sp(ctx, negate_s(ctx, x), negate_p(ctx, k1));
}
}

Value negate_s(SPUContext* ctx, const Value& x) {
SPU_TRACE_MPC_DISP(ctx, x);
TRY_DISPATCH(ctx, x);
// TODO: Both A&B could handle not(invert).
// if (x.eltype().isa<BShare>()) {
// return not_b(ctx, x);
//} else {
// SPU_ENFORCE(x.eltype().isa<AShare>());
// return not_a(ctx, x);
//}
return not_a(ctx, _2a(ctx, x));
return negate_a(ctx, _2a(ctx, x));
}

Value not_v(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }
Value negate_v(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }

Value not_p(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }
Value negate_p(SPUContext* ctx, const Value& x) { FORCE_DISPATCH(ctx, x); }

//////////////////////////////////////////////////////////////////////////////

Expand Down
7 changes: 6 additions & 1 deletion libspu/mpc/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,16 @@ Value make_p(SPUContext* ctx, uint128_t init, const Shape& shape);
Value rand_p(SPUContext* ctx, const Shape& shape);
Value rand_s(SPUContext* ctx, const Shape& shape);

// Compute bitwise_not(invert) of a value in ring 2k space.
// Compute bitwise not of a value.
Value not_p(SPUContext* ctx, const Value& x);
Value not_s(SPUContext* ctx, const Value& x);
Value not_v(SPUContext* ctx, const Value& x);

// Compute negate of a value.
Value negate_p(SPUContext* ctx, const Value& x);
Value negate_s(SPUContext* ctx, const Value& x);
Value negate_v(SPUContext* ctx, const Value& x);

Value msb_p(SPUContext* ctx, const Value& x);
Value msb_s(SPUContext* ctx, const Value& x);
Value msb_v(SPUContext* ctx, const Value& x);
Expand Down
1 change: 1 addition & 0 deletions libspu/mpc/api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ TEST_BINARY_OP(xor)
TEST_UNARY_OP_V(OP) \
TEST_UNARY_OP_P(OP)

TEST_UNARY_OP(negate)
TEST_UNARY_OP(not )
TEST_UNARY_OP_V(msb)
TEST_UNARY_OP_P(msb)
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/cheetah/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class V2A : public UnaryKernel {
NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override;
};

class NotA : public UnaryKernel {
class NegateA : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_a";
static constexpr char kBindName[] = "negate_a";

ce::CExpr latency() const override { return ce::Const(0); }

Expand Down
7 changes: 1 addition & 6 deletions libspu/mpc/cheetah/arithmetic_semi2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,8 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
return x.as(makeType<AShrTy>(field));
}

NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto* comm = ctx->getState<Communicator>();
NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto res = ring_neg(in);
if (comm->getRank() == 0) {
const auto field = in.eltype().as<Ring2k>()->field();
ring_add_(res, ring_not(ring_zeros(field, in.shape())));
}

return res.as(in.eltype());
}
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void regCheetahProtocol(SPUContext* ctx,
ctx->prot()
->regKernel<cheetah::P2A, cheetah::A2P, cheetah::V2A, cheetah::A2V, //
cheetah::B2P, cheetah::P2B, cheetah::A2B, cheetah::B2A, //
cheetah::NotA, //
cheetah::NegateA, //
cheetah::AddAP, cheetah::AddAA, //
cheetah::MulAP, cheetah::MulAA, cheetah::MulAV, //
cheetah::SquareA, //
Expand Down
14 changes: 7 additions & 7 deletions libspu/mpc/common/pv2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,31 +136,31 @@ class RandP : public RandKernel {
}
};

class NotP : public UnaryKernel {
class NegateP : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_p";
static constexpr char kBindName[] = "negate_p";

ce::CExpr latency() const override { return ce::Const(0); }

ce::CExpr comm() const override { return ce::Const(0); }

NdArrayRef proc(KernelEvalContext*, const NdArrayRef& in) const override {
const auto field = in.eltype().as<Ring2k>()->field();
return ring_not(in).as(makeType<Pub2kTy>(field));
return ring_neg(in).as(makeType<Pub2kTy>(field));
}
};

class NotV : public UnaryKernel {
class NegateV : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_v";
static constexpr char kBindName[] = "negate_v";

ce::CExpr latency() const override { return ce::Const(0); }

ce::CExpr comm() const override { return ce::Const(0); }

NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override {
if (isOwner(ctx, in.eltype())) {
return ring_not(in).as(in.eltype());
return ring_neg(in).as(in.eltype());
} else {
return in;
}
Expand Down Expand Up @@ -954,7 +954,7 @@ void regPV2kTypes() {
void regPV2kKernels(Object* obj) {
obj->regKernel<V2P, P2V, //
MakeP, RandP, //
NotV, NotP, //
NegateV, NegateP, //
EqualVVV, EqualVP, EqualPP, //
AddVVV, AddVP, AddPP, //
MulVVV, MulVP, MulPP, //
Expand Down
8 changes: 4 additions & 4 deletions libspu/mpc/ref2k/ref2k.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,17 @@ class Ref2kRandS : public RandKernel {
}
};

class Ref2kNotS : public UnaryKernel {
class Ref2kNegateS : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_s";
static constexpr char kBindName[] = "negate_s";

ce::CExpr latency() const override { return ce::Const(0); }

ce::CExpr comm() const override { return ce::Const(0); }

NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override {
const auto field = in.eltype().as<Ring2k>()->field();
return ring_not(in).as(makeType<Ref2kSecrTy>(field));
return ring_neg(in).as(makeType<Ref2kSecrTy>(field));
}
};

Expand Down Expand Up @@ -488,7 +488,7 @@ void regRef2kProtocol(SPUContext* ctx,
ctx->prot()
->regKernel<Ref2kCommonTypeS, Ref2kCommonTypeV, Ref2kCastTypeS, //
Ref2kP2S, Ref2kS2P, Ref2kV2S, Ref2kS2V, //
Ref2kNotS, //
Ref2kNegateS, //
Ref2kAddSS, Ref2kAddSP, //
Ref2kMulSS, Ref2kMulSP, //
Ref2kMatMulSS, Ref2kMatMulSP, //
Expand Down
24 changes: 1 addition & 23 deletions libspu/mpc/securenn/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,30 +132,8 @@ NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
return out.as(makeType<Pub2kTy>(field));
}

NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto* comm = ctx->getState<Communicator>();

// First, let's show negate could be locally processed.
// let X = sum(Xi) % M
// let Yi = neg(Xi) = M-Xi
//
// we get
// Y = sum(Yi) % M
// = n*M - sum(Xi) % M
// = -sum(Xi) % M
// = -X % M
//
// 'not' could be processed accordingly.
// not(X)
// = M-1-X # by definition, not is the complement of 2^k
// = neg(X) + M-1
//
NdArrayRef NegateA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const {
auto res = ring_neg(in);
if (comm->getRank() == 0) {
const auto field = in.eltype().as<Ring2k>()->field();
ring_add_(res, ring_not(ring_zeros(field, in.shape())));
}

return res.as(in.eltype());
}

Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/securenn/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ class A2P : public UnaryKernel {
NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override;
};

class NotA : public UnaryKernel {
class NegateA : public UnaryKernel {
public:
static constexpr char kBindName[] = "not_a";
static constexpr char kBindName[] = "negate_a";

ce::CExpr latency() const override { return ce::Const(0); }

Expand Down
Loading
Loading