Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

repo-sync-2024-09-24T14:07:49+0800 #866

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b5_nightly_20240919.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b6_nightly_20240923.tar.gz",
],
strip_prefix = "yacl-0.4.5b5_nightly_20240919",
sha256 = "0ef295f6878dce6160fd44e6af59fa369099f736fa8d4a10f9685dda66aefa71",
strip_prefix = "yacl-0.4.5b6_nightly_20240923",
sha256 = "14eaaf7ad4aead7f2244e56453fead4a47973a020e23739ca0fe93873866bb5f",
)

def _libpsi():
Expand Down
5 changes: 4 additions & 1 deletion libspu/compiler/tools/spu-translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
#include "libspu/compiler/core/core.h"
#undef EXPOSE_PIPELINE_BUILDER

template <typename T>
struct fmt::formatter<xt::xarray<T>> : ostream_formatter {};

llvm::cl::opt<uint32_t> ProtocolKind(
"protocol_kind", llvm::cl::init(1),
llvm::cl::desc("1 for REF2k, 2 for SEMI2k, 3 for ABY3, 4 for Cheetah"));
Expand Down Expand Up @@ -72,7 +75,7 @@ void isEqual(const xt::xarray<T> &lhs, const xt::xarray<T> &rhs) {

auto error = lhs - rhs;

for (auto v : error) {
for (T v : error) {
if (v != 0) {
llvm::report_fatal_error(fmt::format("Diff = {}", v).c_str());
}
Expand Down
1 change: 1 addition & 0 deletions libspu/core/ndarray_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "absl/types/span.h"
#include "fmt/ostream.h"
#include "fmt/ranges.h"
#include "yacl/base/buffer.h"

#include "libspu/core/bit_utils.h"
Expand Down
1 change: 0 additions & 1 deletion libspu/core/trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <vector>

#include "absl/types/span.h"
#include "fmt/format.h"
#include "fmt/ranges.h"
#include "spdlog/spdlog.h"
#include "yacl/link/context.h"
Expand Down
8 changes: 7 additions & 1 deletion libspu/dialect/pphlo/IR/type_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,14 @@ LogicalResult inferDynamicUpdateSliceOp(
}

// dynamic_update_slice_c1
TypeTools tools(operand.getContext());
auto common_vis =
tools.computeCommonVisibility({tools.getTypeVisibility(operandType),
tools.getTypeVisibility(updateType)});

inferredReturnTypes.emplace_back(RankedTensorType::get(
operandType.getShape(), operandType.getElementType()));
operandType.getShape(),
tools.getType(operandType.getElementType(), common_vis)));
return success();
}

Expand Down
4 changes: 4 additions & 0 deletions libspu/mpc/common/prg_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ inline NdArrayRef prgReplayArray(PrgSeed seed, const PrgArrayDesc& desc) {
return ring_rand(desc.field, desc.shape, seed, &counter);
}

