Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repo sync #558

Merged
merged 1 commit into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def _com_github_xtensor_xtl():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "f0946d01ef4cd9ecb1a27b4adb41ce6bcc846634"
OPENXLA_SHA256 = "0af44c6e621da42a87c68746d343b0400ed6ca4b7dc0a7c7efc32f32c83d6be2"
OPENXLA_COMMIT = "d1cf2382e57b1efba3bb17d6dd9d8657453405ca"
OPENXLA_SHA256 = "a7f439d54a4e35c7977c2ea17b3a2493b306c9629ccc8071b4962c905ac9f692"

maybe(
http_archive,
Expand Down
4 changes: 2 additions & 2 deletions libspu/compiler/passes/hlo_legalize_to_pphlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1026,8 +1026,8 @@ class HloToPPHloOpConverter<stablehlo::ConvolutionOp>
// Apply dilation and padding to the input of a convolution.
Value applyConvolutionPadding(
Location loc, Value input,
const std::optional<llvm::SmallVector<int64_t>> &padding,
const std::optional<llvm::SmallVector<int64_t>> &lhs_dilation,
const std::optional<llvm::ArrayRef<int64_t>> &padding,
const std::optional<llvm::ArrayRef<int64_t>> &lhs_dilation,
llvm::ArrayRef<int64_t> dim_mappings, OpBuilder &rewriter) const {
if ((!padding || isAll(*padding, 0)) &&
(!lhs_dilation || isAll(*lhs_dilation, 1))) {
Expand Down
10 changes: 5 additions & 5 deletions libspu/compiler/tests/hlo2pphlo/reduce_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ func.func @main(%arg1: tensor<1024x1xf32>) -> (tensor<1024xf32>) {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%2 = "stablehlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1024x1xf32>, tensor<f32>) -> tensor<1024xf32>
}) {dimensions = array<i64: 1>} : (tensor<1024x1xf32>, tensor<f32>) -> tensor<1024xf32>
return %1 : tensor<1024xf32>
}

Expand All @@ -24,10 +24,10 @@ func.func @main(%arg0: tensor<3x2xi64>) -> tensor<2x2xi64> {
%1 = "stablehlo.add"(%arg2, %arg3) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%1) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
return %result : tensor<2x2xi64>
Expand Down
10 changes: 5 additions & 5 deletions libspu/compiler/tests/hlo2pphlo/reduce_s.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ func.func @main(%arg1: tensor<1024x1xf32>) -> (tensor<1024xf32>) {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%2 = "stablehlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1024x1xf32>, tensor<f32>) -> tensor<1024xf32>
}) {dimensions = array<i64: 1>} : (tensor<1024x1xf32>, tensor<f32>) -> tensor<1024xf32>
return %1 : tensor<1024xf32>
}

Expand All @@ -24,10 +24,10 @@ func.func @main(%arg0: tensor<3x2xi64>) -> tensor<2x2xi64> {
%1 = "stablehlo.add"(%arg2, %arg3) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%1) : (tensor<i64>) -> ()
}) {
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>,
base_dilations = dense<[2, 1]> : tensor<2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>,
base_dilations = array<i64: 2, 1>,
window_dilations = array<i64: 3, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>
} : (tensor<3x2xi64>, tensor<i64>) -> tensor<2x2xi64>
return %result : tensor<2x2xi64>
Expand Down
12 changes: 10 additions & 2 deletions libspu/compiler/tests/hlo2pphlo/select_and_scatter.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ func.func @main(%arg0: tensor<128x5x5x32xf32>, %arg1: tensor<128x4x4x32xf32>, %a
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.add %arg3, %arg4 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor<128x5x5x32xf32>, tensor<128x4x4x32xf32>, tensor<f32>) -> tensor<128x5x5x32xf32>
}) {
padding = dense<0> : tensor<4x2xi64>,
window_dimensions = array<i64: 1, 2, 2, 1>,
window_strides = array<i64: 1, 1, 1, 1>
} : (tensor<128x5x5x32xf32>, tensor<128x4x4x32xf32>, tensor<f32>) -> tensor<128x5x5x32xf32>
return %0 : tensor<128x5x5x32xf32>
}

