Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo sync #795

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
## staging
>
> please add your unreleased change here.

- [Feature] Add more send/recv actions profiling

## 20240716
Expand Down
16 changes: 8 additions & 8 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ def _yacl():
http_archive,
name = "yacl",
urls = [
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3.tar.gz",
"https://github.com/secretflow/yacl/archive/refs/tags/0.4.5b3_nightly_20240722.tar.gz",
],
strip_prefix = "yacl-0.4.5b3",
sha256 = "bd89d63312e5e83eff5e001e2cf2135baff321c4b72a309f7d00cc53ce02e1a1",
strip_prefix = "yacl-0.4.5b3_nightly_20240722",
sha256 = "ccca599e6ded6089c5afbb87c8f5e09383195af256caacd50089f0c7443e8604",
)

def _libpsi():
maybe(
http_archive,
name = "psi",
urls = [
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.0beta.tar.gz",
"https://github.com/secretflow/psi/archive/refs/tags/v0.4.1.dev240722.tar.gz",
],
strip_prefix = "psi-0.4.0beta",
sha256 = "c2fbf486a66eca9d3ec1725a81d93a7c6e80a9206ef1c9263a1608e0bef95e1a",
strip_prefix = "psi-0.4.1.dev240722",
sha256 = "878cd8af2c7b9850944a27adf91f21dd4937d09d38e8365baad3b5165db8b39a",
)

def _rules_proto_grpc():
Expand Down Expand Up @@ -136,8 +136,8 @@ def _bazel_skylib():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "8533a6869ae02fb3b15a8a12739a982fc3c9f6e7"
OPENXLA_SHA256 = "d5b076825c992f59542f6b94e5480c7e7c6c627cd18c80ec60b6d5b295c160d4"
OPENXLA_COMMIT = "04f2bfe797408c9efe742b89e2e4db6cf526ebb7"
OPENXLA_SHA256 = "7e1d24737815be7607eed5f02fe7f81d97ffe358dfb7b4876f97bce8f48b3b3e"

