Skip to content

Commit

Permalink
repo-sync-2024-11-01T13:48:20+0800 (#899)
Browse files Browse the repository at this point in the history
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #

## Possible side effects?

- Performance:

- Backward compatibility:
  • Loading branch information
tpppppub authored Nov 1, 2024
1 parent c7055bc commit a3e1e09
Show file tree
Hide file tree
Showing 10 changed files with 24 additions and 19 deletions.
8 changes: 8 additions & 0 deletions libspu/dialect/pphlo/IR/fold.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ OpFoldResult ReverseOp::fold(FoldAdaptor) {
dims, [&](int64_t dim) { return shapedType.getDimSize(dim) == 1; })) {
return input;
}

// reverse(reverse(x, dims), dims) = x
if (auto prev = input.getDefiningOp<ReverseOp>()) {
if (prev.getDimensions() == dims) {
return prev.getOperand();
}
}

return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::vector<NdArrayRef> reduce(ReduceOp op,
ring_mul_(rs[idx], t);
}
} else {
SPU_ENFORCE("not supported reduction op");
SPU_THROW("not supported reduction op");
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions libspu/mpc/semi2k/prime_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ NdArrayRef MulPriv(KernelEvalContext* ctx, const NdArrayRef& x) {

// P0 sends (x+a) to P1 ; P1 sends (y+b) to P0
comm->sendAsync(comm->nextRank(), ring_add(a_or_b, x), "(x + a) or (y + b)");
xa_or_yb = comm->recv(comm->prevRank(), x.eltype(), "(x + a) or (y + b)");
xa_or_yb = comm->recv(comm->prevRank(), x.eltype(), "(x + a) or (y + b)")
.reshape(x.shape());
// note that our rings are commutative.
if (comm->getRank() == 0) {
ring_add_(c, ring_mul(std::move(xa_or_yb), x));
Expand Down Expand Up @@ -198,4 +199,4 @@ NdArrayRef ConvMP(KernelEvalContext* ctx, const NdArrayRef& h,
return x;
}

} // namespace spu::mpc::semi2k
} // namespace spu::mpc::semi2k
5 changes: 5 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"executionEnvironments": [
{"root": "."}
]
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
grpcio>=1.42.0,!=1.48.0
numpy>=1.22.0
numpy>=1.22.0, <2 # FIXME: for SF compatibility
protobuf>=4, <5
cloudpickle>=2.0.0
multiprocess>=0.70.12.2
Expand Down
8 changes: 2 additions & 6 deletions sml/linear_model/emulations/quantile_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import jax.numpy as jnp
Expand Down Expand Up @@ -48,15 +49,10 @@ def proc(X, y):
def generate_data():
from jax import random

# 设置随机种子
key = random.PRNGKey(42)
# 生成 X 数据
key, subkey = random.split(key)
X = random.normal(subkey, (100, 2))
# 生成 y 数据
y = (
5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1
) # 高相关性,带有小噪声
y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1
return X, y

try:
Expand Down
2 changes: 1 addition & 1 deletion sml/linear_model/tests/quantile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def generate_data():
# run
# Larger max_iter can give higher accuracy, but it will take more time to run
proc = proc_wrapper(
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=200
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=20
)
result, coef, intercept = spsim.sim_jax(sim, proc)(X, y)
rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2))
Expand Down
3 changes: 0 additions & 3 deletions sml/linear_model/utils/_linprog_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def _pivot_col(T, tol=1e-5):

all_masked = jnp.all(mask)

# 定义根据最小值选择列的函数
ma = jnp.where(mask, jnp.inf, T[-1, :-1])
min_col = jnp.argmin(ma)

Expand All @@ -44,12 +43,10 @@ def _pivot_row(T, pivcol, phase, tol=1e-5, max_val=1e10):

q = jnp.where(ma >= max_val, jnp.inf, mb / ma)

# 选择最小比值的行
min_rows = jnp.nanargmin(q)
all_masked = jnp.all(mask)

row = min_rows
# 处理全被掩盖的情况
row = jnp.where(all_masked, 0, row)

return ~all_masked, row
Expand Down
1 change: 0 additions & 1 deletion spu/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ py_binary(
srcs = ["jnp_debug.py"],
deps = [
"//spu:api",
"//spu/intrinsic:all_intrinsics",
"//spu/utils:simulation",
],
)
Expand Down
7 changes: 3 additions & 4 deletions spu/tests/jnp_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import jax.numpy as jnp
import numpy as np

import spu.intrinsic as si
import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as ppsim

Expand All @@ -31,9 +30,9 @@
copts.disable_div_sqrt_rewrite = True

x = np.random.randn(3, 4)
y = np.random.randn(5, 6)
fn = lambda x, y: si.example_binary(x, y)
# fn = lambda x, y: jnp.matmul(x, y)
y = np.random.randn(4, 5)
fn = lambda x, y: jnp.matmul(x, y)

spu_fn = ppsim.sim_jax(sim, fn, copts=copts)
z = spu_fn(x, y)

Expand Down

0 comments on commit a3e1e09

Please sign in to comment.