Expand All @@ -35,6 +39,10 @@ func.func @main(%arg0: tensor<128x16x16x64xf32>, %arg1: tensor<128x8x8x64xf32>,
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%1 = stablehlo.add %arg3, %arg4 : tensor<f32>
stablehlo.return %1 : tensor<f32>
}) {padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<128x16x16x64xf32>, tensor<128x8x8x64xf32>, tensor<f32>) -> tensor<128x16x16x64xf32>
}) {
padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>,
window_dimensions = array<i64:1, 3, 3, 1>,
window_strides = array<i64:1, 2, 2, 1>
} : (tensor<128x16x16x64xf32>, tensor<128x8x8x64xf32>, tensor<f32>) -> tensor<128x16x16x64xf32>
return %0 : tensor<128x16x16x64xf32>
}
2 changes: 1 addition & 1 deletion libspu/compiler/tests/hlo2pphlo/vreduce_mixed.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ func.func @main(%arg0: tensor<1024x1xf32>, %arg1: tensor<1024x1xf32>) -> (tensor
%2 = "stablehlo.add"(%arg2, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = "stablehlo.add"(%arg3, %arg5) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%2, %3) : (tensor<f32>, tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1024x1xf32>, tensor<1024x1xf32>, tensor<f32>, tensor<f32>) -> (tensor<1024xf32>, tensor<1024xf32>)
}) {dimensions = array<i64: 1>} : (tensor<1024x1xf32>, tensor<1024x1xf32>, tensor<f32>, tensor<f32>) -> (tensor<1024xf32>, tensor<1024xf32>)
return %1#0, %1#1 : tensor<1024xf32>, tensor<1024xf32>
}
2 changes: 1 addition & 1 deletion libspu/compiler/tests/hlo2pphlo/vreduce_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ func.func @main(%arg0: tensor<1024x1xf32>, %arg1: tensor<1024x1xf32>) -> (tensor
%2 = "stablehlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = "stablehlo.add"(%arg4, %arg5) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%2, %3) : (tensor<f32>, tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1024x1xf32>, tensor<1024x1xf32>, tensor<f32>, tensor<f32>) -> (tensor<1024xf32>, tensor<1024xf32>)
}) {dimensions = array<i64: 1>} : (tensor<1024x1xf32>, tensor<1024x1xf32>, tensor<f32>, tensor<f32>) -> (tensor<1024xf32>, tensor<1024xf32>)
return %1#0, %1#1 : tensor<1024xf32>, tensor<1024xf32>
}
2 changes: 1 addition & 1 deletion libspu/compiler/tests/hlo2pphlo/vreduce_s.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ func.func @main(%arg0: tensor<1024x1xf32>, %arg1: tensor<1024x1xf32>) -> (tensor
%2 = "stablehlo.add"(%arg2, %arg4) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%3 = "stablehlo.add"(%arg3, %arg5) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%2, %3) : (tensor<f32>, tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1024x1xf32>, tensor<1024x1xf32>, tensor<f32>, tensor<f32>) -> (tensor<1024xf32>, tensor<1024xf32>)
}) {dimensions = array<i64: 1>} : (tensor<1024x1xf32>, tensor<1024x1xf32>, tensor<f32>, tensor<f32>) -> (tensor<1024xf32>, tensor<1024xf32>)
return %1#0, %1#1 : tensor<1024xf32>, tensor<1024xf32>
}
60 changes: 30 additions & 30 deletions libspu/device/pphlo/pphlo_executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,11 @@ func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<2x2xi32>) {
%2 = stablehlo.maximum %arg1, %arg2 : tensor<i32>
stablehlo.return %2 : tensor<i32>
}) {
base_dilations = dense<[2, 1]> : tensor<2xi64>,
base_dilations = array<i64: 2, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
window_dimensions = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<[ 4, 1 ]> : tensor<2xi64>
window_dilations = array<i64: 3, 1>,
window_dimensions = array<i64: 2, 1>,
window_strides = array<i64: 4, 1>
} : (tensor<3x2xi32>, tensor<i32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
})",
Expand All @@ -463,11 +463,11 @@ func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<1x2xi32>) {
%2 = stablehlo.maximum %arg1, %arg2 : tensor<i32>
stablehlo.return %2 : tensor<i32>
}) {
base_dilations = dense<[2, 1]> : tensor<2xi64>,
base_dilations = array<i64: 2, 1>,
padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>,
window_dilations = dense<[3, 1]> : tensor<2xi64>,
window_dimensions = dense<[3, 1]> : tensor<2xi64>,
window_strides = dense<[4, 1]> : tensor<2xi64>
window_dilations = array<i64: 3, 1>,
window_dimensions = array<i64: 3, 1>,
window_strides = array<i64: 4, 1>
} : (tensor<3x2xi32>, tensor<i32>) -> tensor<1x2xi32>
return %1 : tensor<1x2xi32>
})",
Expand Down Expand Up @@ -583,11 +583,11 @@ func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<6x6xi32>) {
%2 = stablehlo.maximum %arg1, %arg2 : tensor<i32>
stablehlo.return %2 : tensor<i32>
}) {
base_dilations = dense<2> : tensor<2xi64>,
base_dilations = array<i64: 2, 2>,
padding = dense<0> : tensor<2x2xi64>,
window_dilations = dense<1> : tensor<2xi64>,
window_dimensions = dense<2> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
window_dilations = array<i64: 1, 1>,
window_dimensions = array<i64: 2, 2>,
window_strides = array<i64: 1, 1>
} : (tensor<4x4xi32>, tensor<i32>) -> tensor<6x6xi32>
return %1 : tensor<6x6xi32>
})",
Expand Down Expand Up @@ -615,11 +615,11 @@ func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) {
%2 = stablehlo.maximum %arg1, %arg2 : tensor<i32>
stablehlo.return %2 : tensor<i32>
}) {
base_dilations = dense<2> : tensor<2xi64>,
base_dilations = array<i64: 2, 2>,
padding = dense<0> : tensor<2x2xi64>,
window_dilations = dense<1> : tensor<2xi64>,
window_dimensions = dense<2> : tensor<2xi64>,
window_strides = dense<2> : tensor<2xi64>
window_dilations = array<i64: 1, 1>,
window_dimensions = array<i64: 2, 2>,
window_strides = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i32>) -> tensor<3x3xi32>