# We need openxla to handle xla/mhlo/stablehlo
maybe(
Expand Down
1 change: 0 additions & 1 deletion libspu/compiler/tests/interpret/and.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT

Expand Down
8 changes: 5 additions & 3 deletions libspu/compiler/tests/interpret/generate_mlir_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ def TestCase(inputs, expected, checker='expect_eq', tol=None):
f.write(
"// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s\n"
)
f.write(
"// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n"
)
# FIXME: these tests are not stable for cheetah now
if test not in ["xor", "or", "and"]:
f.write(
"// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s\n"
)
# Some test values in max and min are not supported by protocol 5.
if test not in ["max", "min"]:
f.write(
Expand Down
1 change: 0 additions & 1 deletion libspu/compiler/tests/interpret/or.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT

Expand Down
1 change: 0 additions & 1 deletion libspu/compiler/tests/interpret/xor.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: spu-translate --protocol_kind=1 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=2 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=3 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=4 --interpret -split-input-file %s
// RUN: spu-translate --protocol_kind=5 --interpret -split-input-file %s
// AUTO GENERATED, DO NOT EDIT

Expand Down
4 changes: 2 additions & 2 deletions libspu/compiler/tests/passes/optimizations/ops_negative.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func.func @main() -> tensor<i32> {

func.func @main() -> tensor<i32> {
%0 = pphlo.constant dense<[0.000000e+00, -3.40282347E+38]> : tensor<2xf32>
// expected-error @+1 {{op broadcast_dimensions contains invalid value -6 for result with rank 1}}
// expected-error @+1 {{broadcast_dimensions contains invalid value -6 for result with rank 1}}
%1 = pphlo.broadcast %0, dims = [-6] : (tensor<2xf32>) -> tensor<2xf32>
%2 = pphlo.constant dense<5> : tensor<i32>
pphlo.return %2 : tensor<i32>
Expand All @@ -33,7 +33,7 @@ func.func @main() -> tensor<i32> {
// -----

func.func @main() -> tensor<i32> {
// expected-error @+1 {{op iota dimension cannot go beyond the output rank or be negative}}
// expected-error @+1 {{iota dimension cannot go beyond the output rank}}
%0 = pphlo.iota dim = 1000 : tensor<1xi32>
%1 = pphlo.constant dense<5> : tensor<i32>
pphlo.return %1 : tensor<i32>
Expand Down
1 change: 1 addition & 0 deletions libspu/compiler/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include "llvm/ADT/Twine.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir::spu {
Expand Down
80 changes: 4 additions & 76 deletions libspu/dialect/pphlo/IR/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,12 @@
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"
#include "stablehlo/dialect/TypeInference.h"

#include "libspu/dialect/pphlo/IR/ops.h.inc"

namespace mlir::spu::pphlo {

namespace {

// Checks if the vector `nums` has duplicates.
bool hasDuplicates(const ArrayRef<int64_t> nums) {
llvm::SmallDenseSet<int64_t> set(nums.begin(), nums.end());
return set.size() != nums.size();
}

} // namespace

template <typename T>
static LogicalResult Verify(T /*op*/) {
return success();
Expand Down Expand Up @@ -386,75 +377,12 @@ LogicalResult ConcatenateOp::verify() {
}

LogicalResult BroadcastOp::verify() {
auto operandType = mlir::dyn_cast<RankedTensorType>(getOperand().getType());

auto operandRank = operandType.getRank();

if (getBroadcastDimensions().empty()) {
if (operandRank == 0) {
return success();
}
return emitOpError(
llvm::formatv("broadcast_dimensions is absent, but required because "
"operand has non-zero rank ({0})",
operandRank));
}

auto dimensionsSize = getBroadcastDimensions().size();
if (static_cast<int64_t>(dimensionsSize) != operandRank) {
return emitOpError(llvm::formatv(
"broadcast_dimensions size ({0}) does not match operand rank ({1})",
dimensionsSize, operandRank));
}

auto dimensions = getBroadcastDimensions();
if (hasDuplicates(dimensions)) {
return emitOpError("broadcast_dimensions should not have duplicates");
}

auto resultType = mlir::dyn_cast<RankedTensorType>(getResult().getType());
auto resultRank = resultType.getRank();

for (size_t i = 0; i != dimensionsSize; ++i) {
auto dimIndex = dimensions[i];
if ((dimIndex >= resultRank) || (dimIndex < 0)) {
return emitOpError(
llvm::formatv("broadcast_dimensions contains invalid value {0} for "
"result with rank {1}",
dimIndex, resultRank));
}

if (!operandType.isDynamicDim(i)) {
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) {
return emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
"1 or size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize));
}
}
}

return success();
return hlo::verifyBroadcastInDimOp(getLoc(), getOperand(),
getBroadcastDimensions(), getResult());
}

LogicalResult IotaOp::verify() {
auto shape = mlir::dyn_cast<ShapedType>(getType());
if (!shape.hasRank()) {
return success();
}

if (shape.getRank() == 0) {
return emitOpError() << "does not support scalars.";
}

auto iotaDimension = static_cast<int64_t>(this->getIotaDimension());
if (iotaDimension >= shape.getRank() || iotaDimension < 0) {
return emitOpError()
<< "iota dimension cannot go beyond the output rank or be negative.";
}
return success();
return hlo::verifyIotaOp(getLoc(), getIotaDimension(), getResult());
}

LogicalResult SliceOp::verify() {
Expand Down
10 changes: 10 additions & 0 deletions libspu/dialect/pphlo/IR/print_parse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,19 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
if (parser.parseRParen()) {
return failure();
}
// Parse optional properties
if (succeeded(parser.parseOptionalLess()) &&
(failed(parser.parseAttribute(result.propertiesAttr)) ||
failed(parser.parseGreater()))) {
return failure();
}

// Parse optional attributes
if (parser.parseOptionalAttrDict(result.attributes)) {
return failure();
}

// Parse type signature
if (parser.parseColon() || parser.parseLParen() || parser.parseRParen() ||
parser.parseArrow()) {
return failure();
Expand Down
5 changes: 2 additions & 3 deletions libspu/dialect/pphlo/transforms/inline_secret_control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ class CaseConverter : public OpRewritePattern<CaseOp> {
if (target_type.getNumElements() == in_type.getNumElements()) {
return rewriter.create<ReshapeOp>(loc, broadcasted_mask_type, in);
} else {
return rewriter.create<BroadcastOp>(
loc, broadcasted_mask_type, in,
llvm::SmallVector<int64_t>(target_type.getRank(), 0));
return rewriter.create<BroadcastOp>(loc, broadcasted_mask_type, in,
llvm::SmallVector<int64_t>{0});
}
}

Expand Down
6 changes: 3 additions & 3 deletions libspu/mpc/aby3/boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in,

// TODO: the hal dtype should tell us about the max number of possible bits.
const auto field = ctx->getState<Z2kState>()->getDefaultField();
const size_t out_nbits =
std::min<size_t>(in_ty->nbits() + *std::max_element(bits.begin(), bits.end()),
SizeOf(field) * 8);
const size_t out_nbits = std::min<size_t>(
in_ty->nbits() + *std::max_element(bits.begin(), bits.end()),
SizeOf(field) * 8);
const PtType out_btype = calcBShareBacktype(out_nbits);
bool is_splat = bits.size() == 1;

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/cheetah/ot/yacl/yacl_ote_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
#include "yacl/kernel/algorithms/base_ot.h"
#include "yacl/kernel/algorithms/ferret_ote.h"
#include "yacl/kernel/algorithms/iknp_ote.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/algorithms/softspoken_ote.h"
#include "yacl/kernel/type/ot_store.h"

#include "libspu/core/prelude.h"
#include "libspu/mpc/cheetah/ot/ot_util.h"
Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/securenn/boolean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in,

int64_t out_nbits = in.eltype().as<BShare>()->nbits() +
*std::max_element(shift.begin(), shift.end());
out_nbits =
std::clamp<int64_t>(out_nbits, 0L, static_cast<int64_t>(SizeOf(field) * 8));
out_nbits = std::clamp<int64_t>(out_nbits, 0L,
static_cast<int64_t>(SizeOf(field) * 8));

return makeBShare(ring_lshift(in, shift), field, out_nbits);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ int main(int argc, char* argv[]) {
std::string key;
SPU_ENFORCE(
butil::Base64Decode(ttp_server_config::FLAGS_server_private_key, &key));
decode_private_key =
yacl::Buffer(decode_private_key.data(), decode_private_key.size());
decode_private_key = yacl::Buffer(key.data(), key.size());
}

spu::mpc::semi2k::beaver::ttp_server::ServerOptions ops{
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ spu_cc_library(
"//libspu/mpc/spdz2k/ot:tiny_ot",
"//libspu/mpc/utils:ring_ops",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/kernel/algorithms:ot_store",
"@yacl//yacl/kernel/type:ot_store",
"@yacl//yacl/link",
"@yacl//yacl/utils:matrix_utils",
"@yacl//yacl/utils:serialize",
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/beaver_tinyot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include "yacl/crypto/rand/rand.h"
#include "yacl/crypto/tools/prg.h"
#include "yacl/kernel/algorithms/base_ot.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"
#include "yacl/utils/serialize.h"

#include "libspu/mpc/common/prg_tensor.h"
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/beaver/beaver_tinyot.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#pragma once

#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"
#include "yacl/link/context.h"

#include "libspu/mpc/common/prg_state.h"
Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ spu_cc_library(
"//libspu/mpc/utils:ring_ops",
"@com_github_emptoolkit_emp_tool//:emp-tool",
"@yacl//yacl/crypto/tools:prg",
"@yacl//yacl/kernel/algorithms:ot_store",
"@yacl//yacl/kernel/type:ot_store",
"@yacl//yacl/link",
],
)
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/kos_ote.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#pragma once
#include "absl/types/span.h"
#include "yacl/base/dynamic_bitset.h"
#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"
#include "yacl/link/link.h"
namespace spu::mpc::spdz2k {

Expand Down
2 changes: 1 addition & 1 deletion libspu/mpc/spdz2k/ot/tiny_ot.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
#include <vector>

#include "yacl/kernel/algorithms/ot_store.h"
#include "yacl/kernel/type/ot_store.h"

#include "libspu/mpc/common/communicator.h"

Expand Down
Loading