Skip to content

Commit

Permalink
Repo sync (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Nov 13, 2023
1 parent b0674fd commit 61e10e3
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 114 deletions.
21 changes: 21 additions & 0 deletions bazel/curve25519-donna.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_cc//cc:defs.bzl", "cc_library")

cc_library(
name = "curve25519_donna",
srcs = ["curve25519.c"],
hdrs = glob(["*.h"]),
visibility = ["//visibility:public"],
)
4 changes: 2 additions & 2 deletions bazel/emp-tool.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@yacl//bazel:yacl.bzl", "yacl_cmake_external")
load("@spulib//bazel:spu.bzl", "spu_cmake_external")

package(default_visibility = ["//visibility:public"])

Expand All @@ -21,7 +21,7 @@ filegroup(
srcs = glob(["**"]),
)

yacl_cmake_external(
spu_cmake_external(
name = "emp-tool",
cache_entries = {
"OPENSSL_ROOT_DIR": "$EXT_BUILD_DEPS/openssl",
Expand Down
22 changes: 0 additions & 22 deletions bazel/patches/xla.patch

This file was deleted.

30 changes: 20 additions & 10 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")

SECRETFLOW_GIT = "https://github.com/secretflow"

YACL_COMMIT_ID = "6be4330542e92b6503317c45a999c99e654ced58"
YACL_COMMIT_ID = "0953593df3ca6544442236f2b6d78a5b89035e24"

def spu_deps():
_rules_cuda()
Expand All @@ -43,6 +43,7 @@ def spu_deps():
_com_github_microsoft_kuku()
_com_google_flatbuffers()
_com_github_nvidia_cutlass()
_com_github_floodyberry_curve25519_donna()

maybe(
git_repository,
Expand Down Expand Up @@ -159,17 +160,17 @@ def _com_github_xtensor_xtl():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "75a7973c2850fcc33278c84e1b62eff8f0ad35f8"
OPENXLA_SHA256 = "4534c3230853e990ac613898c2ff39626d1beacb0c3675fbea502dce3e32f620"
OPENXLA_COMMIT = "d5791b01aa7541e3400224ac0a2985cc0f6940cb"
OPENXLA_SHA256 = "82dd50e6f51d79e8da69f109a234e33b8036f7b8798e41a03831b19c0c64d6e5"

SKYLIB_VERSION = "1.3.0"
SKYLIB_SHA256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506"

maybe(
http_archive,
name = "bazel_skylib",
sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506",
sha256 = SKYLIB_SHA256,
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/{version}/bazel-skylib-{version}.tar.gz".format(version = SKYLIB_VERSION),
"https://github.com/bazelbuild/bazel-skylib/releases/download/{version}/bazel-skylib-{version}.tar.gz".format(version = SKYLIB_VERSION),
],
)
Expand All @@ -181,10 +182,6 @@ def _com_github_openxla_xla():
sha256 = OPENXLA_SHA256,
strip_prefix = "xla-" + OPENXLA_COMMIT,
type = ".tar.gz",
patch_args = ["-p1"],
patches = [
"@spulib//bazel:patches/xla.patch",
],
urls = [
"https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = OPENXLA_COMMIT),
],
Expand Down Expand Up @@ -370,3 +367,16 @@ def _com_github_nvidia_cutlass():
sha256 = "9637961560a9d63a6bb3f407faf457c7dbc4246d3afb54ac7dc1e014dd7f172f",
build_file = "@spulib//bazel:nvidia_cutlass.BUILD",
)

def _com_github_floodyberry_curve25519_donna():
maybe(
http_archive,
name = "com_github_floodyberry_curve25519_donna",
strip_prefix = "curve25519-donna-2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2",
sha256 = "ba57d538c241ad30ff85f49102ab2c8dd996148456ed238a8c319f263b7b149a",
type = "tar.gz",
build_file = "@spulib//bazel:curve25519-donna.BUILD",
urls = [
"https://github.com/floodyberry/curve25519-donna/archive/2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2.tar.gz",
],
)
12 changes: 2 additions & 10 deletions bazel/spu.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ warpper bazel cc_xx to modify flags.
"""

load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake", "configure_make")
load("@yacl//bazel:yacl.bzl", "yacl_cmake_external")

WARNING_FLAGS = [
"-Wall",
Expand Down Expand Up @@ -68,15 +68,7 @@ def spu_cc_library(
**kargs
)

def spu_cmake_external(**attrs):
if "generate_args" not in attrs:
attrs["generate_args"] = ["-GNinja"]
return cmake(**attrs)

def spu_configure_make(**attrs):
if "args" not in attrs:
attrs["args"] = ["-j 4"]
return configure_make(**attrs)
spu_cmake_external = yacl_cmake_external

def _spu_version_file_impl(ctx):
out = ctx.actions.declare_file(ctx.attr.filename)
Expand Down
2 changes: 1 addition & 1 deletion libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) {
break;
}
}
debug_options.set_xla_detailed_logging_and_dumping(true);
debug_options.set_xla_enable_dumping(true);
}

auto module_config =
Expand Down
33 changes: 21 additions & 12 deletions libspu/kernel/hal/ring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,22 +346,31 @@ Value _mmul(SPUContext* ctx, const Value& x, const Value& y) {
const auto& row_blocks = ret_blocks[r];
for (int64_t c = 0; c < static_cast<int64_t>(row_blocks.size()); c++) {
const auto& block = row_blocks[c];
SPU_ENFORCE(block.data().isCompact());
const int64_t block_rows = block.shape()[0];
const int64_t block_cols = block.shape()[1];
if (n_blocks == 1) {
SPU_ENFORCE(row_blocks.size() == 1);
SPU_ENFORCE(block_cols == n);
char* dst = &ret.data().at<char>({r * m_step, 0});
const char* src = &block.data().at<char>({0, 0});
size_t cp_len = block.elsize() * block.numel();
std::memcpy(dst, src, cp_len);
if (block.data().isCompact()) {
if (n_blocks == 1) {
SPU_ENFORCE(row_blocks.size() == 1);
SPU_ENFORCE(block_cols == n);
char* dst = &ret.data().at<char>({r * m_step, 0});
const char* src = &block.data().at<char>({0, 0});
size_t cp_len = block.elsize() * block.numel();
std::memcpy(dst, src, cp_len);
} else {
for (int64_t i = 0; i < block_rows; i++) {
char* dst = &ret.data().at<char>({r * m_step + i, c * n_step});
const char* src = &block.data().at<char>({i, 0});
size_t cp_len = block.elsize() * block_cols;
std::memcpy(dst, src, cp_len);
}
}
} else {
for (int64_t i = 0; i < block_rows; i++) {
char* dst = &ret.data().at<char>({r * m_step + i, c * n_step});
const char* src = &block.data().at<char>({i, 0});
size_t cp_len = block.elsize() * block_cols;
std::memcpy(dst, src, cp_len);
for (int64_t j = 0; j < block_cols; j++) {
char* dst = &ret.data().at<char>({r * m_step + i, c * n_step + j});
const char* src = &block.data().at<char>({i, j});
std::memcpy(dst, src, block.elsize());
}
}
}
}
Expand Down
29 changes: 7 additions & 22 deletions libspu/kernel/hal/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Value Permute1D(SPUContext *, const Value &x, const Index &indices) {
return Value(x.data().linear_gather(indices), x.dtype());
}

// FIXME: move to mpc layer
// Vectorized Prefix Sum
// Ref: https://en.algorithmica.org/hpc/algorithms/prefix/
Value PrefixSum(SPUContext *ctx, const Value &x) {
Expand Down Expand Up @@ -242,10 +243,9 @@ spu::Value GenInvPermByTwoBitVectors(SPUContext *ctx, const spu::Value &x,
{reshape(ctx, f0, new_shape), reshape(ctx, f1, new_shape),
reshape(ctx, f2, new_shape), reshape(ctx, f3, new_shape)},
1);
auto s = f.clone();

// calculate prefix sum
auto ps = PrefixSum(ctx, s);
auto ps = PrefixSum(ctx, f);

// mul f and s
auto fs = _mul(ctx, f, ps);
Expand Down Expand Up @@ -294,10 +294,9 @@ spu::Value GenInvPermByBitVector(SPUContext *ctx, const spu::Value &x) {
Shape new_shape = {1, numel};
auto f = concatenate(
ctx, {reshape(ctx, rev_x, new_shape), reshape(ctx, x, new_shape)}, 1);
auto s = f.clone();

// calculate prefix sum
auto ps = PrefixSum(ctx, s);
auto ps = PrefixSum(ctx, f);

// mul f and s
auto fs = _mul(ctx, f, ps);
Expand Down Expand Up @@ -339,25 +338,11 @@ std::vector<spu::Value> BitDecompose(SPUContext *ctx, const spu::Value &x,
? static_cast<size_t>(valid_bits)
: x_bshare.storage_type().as<BShare>()->nbits();
rets.reserve(nbits);
std::vector<std::unique_ptr<SPUContext>> sub_ctxs;
for (size_t bit = 0; bit < nbits; ++bit) {
sub_ctxs.push_back(ctx->fork());
}

std::vector<std::future<spu::Value>> futures;
for (size_t bit = 0; bit < nbits; ++bit) {
auto async_res = std::async(
[&](size_t bit, const spu::Value &x, const spu::Value &k1) {
auto sub_ctx = sub_ctxs[bit].get();
auto x_bshare_shift = right_shift_logical(sub_ctx, x, bit);
auto lowest_bit = _and(sub_ctx, x_bshare_shift, k1);
return _prefer_a(sub_ctx, lowest_bit);
},
bit, x_bshare, k1);
futures.push_back(std::move(async_res));
}
for (size_t bit = 0; bit < nbits; ++bit) {
rets.emplace_back(futures[bit].get());
auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit);
auto lowest_bit = _and(ctx, x_bshare_shift, k1);
rets.emplace_back(_prefer_a(ctx, lowest_bit));
}

return rets;
Expand Down Expand Up @@ -601,4 +586,4 @@ std::vector<spu::Value> simple_sort1d(SPUContext *ctx,
}
}

} // namespace spu::kernel::hal
} // namespace spu::kernel::hal
26 changes: 21 additions & 5 deletions spu/libspu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void BindLink(py::module& m) {
using yacl::link::CertInfo;
using yacl::link::Context;
using yacl::link::ContextDesc;
using yacl::link::RetryOptions;
using yacl::link::SSLOptions;
using yacl::link::VerifyOptions;

Expand All @@ -96,6 +97,25 @@ void BindLink(py::module& m) {
.def_readwrite("ca_file_path", &VerifyOptions::ca_file_path,
"the trusted CA file path");

py::class_<RetryOptions>(m, "RetryOptions",
"The options used for channel retry")
.def_readwrite("max_retry", &RetryOptions::max_retry, "max retry count")
.def_readwrite("retry_interval_ms", &RetryOptions::retry_interval_ms,
"first retry interval")
.def_readwrite("retry_interval_incr_ms",
&RetryOptions::retry_interval_incr_ms,
"the amount of time to increase between retries")
.def_readwrite("max_retry_interval_ms",
&RetryOptions::max_retry_interval_ms,
"the max interval between retries")
.def_readwrite("error_codes", &RetryOptions::error_codes,
"retry on these error codes, if empty, retry on all codes")
.def_readwrite(
"http_codes", &RetryOptions::http_codes,
"retry on these http codes, if empty, retry on all http codes")
.def_readwrite("aggressive_retry", &RetryOptions::aggressive_retry,
"do aggressive retry");

py::class_<SSLOptions>(m, "SSLOptions", "The options used for ssl")
.def_readwrite("cert", &SSLOptions::cert,
"certificate used for authentication")
Expand Down Expand Up @@ -132,12 +152,8 @@ void BindLink(py::module& m) {
.def_readwrite("enable_ssl", &ContextDesc::enable_ssl)
.def_readwrite("client_ssl_opts", &ContextDesc::client_ssl_opts)
.def_readwrite("server_ssl_opts", &ContextDesc::server_ssl_opts)
.def_readwrite("brpc_retry_count", &ContextDesc::brpc_retry_count)
.def_readwrite("brpc_retry_interval_ms",
&ContextDesc::brpc_retry_interval_ms)
.def_readwrite("brpc_aggressive_retry",
&ContextDesc::brpc_aggressive_retry)
.def_readwrite("link_type", &ContextDesc::link_type)
.def_readwrite("retry_opts", &ContextDesc::retry_opts)
.def(
"add_party",
[](ContextDesc& desc, std::string id, std::string host) {
Expand Down
Loading

0 comments on commit 61e10e3

Please sign in to comment.