return %1 : tensor<3x3xi32>
Expand Down Expand Up @@ -648,11 +648,11 @@ func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) {
%2 = stablehlo.maximum %arg1, %arg2 : tensor<i32>
stablehlo.return %2 : tensor<i32>
}) {
base_dilations = dense<2> : tensor<2xi64>,
base_dilations = array<i64: 2, 2>,
padding = dense<0> : tensor<2x2xi64>,
window_dilations = dense<2> : tensor<2xi64>,
window_dimensions = dense<2> : tensor<2xi64>,
window_strides = dense<2> : tensor<2xi64>
window_dilations = array<i64: 2, 2>,
window_dimensions = array<i64: 2, 2>,
window_strides = array<i64: 2, 2>
} : (tensor<4x4xi32>, tensor<i32>) -> tensor<3x3xi32>

return %1 : tensor<3x3xi32>
Expand Down Expand Up @@ -681,11 +681,11 @@ func.func @main(%arg0: tensor<4x4xi32>) -> (tensor<3x3xi32>) {
%2 = stablehlo.maximum %arg1, %arg2 : tensor<i32>
stablehlo.return %2 : tensor<i32>
}) {
base_dilations = dense<2> : tensor<2xi64>,
base_dilations = array<i64: 2, 2>,
padding = dense<1> : tensor<2x2xi64>,
window_dilations = dense<1> : tensor<2xi64>,
window_dimensions = dense<3> : tensor<2xi64>,
window_strides = dense<3> : tensor<2xi64>
window_dilations = array<i64: 1, 1>,
window_dimensions = array<i64: 3, 3>,
window_strides = array<i64: 3, 3>
} : (tensor<4x4xi32>, tensor<i32>) -> tensor<3x3xi32>