inline NdArrayRef prgReplayArrayMutable(PrgSeed seed, PrgArrayDesc& desc) {
return ring_rand(desc.field, desc.shape, seed, &desc.prg_counter);
}

} // namespace spu::mpc
23 changes: 12 additions & 11 deletions libspu/mpc/semi2k/beaver/beaver_impl/beaver_tfp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Mul(FieldType field, int64_t size,

if (lctx_->Rank() == 0) {
ops[2].seeds = seeds_;
auto adjust = TrustedParty::adjustMul(ops);
auto adjust = TrustedParty::adjustMul(absl::MakeSpan(ops));
ring_add_(c, adjust);
}

Expand Down Expand Up @@ -158,7 +158,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Square(FieldType field, int64_t size,

if (lctx_->Rank() == 0) {
ops[1].seeds = seeds_;
auto adjust = TrustedParty::adjustSquare(ops);
auto adjust = TrustedParty::adjustSquare(absl::MakeSpan(ops));
ring_add_(b, adjust);
}

Expand Down Expand Up @@ -223,7 +223,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::Dot(FieldType field, int64_t m,

if (lctx_->Rank() == 0) {
ops[2].seeds = seeds_;
auto adjust = TrustedParty::adjustDot(ops);
auto adjust = TrustedParty::adjustDot(absl::MakeSpan(ops));
ring_add_(c, adjust);
}

Expand All @@ -250,7 +250,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::And(int64_t size) {
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustAnd(ops);
auto adjust = TrustedParty::adjustAnd(absl::MakeSpan(ops));
ring_xor_(c, adjust);
}

Expand All @@ -276,7 +276,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Trunc(FieldType field, int64_t size,
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustTrunc(ops, bits);
auto adjust = TrustedParty::adjustTrunc(absl::MakeSpan(ops), bits);
ring_add_(b, adjust);
}

Expand All @@ -300,7 +300,7 @@ BeaverTfpUnsafe::Triple BeaverTfpUnsafe::TruncPr(FieldType field, int64_t size,
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjusts = TrustedParty::adjustTruncPr(ops, bits);
auto adjusts = TrustedParty::adjustTruncPr(absl::MakeSpan(ops), bits);
ring_add_(rc, std::get<0>(adjusts));
ring_add_(rb, std::get<1>(adjusts));
}
Expand All @@ -322,7 +322,7 @@ BeaverTfpUnsafe::Array BeaverTfpUnsafe::RandBit(FieldType field, int64_t size) {
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustRandBit(ops);
auto adjust = TrustedParty::adjustRandBit(absl::MakeSpan(ops));
ring_add_(a, adjust);
}

Expand All @@ -348,10 +348,11 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::PermPair(
auto pv_buf = lctx_->Recv(perm_rank, kTag);

ring_add_(b, TrustedParty::adjustPerm(
ops, absl::MakeSpan(pv_buf.data<int64_t>(),
pv_buf.size() / sizeof(int64_t))));
absl::MakeSpan(ops),
absl::MakeSpan(pv_buf.data<int64_t>(),
pv_buf.size() / sizeof(int64_t))));
} else {
ring_add_(b, TrustedParty::adjustPerm(ops, perm_vec));
ring_add_(b, TrustedParty::adjustPerm(absl::MakeSpan(ops), perm_vec));
}
} else if (perm_rank == lctx_->Rank()) {
lctx_->SendAsync(
Expand Down Expand Up @@ -380,7 +381,7 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::Eqz(FieldType field, int64_t size) {
for (auto& op : ops) {
op.seeds = seeds_;
}
auto adjust = TrustedParty::adjustEqz(ops);
auto adjust = TrustedParty::adjustEqz(absl::MakeSpan(ops));
ring_xor_(b, adjust);
}

Expand Down
132 changes: 73 additions & 59 deletions libspu/mpc/semi2k/beaver/beaver_impl/beaver_ttp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,53 +98,48 @@ class StreamReader : public brpc::StreamInputHandler {
kStreamFailed,
};

StreamReader() {
StreamReader(int32_t num_buf, size_t buf_len) {
SPU_ENFORCE(num_buf > 0);
SPU_ENFORCE(buf_len > 0);
buf_vec_.resize(num_buf);
buf_len_ = buf_len;
future_finished_ = promise_finished_.get_future();
future_closed_ = promise_closed_.get_future();
}

int on_received_messages(brpc::StreamId id, butil::IOBuf* const messages[],
size_t size) override {
SPDLOG_DEBUG("on_received_messages, stream id: {}", id);
if (status_ != Status::kNotFinished) {
SPDLOG_ERROR("unexpected messages received");
return -1;
}

for (size_t i = 0; i < size; ++i) {
if (status_ != Status::kNotFinished) {
SPDLOG_ERROR("unexpected messages received");
return -1;
}

SPDLOG_DEBUG("receive buf size: {}", messages[i]->size());
const auto& message = messages[i];
if (!buf_lens_.has_value()) {
beaver::ttp_server::BeaverDownStreamMeta meta{};
message->copy_to(&meta, sizeof(meta));
message->pop_front(sizeof(meta));
if (meta.err_code != 0) {
SPDLOG_ERROR("response error from server, err_code: {}, err_text: {}",
meta.err_code, message->to_string());
status_ = Status::kAbnormalFinished;
promise_finished_.set_value(status_);
return -2;
}
SPU_ENFORCE(meta.total_buf_num > 0);
buf_.emplace_back();
buf_lens_.emplace(meta.total_buf_num);
size_t meta_bytes = meta.total_buf_num * sizeof(uint64_t);
SPU_ENFORCE(message->length() >= meta_bytes);
message->copy_to(buf_lens_.value().data(), meta_bytes);
message->pop_front(meta_bytes);
beaver::ttp_server::BeaverDownStreamMeta meta;
message->copy_to(&meta, sizeof(meta));
message->pop_front(sizeof(meta));
if (meta.err_code != 0) {
SPDLOG_ERROR("response error from server, err_code: {}, err_text: {}",
meta.err_code, message->to_string());
status_ = Status::kAbnormalFinished;
promise_finished_.set_value(status_);
return -2;
}

size_t cur_buf_idx = buf_.size() - 1;
size_t cur_buf_size = buf_lens_.value().at(cur_buf_idx);
buf_.back().append(message->movable());
SPU_ENFORCE(buf_.back().length() <= cur_buf_size);
if (buf_.back().length() == cur_buf_size) {
if (cur_buf_idx == buf_lens_.value().size() - 1) {
status_ = Status::kNormalFinished;
promise_finished_.set_value(status_);
} else {
buf_.emplace_back();
}
SPU_ENFORCE(message->length() % buf_vec_.size() == 0);
size_t msg_len = message->length() / buf_vec_.size();
for (size_t buf_idx = 0; buf_idx < buf_vec_.size(); ++buf_idx) {
message->append_to(&buf_vec_[buf_idx], msg_len, buf_idx * msg_len);
}

SPU_ENFORCE(buf_vec_[0].length() <= buf_len_,
"unexpected bytes received");
if (buf_vec_[0].length() == buf_len_) {
status_ = Status::kNormalFinished;
promise_finished_.set_value(status_);
}
}
return 0;
Expand All @@ -169,23 +164,41 @@ class StreamReader : public brpc::StreamInputHandler {

const auto& GetBufVecRef() const {
SPU_ENFORCE(status_ == Status::kNormalFinished);
return buf_;
return buf_vec_;
}

Status WaitFinished() { return future_finished_.get(); };

void WaitClosed() { future_closed_.wait(); }

private:
std::vector<butil::IOBuf> buf_;
std::optional<std::vector<uint64_t>> buf_lens_;
std::vector<butil::IOBuf> buf_vec_;
size_t buf_len_;
Status status_ = Status::kNotFinished;
std::promise<Status> promise_finished_;
std::promise<void> promise_closed_;
std::future<Status> future_finished_;
std::future<void> future_closed_;
};

// Obtain a tuple containing num_buf and buf_len
template <class AdjustRequest>
std::tuple<int32_t, int64_t> GetBufferLength(const AdjustRequest& req) {
if constexpr (std::is_same_v<AdjustRequest,
beaver::ttp_server::AdjustDotRequest>) {
SPU_ENFORCE_EQ(req.prg_inputs().size(), 3);
return {1, req.prg_inputs()[2].buffer_len()};
} else if constexpr (std::is_same_v<
AdjustRequest,
beaver::ttp_server::AdjustTruncPrRequest>) {
SPU_ENFORCE_GE(req.prg_inputs().size(), 1);
return {2, req.prg_inputs()[0].buffer_len()};
} else {
SPU_ENFORCE_GE(req.prg_inputs().size(), 1);
return {1, req.prg_inputs()[0].buffer_len()};
}
}

template <class AdjustRequest>
std::vector<NdArrayRef> RpcCall(
brpc::Channel& channel, AdjustRequest req, FieldType ret_field,
Expand All @@ -194,9 +207,10 @@ std::vector<NdArrayRef> RpcCall(
beaver::ttp_server::BeaverService::Stub stub(&channel);
beaver::ttp_server::AdjustResponse rsp;

StreamReader reader;
auto [num_buf, buf_len] = GetBufferLength(req);
StreamReader reader(num_buf, buf_len);
brpc::StreamOptions stream_options;
stream_options.max_buf_size = 0;
stream_options.max_buf_size = 2 * beaver::ttp_server::kUpStreamChunkSize;
stream_options.handler = &reader;
brpc::StreamId stream_id;
SPU_ENFORCE_EQ(brpc::StreamCreate(&stream_id, cntl, &stream_options), 0,
Expand All @@ -206,14 +220,6 @@ std::vector<NdArrayRef> RpcCall(
reader.WaitClosed();
});

if (upstream_messages != nullptr) {
for (const auto& message : *upstream_messages) {
SPU_ENFORCE_EQ(brpc::StreamWrite(stream_id, message), 0);
SPDLOG_DEBUG("write buf size {} to stream id {}", message.length(),
stream_id);
}
}

if constexpr (std::is_same_v<AdjustRequest,
beaver::ttp_server::AdjustMulRequest>) {
stub.AdjustMul(&cntl, &req, &rsp, nullptr);
Expand Down Expand Up @@ -255,6 +261,19 @@ std::vector<NdArrayRef> RpcCall(
"Adjust server failed code={}, error={}",
ErrorCode_Name(rsp.code()), rsp.message());

if (upstream_messages != nullptr) {
for (const auto& message : *upstream_messages) {
int ret = brpc::StreamWrite(stream_id, message);
if (ret == EAGAIN) {
SPU_ENFORCE_EQ(brpc::StreamWait(stream_id, nullptr), 0);
ret = brpc::StreamWrite(stream_id, message);
}
SPU_ENFORCE_EQ(ret, 0, "Write stream failed");
SPDLOG_DEBUG("write buf size {} to stream id {}", message.length(),
stream_id);
}
}

auto status = reader.WaitFinished();
SPU_ENFORCE(status == StreamReader::Status::kNormalFinished,
"Stream reader finished abnormally, status: {}",
Expand Down Expand Up @@ -590,25 +609,20 @@ BeaverTtp::Pair BeaverTtp::PermPair(FieldType field, int64_t size,
if (lctx_->Rank() == perm_rank) {
auto req = BuildAdjustRequest<beaver::ttp_server::AdjustPermRequest>(
descs, descs_seed);
std::vector<butil::IOBuf> buf_vec;
beaver::ttp_server::BeaverPermUpStreamMeta meta{};
meta.total_buf_size = perm_vec.size() * sizeof(int64_t);
std::vector<butil::IOBuf> stream_data;
size_t left_buf_size = perm_vec.size() * sizeof(int64_t);
size_t chunk_idx = 0;
while (left_buf_size > 0) {
using beaver::ttp_server::kUpStreamChunkSize;
size_t cur_chunk_size = std::min(left_buf_size, kUpStreamChunkSize);
buf_vec.emplace_back();
if (chunk_idx == 0) {
buf_vec.back().append(&meta, sizeof(meta));
}
buf_vec.back().append(reinterpret_cast<const char*>(perm_vec.data()) +
(chunk_idx * kUpStreamChunkSize),
cur_chunk_size);
stream_data.emplace_back();
stream_data.back().append(reinterpret_cast<const char*>(perm_vec.data()) +
(chunk_idx * kUpStreamChunkSize),
cur_chunk_size);
++chunk_idx;
left_buf_size -= cur_chunk_size;
}
auto adjusts = RpcCall(channel_, req, field, &buf_vec);
auto adjusts = RpcCall(channel_, req, field, &stream_data);
SPU_ENFORCE_EQ(adjusts.size(), 1U);
ring_add_(b, adjusts[0].reshape(b.shape()));
}
Expand Down
Loading
Loading