Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo sync #616

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions libspu/kernel/hal/fxp_cleartext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,10 @@ Value f_erf_p(SPUContext* ctx, const Value& in) {
return applyFloatingPointFn(ctx, in, [](float x) { return std::erf(x); });
}

Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y) {
SPU_TRACE_HAL_DISP(ctx, x, y);
return applyFloatingPointFn(ctx, x, y,
[](float a, float b) { return std::pow(a, b); });
}

} // namespace spu::kernel::hal
2 changes: 2 additions & 0 deletions libspu/kernel/hal/fxp_cleartext.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,6 @@ Value f_cosine_p(SPUContext* ctx, const Value& in);

Value f_erf_p(SPUContext* ctx, const Value& in);

Value f_pow_p(SPUContext* ctx, const Value& x, const Value& y);

} // namespace spu::kernel::hal
28 changes: 25 additions & 3 deletions libspu/kernel/hal/polymorphic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "libspu/core/trace.h"
#include "libspu/kernel/hal/fxp_approx.h"
#include "libspu/kernel/hal/fxp_base.h"
#include "libspu/kernel/hal/fxp_cleartext.h"
#include "libspu/kernel/hal/integer.h"
#include "libspu/kernel/hal/ring.h" // for fast fxp x int
#include "libspu/kernel/hal/type_cast.h"
Expand Down Expand Up @@ -329,15 +330,36 @@ Value min(SPUContext* ctx, const Value& x, const Value& y) {
Value power(SPUContext* ctx, const Value& x, const Value& y) {
SPU_TRACE_HAL_DISP(ctx, x, y);

if (x.isInt() && y.isInt()) {
if (x.isInt() || y.isInt()) {
auto x_f = dtype_cast(ctx, x, DT_F32);
auto y_f = dtype_cast(ctx, y, DT_F32);
auto ret = power(ctx, x_f, y_f);
return dtype_cast(ctx, ret, x.dtype());
return ret;
}
if (x.isPublic() && y.isPublic()) {
return f_pow_p(ctx, x, y);
}

auto msb = _msb(ctx, x);
auto msb_a = _prefer_a(ctx, msb);
auto x_abs = _mux(ctx, msb_a, _negate(ctx, x), x).setDtype(x.dtype());

// if x=0 is public, then log(x) get -inf, the wrong output will be got after
// multiplying y. So we force x to be secret, then computing log(x) leads to
// a small negative numbers, so exp(y*log(x))=0.
auto x_s = x.isPublic() ? hal::seal(ctx, x_abs) : x_abs;
// x^y = e^(y*ln(x))
return exp(ctx, mul(ctx, y, log(ctx, x)));
// the precision is highly dependent on the precision of exp and log, so we
// choose the most precise methods here.
auto val = detail::exp_pade(ctx, mul(ctx, y, detail::log_minmax(ctx, x_s)));

// the final sign is decided on both sign of x and the parity of y
// when x<0 and y is odd, e.g. (-2)^3 = -8
auto odd = _and(ctx, _rshift(ctx, y, ctx->getFxpBits()),
_constant(ctx, 1, y.shape()));
auto sign = _and(ctx, msb, odd);

return _mux(ctx, sign, _negate(ctx, val), val).setDtype(x.dtype());
}

Value idiv(SPUContext* ctx, const Value& x, const Value& y) {
Expand Down
46 changes: 31 additions & 15 deletions libspu/kernel/hal/polymorphic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,26 +406,42 @@ TYPED_TEST(MathTest, Pow) {
using LHS_VT = typename std::tuple_element<1, TypeParam>::type;
using RHS_DT = typename std::tuple_element<2, TypeParam>::type;
using RHS_VT = typename std::tuple_element<3, TypeParam>::type;
using RES_DT = typename std::tuple_element<4, TypeParam>::type;
// using RES_DT = typename std::tuple_element<4, TypeParam>::type;

if constexpr (!std::is_same_v<LHS_DT, RHS_DT> ||
!std::is_same_v<LHS_VT, RHS_VT> || std::is_integral_v<RHS_DT>) {
return;
// GIVEN
xt::xarray<LHS_DT> x;
xt::xarray<RHS_DT> y;
{
// random test
x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
y = test::xt_random<RHS_DT>({5, 6}, -2, 2);

// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
<< z << std::endl;
}

// GIVEN
const xt::xarray<LHS_DT> x = test::xt_random<LHS_DT>({5, 6}, 0, 100);
const xt::xarray<RHS_DT> y = test::xt_random<RHS_DT>({5, 6}, 0, 2);
{
// some fixed corner case
x = {-1, -1, -3, 1, -3, 0, 1, 1, 5, 0};
y = {1, 0, -3, -3, 3, 0, 0, 2, 5, 2};

// WHAT
auto z = test::evalBinaryOp<RES_DT>(LHS_VT(), RHS_VT(), power, x, y);
// WHAT
auto z = test::evalBinaryOp<float>(LHS_VT(), RHS_VT(), power, x, y);

// THEN
auto expected = xt::pow(x, y);
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
<< z << std::endl;
// THEN
auto expected = xt::pow(x, y);
EXPECT_TRUE(xt::allclose(expected, z, 0.3, 0.03)) << x << std::endl
<< y << std::endl
<< expected << std::endl
<< z << std::endl;
}
}

using MathUnaryTestTypes = ::testing::Types<
Expand Down
163 changes: 41 additions & 122 deletions libspu/mpc/cheetah/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace spu::mpc::cheetah {
namespace {
// Return num_workers for the given size of jobs
size_t InitOTState(KernelEvalContext* ctx, size_t njobs) {
constexpr size_t kMinWorkSize = 5000;
constexpr size_t kMinWorkSize = 2048;
if (njobs == 0) {
return 0;
}
Expand Down Expand Up @@ -139,86 +139,44 @@ std::array<NdArrayRef, 3> CheetahMulState::TakeCachedBeaver(FieldType field,

NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
OTUnaryFunc func) {
Shape shape = x.shape();
const Shape& shape = x.shape();
SPU_ENFORCE(shape.numel() > 0);
// (lazy) init OT
int64_t numel = x.numel();
int64_t nworker = InitOTState(ctx, numel);
int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker);

int64_t slicing_dim = -1;
int64_t slice_numel = 1;
for (int64_t dim = shape.size() - 1; dim >= 0; dim--) {
slice_numel *= shape[dim];
if (slice_numel > workload) {
slice_numel /= shape[dim];
slicing_dim = dim;
break;
}
}

// get the slice num in the left outer dimensions
int64_t num_slice = 1;
for (int64_t dim = 0; dim < slicing_dim; dim++) {
num_slice *= shape[dim];
}

int64_t slice_stride = (workload + slice_numel - 1) / slice_numel;
if (slice_stride == 1) {
return func(x, ctx->getState<CheetahOTState>()->get(0));
}

int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0);

// initialize slice indices
Index start_indices(shape.size());
Index end_indices(shape.begin(), shape.end());
end_indices[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
end_indices[dim] = 1;
if (shape.ndim() != 1) {
// TiledDispatchOTFunc over flatten input
return TiledDispatchOTFunc(ctx, x.reshape({numel}), func)
.reshape(x.shape());
}

SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker);
nworker = num_slice * num_slice_dim;

std::vector<NdArrayRef> outs(nworker);
std::vector<std::future<void>> futures;

Index sidx = start_indices;
Index eidx = end_indices;
for (int64_t wi = 0; wi < nworker; ++wi) {
auto slice_input = x.slice(sidx, eidx, {});
int64_t slice_end = 0;
for (int64_t wi = 0; wi + 1 < nworker; ++wi) {
int64_t slice_bgn = wi * workload;
slice_end = std::min(numel, slice_bgn + workload);
auto slice_input = x.slice({slice_bgn}, {slice_end}, {});
futures.emplace_back(std::async(
[&](int64_t idx, const NdArrayRef& input) {
auto ot_instance = ctx->getState<CheetahOTState>()->get(idx);
outs[idx] = func(input, ot_instance);
},
wi, slice_input));

// update indices
if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
// carray out
sidx[slicing_dim] = 0;
eidx[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
sidx[dim] = (sidx[dim] + 1) % shape[dim];
eidx[dim] = eidx[dim] % shape[dim] + 1;
if (eidx[dim] != 1) {
break;
}
}
} else {
sidx[slicing_dim] += slice_stride;
eidx[slicing_dim] += slice_stride;
eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]);
}
}

auto slice_input = x.slice({slice_end}, {numel}, {1});
auto ot_instance = ctx->getState<CheetahOTState>()->get(nworker - 1);
outs[nworker - 1] = func(slice_input, ot_instance);

for (auto&& f : futures) {
f.get();
}

NdArrayRef out(x.eltype(), x.shape());
NdArrayRef out(outs[0].eltype(), x.shape());
int64_t offset = 0;

for (auto& out_slice : outs) {
Expand All @@ -232,89 +190,50 @@ NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,

NdArrayRef TiledDispatchOTFunc(KernelEvalContext* ctx, const NdArrayRef& x,
const NdArrayRef& y, OTBinaryFunc func) {
Shape shape = x.shape();
SPU_ENFORCE_EQ(x.shape(), y.shape());
const Shape& shape = x.shape();
SPU_ENFORCE(shape.numel() > 0);
SPU_ENFORCE_EQ(shape, y.shape());
// (lazy) init OT
int64_t numel = x.numel();
int64_t nworker = InitOTState(ctx, numel);
int64_t workload = nworker == 0 ? 0 : CeilDiv(numel, nworker);

int64_t slicing_dim = -1;
int64_t slice_numel = 1;
for (int64_t dim = shape.size() - 1; dim >= 0; dim--) {
slice_numel *= shape[dim];
if (slice_numel > workload) {
slice_numel /= shape[dim];
slicing_dim = dim;
break;
}
if (shape.ndim() != 1) {
// TiledDispatchOTFunc over flatten input
return TiledDispatchOTFunc(ctx, x.reshape({numel}), y.reshape({numel}),
func)
.reshape(x.shape());
}

// get the slice num in the left outer dimensions
int64_t num_slice = 1;
for (int64_t dim = 0; dim < slicing_dim; dim++) {
num_slice *= shape[dim];
}

int64_t slice_stride = (workload + slice_numel - 1) / slice_numel;
if (slice_stride == 1) {
return func(x, y, ctx->getState<CheetahOTState>()->get(0));
}

int64_t num_slice_dim = shape[slicing_dim] / slice_stride +
((shape[slicing_dim] % slice_stride) != 0 ? 1 : 0);

// initialize slice indices
Index start_indices(shape.size());
Index end_indices(shape.begin(), shape.end());
end_indices[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
end_indices[dim] = 1;
}

SPU_ENFORCE_LE(num_slice * num_slice_dim, nworker);
nworker = num_slice * num_slice_dim;

std::vector<NdArrayRef> outs(nworker);
std::vector<std::future<void>> futures;

Index sidx = start_indices;
Index eidx = end_indices;
for (int64_t wi = 0; wi < nworker; ++wi) {
auto x_slice = x.slice(sidx, eidx, {});
auto y_slice = y.slice(sidx, eidx, {});

int64_t slice_end = 0;
for (int64_t wi = 0; wi + 1 < nworker; ++wi) {
int64_t slice_bgn = wi * workload;
slice_end = std::min(numel, slice_bgn + workload);
auto x_slice = x.slice({slice_bgn}, {slice_end}, {1});
auto y_slice = y.slice({slice_bgn}, {slice_end}, {1});
futures.emplace_back(std::async(
[&](int64_t idx, const NdArrayRef& input0, const NdArrayRef& input1) {
[&](int64_t idx, const NdArrayRef& inp0, const NdArrayRef& inp1) {
auto ot_instance = ctx->getState<CheetahOTState>()->get(idx);
outs[idx] = func(input0, input1, ot_instance);
outs[idx] = func(inp0, inp1, ot_instance);
},
wi, x_slice, y_slice));

// update indices
if (0 == (eidx[slicing_dim] % shape[slicing_dim])) {
// carray out
sidx[slicing_dim] = 0;
eidx[slicing_dim] = slice_stride;
for (int64_t dim = slicing_dim - 1; dim >= 0; dim--) {
sidx[dim] = (sidx[dim] + 1) % shape[dim];
eidx[dim] = eidx[dim] % shape[dim] + 1;
if (eidx[dim] != 1) {
break;
}
}
} else {
sidx[slicing_dim] += slice_stride;
eidx[slicing_dim] += slice_stride;
eidx[slicing_dim] = std::min(shape[slicing_dim], eidx[slicing_dim]);
}
}

auto x_slice = x.slice({slice_end}, {numel}, {});
auto y_slice = y.slice({slice_end}, {numel}, {});
auto ot_instance = ctx->getState<CheetahOTState>()->get(nworker - 1);
outs[nworker - 1] = func(x_slice, y_slice, ot_instance);

for (auto&& f : futures) {
f.get();
}

NdArrayRef out(x.eltype(), x.shape());
NdArrayRef out(outs[0].eltype(), x.shape());
int64_t offset = 0;

for (auto& out_slice : outs) {
std::memcpy(out.data<std::byte>() + offset, out_slice.data(),
out_slice.numel() * out.elsize());
Expand Down
4 changes: 3 additions & 1 deletion libspu/mpc/cheetah/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "libspu/mpc/cheetah/ot/basic_ot_prot.h"
#include "libspu/mpc/cheetah/rlwe/utils.h"

#include "libspu/spu.pb.h"

namespace spu::mpc::cheetah {

using OTUnaryFunc = std::function<NdArrayRef(
Expand Down Expand Up @@ -101,7 +103,7 @@ class CheetahOTState : public State {

mutable std::mutex lock_;

static constexpr size_t kMaxOTParallel = 24;
static constexpr size_t kMaxOTParallel = 48;

size_t maximum_instances_ = 0;
std::vector<ProtPtr> basic_ot_prot_;
Expand Down
Loading