return %1 : tensor<3x3xi32>
Expand Down Expand Up @@ -2474,11 +2474,11 @@ func.func @main(%arg0: tensor<4x6xi32>, %arg1: tensor<2x2xi32>) -> (tensor<2x2xi
%3 = stablehlo.maximum %arg2, %arg3 : tensor<i32>
stablehlo.return %3 : tensor<i32>
}) {
base_dilations = dense<1> : tensor<2xi64>,
base_dilations = array<i64: 1, 1>,
padding = dense<0> : tensor<2x2xi64>,
window_dilations = dense<1> : tensor<2xi64>,
window_dimensions = dense<[2,3]> : tensor<2xi64>,
window_strides = dense<[2, 3]> : tensor<2xi64>
window_dilations = array<i64: 1, 1>,
window_dimensions = array<i64: 2, 3>,
window_strides = array<i64: 2, 3>
} : (tensor<4x6xi32>, tensor<i32>) -> tensor<2x2xi32>
%2 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
Expand All @@ -2490,8 +2490,8 @@ func.func @main(%arg0: tensor<4x6xi32>, %arg1: tensor<2x2xi32>) -> (tensor<2x2xi
stablehlo.return %3 : tensor<i32>
}) {
padding = dense<0> : tensor<2x2xi64>,
window_dimensions = dense<[2,3]> : tensor<2xi64>,
window_strides = dense<[2,3]> : tensor<2xi64>
window_dimensions = array<i64: 2,3>,
window_strides = array<i64: 2,3>
} : (tensor<4x6xi32>, tensor<2x2xi32>, tensor<i32>) -> tensor<4x6xi32>
return %1, %2 : tensor<2x2xi32>, tensor<4x6xi32>
})",
Expand Down
7 changes: 7 additions & 0 deletions libspu/device/pphlo/pphlo_intrinsic_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "spdlog/spdlog.h"

#include "libspu/kernel/hal/fxp_approx.h"
#include "libspu/kernel/hlo/casting.h"
#include "libspu/kernel/hlo/const.h"

Expand Down Expand Up @@ -48,6 +49,12 @@ std::vector<Value> intrinsic_dispatcher(SPUContext* ctx, llvm::StringRef name,
SPDLOG_INFO("Calling example intrinsic");
return {inputs.begin(), inputs.end()};
}

if (name == "mhlo.erf") {
SPU_ENFORCE(inputs.size() == 1 && inputs[0].isFxp());
return {kernel::hal::f_erf(ctx, inputs[0])};
}

SPU_THROW("Unhandled intrinsic call {}", name.str());
}

Expand Down
55 changes: 55 additions & 0 deletions libspu/kernel/hal/fxp_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,4 +680,59 @@ Value f_cosine(SPUContext* ctx, const Value& x) {
return detail::cos_chebyshev(ctx, x);
}

namespace {

Value EvaluatePolynomial(SPUContext* ctx, const Value& x,
absl::Span<const float> coefficients) {
auto poly = constant(ctx, coefficients[0], x.dtype(), x.shape());

for (size_t i = 1; i < coefficients.size(); ++i) {
auto c = constant(ctx, coefficients[i], x.dtype(), x.shape());
poly = f_mul(ctx, poly, x);
poly = f_add(ctx, poly, c);
}
return poly;
}

Value ErfImpl(SPUContext* ctx, const Value& x) {
static std::array<float, 5> kErfCoefficient{0.078108, 0.000972, 0.230389,
0.278393, 1.0};
auto one = constant(ctx, 1.0, x.dtype(), x.shape());

auto z = EvaluatePolynomial(ctx, x, kErfCoefficient);
z = f_square(ctx, z);
z = f_square(ctx, z);
z = detail::reciprocal_goldschmidt_positive(ctx, z);

return f_sub(ctx, one, z);
}

} // namespace

// Ref:
// Handbook of Mathematical Functions: with Formulas, Graphs, and Mathematical
// Tables, equation 7.1.27, maximum absolute error <= 5e-4
Value f_erf(SPUContext* ctx, const Value& x) {
if (x.isPublic()) {
return f_erf_p(ctx, x);
}
auto zero = constant(ctx, 0.0, x.dtype(), x.shape());
auto pred = f_less(ctx, x, zero);

auto abs_x = f_abs(ctx, x);

auto three = constant(ctx, 3.0, x.dtype(), x.shape());
auto cond = f_less(ctx, abs_x, three);

auto erf = ErfImpl(ctx, abs_x);

// we do this truncation because:
// 1. for large abs_x, reciprocal may overflow
// 2. error is sufficiently small (< 2.2e-5)
erf = _mux(ctx, cond, erf, constant(ctx, 1.0F, x.dtype(), x.shape()))
.setDtype(x.dtype());

return _mux(ctx, pred, f_negate(ctx, erf), erf).setDtype(x.dtype());
}

} // namespace spu::kernel::hal
2 changes: 2 additions & 0 deletions libspu/kernel/hal/fxp_approx.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ Value f_sqrt(SPUContext* ctx, const Value& x);

Value f_sigmoid(SPUContext* ctx, const Value& x);

Value f_erf(SPUContext* ctx, const Value& x);

} // namespace spu::kernel::hal
32 changes: 32 additions & 0 deletions libspu/kernel/hal/fxp_approx_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,36 @@ TEST(FxpTest, Cosine) {
}
}

TEST(FxpTest, Erf) {
// GIVEN
SPUContext ctx = test::makeSPUContext();

xt::xarray<float> x = xt::random::rand<float>({10}, -10, 10);

// public cos
{
Value a = constant(&ctx, x, DT_F32);
Value c = f_erf(&ctx, a);
EXPECT_EQ(c.dtype(), DT_F32);

auto y = dump_public_as<float>(&ctx, c);
EXPECT_TRUE(xt::allclose(xt::erf(x), y, 0.01, 0.001))
<< xt::erf(x) << std::endl
<< y;
}

// secret cos
{
Value a = test::makeValue(&ctx, x, VIS_SECRET);
Value c = f_erf(&ctx, a);
EXPECT_EQ(c.dtype(), DT_F32);

auto y = dump_public_as<float>(&ctx, reveal(&ctx, c));
// low precision
EXPECT_TRUE(xt::allclose(xt::erf(x), y, 0.01, 0.001))
<< xt::erf(x) << std::endl
<< y;
}
}

} // namespace spu::kernel::hal
5 changes: 5 additions & 0 deletions libspu/kernel/hal/fxp_cleartext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,9 @@ Value f_cosine_p(SPUContext* ctx, const Value& in) {
return applyFloatingPointFn(ctx, in, [](float x) { return std::cos(x); });
}

Value f_erf_p(SPUContext* ctx, const Value& in) {
SPU_TRACE_HAL_DISP(ctx, in);
return applyFloatingPointFn(ctx, in, [](float x) { return std::erf(x); });
}

} // namespace spu::kernel::hal
2 changes: 2 additions & 0 deletions libspu/kernel/hal/fxp_cleartext.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ Value f_sine_p(SPUContext* ctx, const Value& in);

Value f_cosine_p(SPUContext* ctx, const Value& in);

Value f_erf_p(SPUContext* ctx, const Value& in);

} // namespace spu::kernel::hal
2 changes: 1 addition & 1 deletion sml/metrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def auc(x, y):
Area Under the Curve
"""
x, y = jax.lax.sort([x, y], num_keys=1)
area = jnp.abs(jnp.trapz(y, x))
area = jnp.abs(jax.scipy.integrate.trapezoid(y, x))
return area


Expand Down
Loading