From 9f6419af4ba4ea86f8575536095c5e8a6364f8c4 Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Tue, 24 Sep 2024 09:15:28 +0800 Subject: [PATCH 1/4] fixed simplex --- sml/linear_model/BUILD.bazel | 8 + sml/linear_model/emulations/BUILD.bazel | 9 + sml/linear_model/emulations/quantile_emul.py | 102 ++++++++ sml/linear_model/quantile.py | 233 +++++++++++++++++++ sml/linear_model/tests/BUILD.bazel | 10 + sml/linear_model/tests/quantile_test.py | 93 ++++++++ sml/linear_model/utils/BUILD.bazel | 5 + sml/linear_model/utils/_linprog_simplex.py | 197 ++++++++++++++++ 8 files changed, 657 insertions(+) create mode 100644 sml/linear_model/emulations/quantile_emul.py create mode 100644 sml/linear_model/quantile.py create mode 100644 sml/linear_model/tests/quantile_test.py create mode 100644 sml/linear_model/utils/_linprog_simplex.py diff --git a/sml/linear_model/BUILD.bazel b/sml/linear_model/BUILD.bazel index 6fa078a8c..202fb8800 100644 --- a/sml/linear_model/BUILD.bazel +++ b/sml/linear_model/BUILD.bazel @@ -54,3 +54,11 @@ py_binary( "//sml/linear_model/utils:solver", ], ) + +py_library( + name = "quantile", + srcs = ["quantile.py"], + deps = [ + "//sml/linear_model/utils:_linprog_simplex", + ], +) \ No newline at end of file diff --git a/sml/linear_model/emulations/BUILD.bazel b/sml/linear_model/emulations/BUILD.bazel index 6778cd0c2..1f603fe1f 100644 --- a/sml/linear_model/emulations/BUILD.bazel +++ b/sml/linear_model/emulations/BUILD.bazel @@ -62,3 +62,12 @@ py_binary( "//sml/utils:emulation", ], ) + +py_binary( + name = "quantile_emul", + srcs = ["quantile_emul.py"], + deps = [ + "//sml/linear_model:quantile", + "//sml/utils:emulation", + ], +) \ No newline at end of file diff --git a/sml/linear_model/emulations/quantile_emul.py b/sml/linear_model/emulations/quantile_emul.py new file mode 100644 index 000000000..64ed95d39 --- /dev/null +++ b/sml/linear_model/emulations/quantile_emul.py @@ -0,0 +1,102 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time + +import jax.numpy as jnp +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +import sml.utils.emulation as emulation +from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor + +CONFIG_FILE = emulation.CLUSTER_ABY3_3PC + + +def emul_quantile(mode=emulation.Mode.MULTIPROCESS): + def proc_wrapper( + quantile, + alpha, + fit_intercept, + lr, + max_iter, + ): + quantile_custom = SmlQuantileRegressor( + quantile=quantile, + alpha=alpha, + fit_intercept=fit_intercept, + lr=lr, + max_iter=max_iter, + ) + + def proc(X, y): + quantile_custom_fit = quantile_custom.fit(X, y) + result = quantile_custom_fit.predict(X) + return result + + return proc + + def generate_data(): + from jax import random + + # 设置随机种子 + key = random.PRNGKey(42) + # 生成 X 数据 + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + # 生成 y 数据 + y = ( + 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 + ) # 高相关性,带有小噪声 + return X, y + + try: + # bandwidth and latency only work for docker mode + emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20) + emulator.up() + + # load mock data + X, y = generate_data() + + # compare with sklearn + quantile_sklearn = SklearnQuantileRegressor( + quantile=0.3, alpha=0.1, fit_intercept=True, solver='highs' + ) + start = time.time() + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + score_plain = jnp.mean(y <= quantile_sklearn_fit.predict(X)) + end = time.time() + print(f"Running time in SKlearn: {end - start:.2f}s") + + # mark these data to be protected in SPU + X_spu, y_spu = emulator.seal(X, y) + + # run + proc = proc_wrapper( + quantile=0.3, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300 + ) + start = time.time() + result = emulator.run(proc)(X_spu, y_spu) + end = time.time() + score_encrpted = jnp.mean(y <= result) + print(f"Running time in SPU: {end - start:.2f}s") + + # print acc + print(f"Accuracy in SKlearn: {score_plain:.2f}") + print(f"Accuracy in SPU: {score_encrpted:.2f}") + + finally: + emulator.down() + + +if __name__ == "__main__": + emul_quantile(emulation.Mode.MULTIPROCESS) diff --git a/sml/linear_model/quantile.py b/sml/linear_model/quantile.py new file mode 100644 index 000000000..e14b7fe8a --- /dev/null +++ b/sml/linear_model/quantile.py @@ -0,0 +1,233 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numbers +import warnings +from warnings import warn + +import jax +import jax.numpy as jnp +import pandas as pd +from jax import grad + +from sml.linear_model.utils._linprog_simplex import _linprog_simplex + + +class QuantileRegressor: + """ + Initialize the quantile regression model. + Parameters + ---------- + quantile : float, default=0.5 + The quantile to be predicted. Must be between 0 and 1. + A quantile of 0.5 corresponds to the median (50th percentile). + alpha : float, default=1.0 + Regularization strength; must be a positive float. + Larger values specify stronger regularization, reducing model complexity. + fit_intercept : bool, default=True + Whether to calculate the intercept for the model. + If False, no intercept will be used in calculations, meaning the model will + assume that the data is already centered. + lr : float, default=0.01 + Learning rate for the optimization process. This controls the size of + the steps taken in each iteration towards minimizing the objective function. + max_iter : int, default=1000 + The maximum number of iterations for the optimization algorithm. + This controls how long the model will continue to update the weights + before stopping. + Attributes + ---------- + coef_ : array-like of shape (n_features,) + The coefficients (weights) assigned to the input features. These will be + learned during model fitting. + intercept_ : float + The intercept (bias) term. If `fit_intercept=True`, this will be + learned during model fitting. + """ + + def __init__( + self, quantile=0.5, alpha=1.0, fit_intercept=True, lr=0.01, max_iter=1000 + ): + self.quantile = quantile + self.alpha = alpha + self.fit_intercept = fit_intercept + self.lr = lr + self.max_iter = max_iter + + self.coef_ = None + self.intercept_ = None + + def fit(self, X, y, sample_weight=None): + """ + Fit the quantile regression model using linear programming. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training data. + y : array-like of shape (n_samples,) + Target values. + sample_weight : array-like of shape (n_samples,), optional + Individual weights for each sample. If not provided, all samples + are assumed to have equal weight. + Returns + ------- + self : object + Returns an instance of self. + Steps: + 1. Determine the number of parameters (`n_params`), accounting for the intercept if needed. + 2. Define the objective function `c`, incorporating both the L1 regularization and the pinball loss. + 3. Set up the equality constraint matrix `A_eq` and vector `b_eq` based on the input data `X` and `y`. + 4. Solve the linear programming problem using `_linprog_simplex`. + 5. Extract the model parameters (intercept and coefficients) from the solution. + """ + n_samples, n_features = X.shape + n_params = n_features + + # sample_weight = jnp.ones((n_samples,)) + if sample_weight is None: + sample_weight = jnp.ones((n_samples,)) + + if self.fit_intercept: + n_params += 1 + + alpha = jnp.sum(sample_weight) * self.alpha + + # After rescaling alpha, the minimization problem is + # min sum(pinball loss) + alpha * L1 + # Use linear programming formulation of quantile regression + # min_x c x + # A_eq x = b_eq + # 0 <= x + # x = (s0, s, t0, t, u, v) = slack variables >= 0 + # intercept = s0 - t0 + # coef = s - t + # c = (0, alpha * 1_p, 0, alpha * 1_p, quantile * 1_n, (1-quantile) * 1_n) + # residual = y - X@coef - intercept = u - v + # A_eq = (1_n, X, -1_n, -X, diag(1_n), -diag(1_n)) + # b_eq = y + # p = n_features + # n = n_samples + # 1_n = vector of length n with entries equal one + # see https://stats.stackexchange.com/questions/384909/ + c = jnp.concatenate( + [ + jnp.full(2 * n_params, fill_value=alpha), + sample_weight * self.quantile, + sample_weight * (1 - self.quantile), + ] + ) + + if self.fit_intercept: + c = c.at[0].set(0) + c = c.at[n_params].set(0) + + eye = jnp.eye(n_samples) + if self.fit_intercept: + ones = jnp.ones((n_samples, 1)) + A = jnp.concatenate([ones, X, -ones, -X, eye, -eye], axis=1) + else: + A = jnp.concatenate([X, -X, eye, -eye], axis=1) + + b = y + + result = _linprog_simplex(c, A, b, maxiter=self.max_iter, tol=1e-3) + + solution = result[0] + + params = solution[:n_params] - solution[n_params : 2 * n_params] + + if self.fit_intercept: + self.coef_ = params[1:] + self.intercept_ = params[0] + else: + self.coef_ = params + self.intercept_ = 0.0 + return self + + def predict(self, X): + """ + Predict target values using the fitted quantile regression model. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Input data for which predictions are to be made. + Returns + ------- + y_pred : array-like of shape (n_samples,) + Predicted target values. + Notes + ----- + The predict method computes the predicted target values using the model's + learned coefficients and intercept (if fit_intercept=True). + - If the model includes an intercept, a column of ones is added to the input data `X` to account + for the intercept in the linear combination. + - The method then computes the dot product between the modified `X` and the stacked vector of + intercept and coefficients. + - If there is no intercept, the method simply computes the dot product between `X` and the coefficients. + """ + + if self.fit_intercept: + X = jnp.column_stack((jnp.ones(X.shape[0]), X)) + + return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_])) + else: + return jnp.dot(X, self.coef_) + + +def generate_data(): + import numpy as np + + np.random.seed(42) + + X = np.random.normal(size=(100, 2)) + + y = 3 * X[:, 0] + 2 * X[:, 1] + np.random.normal(size=(100,)) * 0.1 + return X, y + + +if __name__ == "__main__": + X, y = generate_data() + + quantile = 0.7 + alpha = 0.1 + fit_intercept = True + lr = 0.01 + max_iter = 300 + + quantile_custom = QuantileRegressor( + quantile=quantile, + alpha=alpha, + fit_intercept=fit_intercept, + lr=lr, + max_iter=max_iter, + ) + + quantile_custom_fit = quantile_custom.fit(X, y) + result = quantile_custom_fit.predict(X) + acc_custom = jnp.mean(y <= result) + print(result) + print(f"Accuracy in SPU: {acc_custom:.2f}") + print(quantile_custom_fit.coef_) + print(quantile_custom_fit.intercept_) + + from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + + quantile_sklearn = SklearnQuantileRegressor( + quantile=quantile, alpha=alpha, fit_intercept=fit_intercept, solver='highs' + ) + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + acc_sklearn = jnp.mean(y <= quantile_sklearn_fit.predict(X)) + print(f"Accuracy in SKlearn: {acc_sklearn:.2f}") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) diff --git a/sml/linear_model/tests/BUILD.bazel b/sml/linear_model/tests/BUILD.bazel index 1fa04f861..f729c2067 100644 --- a/sml/linear_model/tests/BUILD.bazel +++ b/sml/linear_model/tests/BUILD.bazel @@ -70,3 +70,13 @@ py_test( "//spu/utils:simulation", ], ) + +py_test( + name = "quantile_test", + srcs = ["quantile_test.py"], + deps = [ + "//sml/linear_model:quantile", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py new file mode 100644 index 000000000..511f4bf9d --- /dev/null +++ b/sml/linear_model/tests/quantile_test.py @@ -0,0 +1,93 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import jax.numpy as jnp + +# import numpy as np +from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor + +import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.utils.simulation as spsim +from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor + +class UnitTests(unittest.TestCase): + def test_quantile(self): + def proc_wrapper( + quantile, + alpha, + fit_intercept, + lr, + max_iter, + ): + quantile_custom = SmlQuantileRegressor( + quantile=quantile, + alpha=alpha, + fit_intercept=fit_intercept, + lr=lr, + max_iter=max_iter, + ) + + def proc(X, y): + quantile_custom_fit = quantile_custom.fit(X, y) + result = quantile_custom_fit.predict(X) + return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_ + + return proc + + n_samples, n_features = 100, 2 + + def generate_data(): + from jax import random + key = random.PRNGKey(42) + key, subkey = random.split(key) + X = random.normal(subkey, (100, 2)) + y = ( + 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 + ) + return X, y + + # bandwidth and latency only work for docker mode + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + X, y = generate_data() + + # compare with sklearn + quantile_sklearn = SklearnQuantileRegressor( + quantile=0.2, alpha=0.1, fit_intercept=True, solver='revised simplex' + ) + quantile_sklearn_fit = quantile_sklearn.fit(X, y) + acc_sklearn = jnp.mean(y <= quantile_sklearn_fit.predict(X)) + print(f"Accuracy in SKlearn: {acc_sklearn:.2f}") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) + + # run + proc = proc_wrapper( + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300 + ) + result, coef, intercept = spsim.sim_jax(sim, proc)(X, y) + acc_custom = jnp.mean(y <= result) + + # print accuracy + print(f"Accuracy in SPU: {acc_custom:.2f}") + print(coef) + print(intercept) + + +if __name__ == "__main__": + unittest.main() diff --git a/sml/linear_model/utils/BUILD.bazel b/sml/linear_model/utils/BUILD.bazel index 7c13def5f..273290734 100644 --- a/sml/linear_model/utils/BUILD.bazel +++ b/sml/linear_model/utils/BUILD.bazel @@ -31,3 +31,8 @@ py_library( name = "solver", srcs = ["solver.py"], ) + +py_library( + name = "_linprog_simplex", + srcs = ["_linprog_simplex.py"], +) diff --git a/sml/linear_model/utils/_linprog_simplex.py b/sml/linear_model/utils/_linprog_simplex.py new file mode 100644 index 000000000..c1e8c7365 --- /dev/null +++ b/sml/linear_model/utils/_linprog_simplex.py @@ -0,0 +1,197 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from warnings import warn + +import jax +import jax.numpy as jnp +from jax import jit, lax + + +def _pivot_col(T, tol=1e-5, bland=False): + mask = T[-1, :-1] >= -tol + + all_masked = jnp.all(mask) + + bland_first_col = jnp.argmin(jnp.where(mask, jnp.inf, jnp.arange(T.shape[1] - 1))) + # 定义根据最小值选择列的函数 + ma = jnp.where(mask, jnp.inf, T[-1, :-1]) + min_col = jnp.argmin(ma) + + result = jnp.where(bland, bland_first_col, min_col) + + valid = ~all_masked + result = jnp.where(all_masked, 0, result) + + return valid, result + +def _pivot_row(T, basis, pivcol, phase, tol=1e-5, bland=False): + + def true_mask_func(T, pivcol): + mask = T[:-2, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-2, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-2, -1]) + + q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma) + + # 选择最小比值的行 + min_rows = jnp.nanargmin(q) + all_masked = jnp.all(mask) + return min_rows, all_masked + + def false_mask_func(T, pivcol): + mask = T[:-1, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-1, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-1, -1]) + + q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma) + + # 选择最小比值的行 + min_rows = jnp.nanargmin(q) + all_masked = jnp.all(mask) + return min_rows, all_masked + + true_min_rows, true_all_masked = true_mask_func(T, pivcol) + false_min_rows, false_all_masked = false_mask_func(T, pivcol) + min_rows = jnp.where(phase == 1, true_min_rows, false_min_rows) + all_masked = jnp.where(phase == 1, true_all_masked, false_all_masked) + + # 检查掩码数组是否全被掩盖 + has_valid_row = min_rows.size > 0 + row = min_rows + + # 处理全被掩盖的情况 + row = jnp.where(all_masked, 0, row) + + # 处理没有满足条件的行的情况 + row = jnp.where(has_valid_row, row, 0) + + return ~all_masked & has_valid_row, row + + +def _apply_pivot(T, basis, pivrow, pivcol, tol=1e-5): + pivrow = jnp.int32(pivrow) + pivcol = jnp.int32(pivcol) + + basis = basis.at[pivrow].set(pivcol) + + pivrow_one_hot = jax.nn.one_hot(pivrow, T.shape[0]) + pivcol_one_hot = jax.nn.one_hot(pivcol, T.shape[1]) + + pivval = jnp.dot(pivrow_one_hot, jnp.dot(T, pivcol_one_hot)) + + updated_row = T[pivrow] / pivval + T = pivrow_one_hot[:, None] * updated_row + T * (1 - pivrow_one_hot[:, None]) + + scalar = jnp.dot(T, pivcol_one_hot).reshape(-1, 1) + + updated_T = T - scalar * T[pivrow] + + row_restore_matrix = pivrow_one_hot[:, None] * T[pivrow] + updated_T = row_restore_matrix + updated_T * (1 - pivrow_one_hot[:, None]) + + return updated_T, basis + + +def _solve_simplex( + T, + n, + basis, + maxiter=300, + tol=1e-5, + phase=2, + bland=False, +): + status = 0 + complete = False + + num = 0 + pivcol = 0 + pivrow = 0 + while num < maxiter: + pivcol_found, pivcol = _pivot_col(T, tol, bland) + + def cal_pivcol_found_True( + T, basis, pivcol, phase, tol, bland, status, complete + ): + pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol, bland) + + pivrow_isnot_found = pivrow_found == False + status = jnp.where(pivrow_isnot_found, 1, status) + complete = jnp.where(pivrow_isnot_found, True, complete) + + return pivrow, status, complete + + pivcol_isnot_found = pivcol_found == False + pivcol = jnp.where(pivcol_isnot_found, 0, pivcol) + pivrow = jnp.where(pivcol_isnot_found, 0, pivrow) + status = jnp.where(pivcol_isnot_found, 0, status) + complete = jnp.where(pivcol_isnot_found, True, complete) + + pivcol_is_found = pivcol_found == True + pivrow_True, status_True, complete_True = ( + cal_pivcol_found_True(T, basis, pivcol, phase, tol, bland, status, complete) + ) + + pivrow = jnp.where(pivcol_is_found, pivrow_True, pivrow) + status = jnp.where(pivcol_is_found, status_True, status) + complete = jnp.where(pivcol_is_found, complete_True, complete) + + complete_is_False = complete == False + apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol, tol) + T = jnp.where(complete_is_False, apply_T, T) + basis = jnp.where(complete_is_False, apply_basis, basis) + num = num + 1 + + return T, basis, status + + +def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, bland=False): + status = 0 + n, m = A.shape + + # All constraints must have b >= 0. + is_negative_constraint = jnp.less(b, 0) + A = jnp.where(is_negative_constraint[:, None], A * -1, A) + b = jnp.where(is_negative_constraint, b * -1, b) + + av = jnp.arange(n) + m + basis = av.copy() + + row_constraints = jnp.hstack((A, jnp.eye(n), b[:, jnp.newaxis])) + row_objective = jnp.hstack((c, jnp.zeros(n), c0)) + row_pseudo_objective = -row_constraints.sum(axis=0) + row_pseudo_objective = row_pseudo_objective.at[av].set(0) + T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective)) + + # phase 1 + T, basis, status = _solve_simplex( + T, n, basis, maxiter=maxiter, tol=tol, phase=1, bland=bland + ) + + status = jnp.where(jnp.abs(T[-1, -1]) < tol, status, 1) + + T_new = T[:-1, :] + jit_delete = jit(jnp.delete, static_argnames=['assume_unique_indices']) + T = jnp.delete(T_new, av, 1, assume_unique_indices=True) + + # phase 2 + T, basis, status = _solve_simplex(T, n, basis, maxiter, tol, 2, bland) + + solution = jnp.zeros(n + m) + solution = solution.at[basis[:n]].set(T[:n, -1]) + x = solution[:m] + + return x, status From 178a477c6f777b4fec6c94482eb9c2f1f036de43 Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Tue, 24 Sep 2024 09:19:41 +0800 Subject: [PATCH 2/4] fixed simplex --- libspu/dialect/pphlo/IR/type_inference.cc | 7 +++- sml/linear_model/BUILD.bazel | 2 +- sml/linear_model/emulations/BUILD.bazel | 2 +- sml/linear_model/quantile.py | 48 ---------------------- sml/linear_model/tests/quantile_test.py | 6 +-- sml/linear_model/utils/_linprog_simplex.py | 7 ++-- 6 files changed, 15 insertions(+), 57 deletions(-) diff --git a/libspu/dialect/pphlo/IR/type_inference.cc b/libspu/dialect/pphlo/IR/type_inference.cc index 6c95aac31..6974c2e9e 100644 --- a/libspu/dialect/pphlo/IR/type_inference.cc +++ b/libspu/dialect/pphlo/IR/type_inference.cc @@ -420,8 +420,13 @@ LogicalResult inferDynamicUpdateSliceOp( } // dynamic_update_slice_c1 + TypeTools tools(operand.getContext()); + auto common_vis = + tools.computeCommonVisibility({tools.getTypeVisibility(operandType), + tools.getTypeVisibility(updateType)}); inferredReturnTypes.emplace_back(RankedTensorType::get( - operandType.getShape(), operandType.getElementType())); + operandType.getShape(), + tools.getType(operandType.getElementType(), common_vis))); return success(); } diff --git a/sml/linear_model/BUILD.bazel b/sml/linear_model/BUILD.bazel index 202fb8800..fa4fdd158 100644 --- a/sml/linear_model/BUILD.bazel +++ b/sml/linear_model/BUILD.bazel @@ -61,4 +61,4 @@ py_library( deps = [ "//sml/linear_model/utils:_linprog_simplex", ], -) \ No newline at end of file +) diff --git a/sml/linear_model/emulations/BUILD.bazel b/sml/linear_model/emulations/BUILD.bazel index 1f603fe1f..46df0f929 100644 --- a/sml/linear_model/emulations/BUILD.bazel +++ b/sml/linear_model/emulations/BUILD.bazel @@ -70,4 +70,4 @@ py_binary( "//sml/linear_model:quantile", "//sml/utils:emulation", ], -) \ No newline at end of file +) diff --git a/sml/linear_model/quantile.py b/sml/linear_model/quantile.py index e14b7fe8a..ad7a2c8ba 100644 --- a/sml/linear_model/quantile.py +++ b/sml/linear_model/quantile.py @@ -183,51 +183,3 @@ def predict(self, X): return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_])) else: return jnp.dot(X, self.coef_) - - -def generate_data(): - import numpy as np - - np.random.seed(42) - - X = np.random.normal(size=(100, 2)) - - y = 3 * X[:, 0] + 2 * X[:, 1] + np.random.normal(size=(100,)) * 0.1 - return X, y - - -if __name__ == "__main__": - X, y = generate_data() - - quantile = 0.7 - alpha = 0.1 - fit_intercept = True - lr = 0.01 - max_iter = 300 - - quantile_custom = QuantileRegressor( - quantile=quantile, - alpha=alpha, - fit_intercept=fit_intercept, - lr=lr, - max_iter=max_iter, - ) - - quantile_custom_fit = quantile_custom.fit(X, y) - result = quantile_custom_fit.predict(X) - acc_custom = jnp.mean(y <= result) - print(result) - print(f"Accuracy in SPU: {acc_custom:.2f}") - print(quantile_custom_fit.coef_) - print(quantile_custom_fit.intercept_) - - from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor - - quantile_sklearn = SklearnQuantileRegressor( - quantile=quantile, alpha=alpha, fit_intercept=fit_intercept, solver='highs' - ) - quantile_sklearn_fit = quantile_sklearn.fit(X, y) - acc_sklearn = jnp.mean(y <= quantile_sklearn_fit.predict(X)) - print(f"Accuracy in SKlearn: {acc_sklearn:.2f}") - print(quantile_sklearn_fit.coef_) - print(quantile_sklearn_fit.intercept_) diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py index 511f4bf9d..54daa5ce4 100644 --- a/sml/linear_model/tests/quantile_test.py +++ b/sml/linear_model/tests/quantile_test.py @@ -23,6 +23,7 @@ import spu.utils.simulation as spsim from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor + class UnitTests(unittest.TestCase): def test_quantile(self): def proc_wrapper( @@ -51,12 +52,11 @@ def proc(X, y): def generate_data(): from jax import random + key = random.PRNGKey(42) key, subkey = random.split(key) X = random.normal(subkey, (100, 2)) - y = ( - 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 - ) + y = 5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1 return X, y # bandwidth and latency only work for docker mode diff --git a/sml/linear_model/utils/_linprog_simplex.py b/sml/linear_model/utils/_linprog_simplex.py index c1e8c7365..a08c2bd17 100644 --- a/sml/linear_model/utils/_linprog_simplex.py +++ b/sml/linear_model/utils/_linprog_simplex.py @@ -37,6 +37,7 @@ def _pivot_col(T, tol=1e-5, bland=False): return valid, result + def _pivot_row(T, basis, pivcol, phase, tol=1e-5, bland=False): def true_mask_func(T, pivcol): @@ -116,7 +117,7 @@ def _solve_simplex( ): status = 0 complete = False - + num = 0 pivcol = 0 pivrow = 0 @@ -141,8 +142,8 @@ def cal_pivcol_found_True( complete = jnp.where(pivcol_isnot_found, True, complete) pivcol_is_found = pivcol_found == True - pivrow_True, status_True, complete_True = ( - cal_pivcol_found_True(T, basis, pivcol, phase, tol, bland, status, complete) + pivrow_True, status_True, complete_True = cal_pivcol_found_True( + T, basis, pivcol, phase, tol, bland, status, complete ) pivrow = jnp.where(pivcol_is_found, pivrow_True, pivrow) From 8c1f00ef4a89c8a2c2aaa95e56f2d21c353a9889 Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Thu, 24 Oct 2024 13:06:52 +0800 Subject: [PATCH 3/4] fix problems --- sml/linear_model/emulations/quantile_emul.py | 24 +++-- sml/linear_model/quantile.py | 37 +++++--- sml/linear_model/tests/quantile_test.py | 16 ++-- sml/linear_model/utils/_linprog_simplex.py | 95 ++++++++------------ 4 files changed, 85 insertions(+), 87 deletions(-) diff --git a/sml/linear_model/emulations/quantile_emul.py b/sml/linear_model/emulations/quantile_emul.py index 64ed95d39..c7c08d5f6 100644 --- a/sml/linear_model/emulations/quantile_emul.py +++ b/sml/linear_model/emulations/quantile_emul.py @@ -41,7 +41,7 @@ def proc_wrapper( def proc(X, y): quantile_custom_fit = quantile_custom.fit(X, y) result = quantile_custom_fit.predict(X) - return result + return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_ return proc @@ -69,30 +69,36 @@ def generate_data(): # compare with sklearn quantile_sklearn = SklearnQuantileRegressor( - quantile=0.3, alpha=0.1, fit_intercept=True, solver='highs' + quantile=0.2, alpha=0.1, fit_intercept=True, solver='highs' ) start = time.time() quantile_sklearn_fit = quantile_sklearn.fit(X, y) - score_plain = jnp.mean(y <= quantile_sklearn_fit.predict(X)) + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) end = time.time() print(f"Running time in SKlearn: {end - start:.2f}s") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) # mark these data to be protected in SPU X_spu, y_spu = emulator.seal(X, y) # run + # Larger max_iter can give higher accuracy, but it will take more time to run proc = proc_wrapper( - quantile=0.3, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300 + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100 ) start = time.time() - result = emulator.run(proc)(X_spu, y_spu) + result, coef, intercept = emulator.run(proc)(X_spu, y_spu) end = time.time() - score_encrpted = jnp.mean(y <= result) + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) print(f"Running time in SPU: {end - start:.2f}s") + print(coef) + print(intercept) - # print acc - print(f"Accuracy in SKlearn: {score_plain:.2f}") - print(f"Accuracy in SPU: {score_encrpted:.2f}") + # print RMSE + print(f"RMSE in SKlearn: {rmse_plain:.2f}") + print(f"RMSE in SPU: {rmse_encrpted:.2f}") finally: emulator.down() diff --git a/sml/linear_model/quantile.py b/sml/linear_model/quantile.py index ad7a2c8ba..549e67aeb 100644 --- a/sml/linear_model/quantile.py +++ b/sml/linear_model/quantile.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numbers -import warnings -from warnings import warn - import jax import jax.numpy as jnp import pandas as pd @@ -46,6 +42,8 @@ class QuantileRegressor: The maximum number of iterations for the optimization algorithm. This controls how long the model will continue to update the weights before stopping. + max_val : float, default=1e10 + The maximum value allowed for the model parameters. Attributes ---------- coef_ : array-like of shape (n_features,) @@ -57,13 +55,20 @@ class QuantileRegressor: """ def __init__( - self, quantile=0.5, alpha=1.0, fit_intercept=True, lr=0.01, max_iter=1000 + self, + quantile=0.5, + alpha=1.0, + fit_intercept=True, + lr=0.01, + max_iter=1000, + max_val=1e10, ): self.quantile = quantile self.alpha = alpha self.fit_intercept = fit_intercept self.lr = lr self.max_iter = max_iter + self.max_val = max_val self.coef_ = None self.intercept_ = None @@ -94,7 +99,6 @@ def fit(self, X, y, sample_weight=None): n_samples, n_features = X.shape n_params = n_features - # sample_weight = jnp.ones((n_samples,)) if sample_weight is None: sample_weight = jnp.ones((n_samples,)) @@ -141,9 +145,11 @@ def fit(self, X, y, sample_weight=None): b = y - result = _linprog_simplex(c, A, b, maxiter=self.max_iter, tol=1e-3) + result = _linprog_simplex( + c, A, b, maxiter=self.max_iter, tol=1e-3, max_val=self.max_val + ) - solution = result[0] + solution = result params = solution[:n_params] - solution[n_params : 2 * n_params] @@ -177,9 +183,14 @@ def predict(self, X): - If there is no intercept, the method simply computes the dot product between `X` and the coefficients. """ - if self.fit_intercept: - X = jnp.column_stack((jnp.ones(X.shape[0]), X)) + assert ( + self.coef_ is not None and self.intercept_ is not None + ), "Model has not been fitted yet. Please fit the model before predicting." - return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_])) - else: - return jnp.dot(X, self.coef_) + n_features = len(self.coef_) + assert X.shape[1] == n_features, ( + f"Input X must have {n_features} features, " + f"but got {X.shape[1]} features instead." + ) + + return jnp.dot(X, self.coef_) + self.intercept_ diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py index 54daa5ce4..5e693ede3 100644 --- a/sml/linear_model/tests/quantile_test.py +++ b/sml/linear_model/tests/quantile_test.py @@ -15,8 +15,6 @@ import unittest import jax.numpy as jnp - -# import numpy as np from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor import spu.spu_pb2 as spu_pb2 # type: ignore @@ -71,20 +69,22 @@ def generate_data(): quantile=0.2, alpha=0.1, fit_intercept=True, solver='revised simplex' ) quantile_sklearn_fit = quantile_sklearn.fit(X, y) - acc_sklearn = jnp.mean(y <= quantile_sklearn_fit.predict(X)) - print(f"Accuracy in SKlearn: {acc_sklearn:.2f}") + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) + print(f"RMSE in SKlearn: {rmse_plain:.2f}") print(quantile_sklearn_fit.coef_) print(quantile_sklearn_fit.intercept_) # run + # Larger max_iter can give higher accuracy, but it will take more time to run proc = proc_wrapper( - quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300 + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100 ) result, coef, intercept = spsim.sim_jax(sim, proc)(X, y) - acc_custom = jnp.mean(y <= result) + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) - # print accuracy - print(f"Accuracy in SPU: {acc_custom:.2f}") + # print RMSE + print(f"RMSE in SPU: {rmse_encrpted:.2f}") print(coef) print(intercept) diff --git a/sml/linear_model/utils/_linprog_simplex.py b/sml/linear_model/utils/_linprog_simplex.py index a08c2bd17..a97a460f3 100644 --- a/sml/linear_model/utils/_linprog_simplex.py +++ b/sml/linear_model/utils/_linprog_simplex.py @@ -12,40 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from warnings import warn - import jax import jax.numpy as jnp from jax import jit, lax -def _pivot_col(T, tol=1e-5, bland=False): +def _pivot_col(T, tol=1e-5): mask = T[-1, :-1] >= -tol all_masked = jnp.all(mask) - bland_first_col = jnp.argmin(jnp.where(mask, jnp.inf, jnp.arange(T.shape[1] - 1))) # 定义根据最小值选择列的函数 ma = jnp.where(mask, jnp.inf, T[-1, :-1]) min_col = jnp.argmin(ma) - result = jnp.where(bland, bland_first_col, min_col) - valid = ~all_masked - result = jnp.where(all_masked, 0, result) + result = jnp.where(all_masked, 0, min_col) return valid, result -def _pivot_row(T, basis, pivcol, phase, tol=1e-5, bland=False): +def _pivot_row(T, basis, pivcol, phase, tol=1e-5, max_val=1e10): def true_mask_func(T, pivcol): - mask = T[:-2, pivcol] <= tol - ma = jnp.where(mask, jnp.inf, T[:-2, pivcol]) - mb = jnp.where(mask, jnp.inf, T[:-2, -1]) + if phase == 1: + k = 2 + else: + k = 1 + + mask = T[:-k, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-k, -1]) - q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma) + q = jnp.where(ma >= max_val, jnp.inf, mb / ma) # 选择最小比值的行 min_rows = jnp.nanargmin(q) @@ -53,11 +52,16 @@ def true_mask_func(T, pivcol): return min_rows, all_masked def false_mask_func(T, pivcol): - mask = T[:-1, pivcol] <= tol - ma = jnp.where(mask, jnp.inf, T[:-1, pivcol]) - mb = jnp.where(mask, jnp.inf, T[:-1, -1]) + if phase == 1: + k = 2 + else: + k = 1 - q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma) + mask = T[:-k, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-k, -1]) + + q = jnp.where(ma >= max_val, jnp.inf, mb / ma) # 选择最小比值的行 min_rows = jnp.nanargmin(q) @@ -69,20 +73,14 @@ def false_mask_func(T, pivcol): min_rows = jnp.where(phase == 1, true_min_rows, false_min_rows) all_masked = jnp.where(phase == 1, true_all_masked, false_all_masked) - # 检查掩码数组是否全被掩盖 - has_valid_row = min_rows.size > 0 row = min_rows - # 处理全被掩盖的情况 row = jnp.where(all_masked, 0, row) - # 处理没有满足条件的行的情况 - row = jnp.where(has_valid_row, row, 0) + return ~all_masked, row - return ~all_masked & has_valid_row, row - -def _apply_pivot(T, basis, pivrow, pivcol, tol=1e-5): +def _apply_pivot(T, basis, pivrow, pivcol): pivrow = jnp.int32(pivrow) pivcol = jnp.int32(pivcol) @@ -110,57 +108,45 @@ def _solve_simplex( T, n, basis, - maxiter=300, + maxiter=100, tol=1e-5, phase=2, - bland=False, ): - status = 0 complete = False num = 0 pivcol = 0 pivrow = 0 while num < maxiter: - pivcol_found, pivcol = _pivot_col(T, tol, bland) + pivcol_found, pivcol = _pivot_col(T, tol) - def cal_pivcol_found_True( - T, basis, pivcol, phase, tol, bland, status, complete - ): - pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol, bland) + def cal_pivcol_found_True(T, basis, pivcol, phase, tol, complete): + pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol) pivrow_isnot_found = pivrow_found == False - status = jnp.where(pivrow_isnot_found, 1, status) complete = jnp.where(pivrow_isnot_found, True, complete) - return pivrow, status, complete - - pivcol_isnot_found = pivcol_found == False - pivcol = jnp.where(pivcol_isnot_found, 0, pivcol) - pivrow = jnp.where(pivcol_isnot_found, 0, pivrow) - status = jnp.where(pivcol_isnot_found, 0, status) - complete = jnp.where(pivcol_isnot_found, True, complete) + return pivrow, complete pivcol_is_found = pivcol_found == True - pivrow_True, status_True, complete_True = cal_pivcol_found_True( - T, basis, pivcol, phase, tol, bland, status, complete + pivrow_True, complete_True = cal_pivcol_found_True( + T, basis, pivcol, phase, tol, complete ) - pivrow = jnp.where(pivcol_is_found, pivrow_True, pivrow) - status = jnp.where(pivcol_is_found, status_True, status) + pivrow = jnp.where(pivcol_is_found, pivrow_True, 0) + complete = jnp.where(pivcol_is_found, complete_True, complete) complete_is_False = complete == False - apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol, tol) + apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol) T = jnp.where(complete_is_False, apply_T, T) basis = jnp.where(complete_is_False, apply_basis, basis) num = num + 1 - return T, basis, status + return T, basis -def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, bland=False): - status = 0 +def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, max_val=1e10): n, m = A.shape # All constraints must have b >= 0. @@ -178,21 +164,16 @@ def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, bland=False): T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective)) # phase 1 - T, basis, status = _solve_simplex( - T, n, basis, maxiter=maxiter, tol=tol, phase=1, bland=bland - ) - - status = jnp.where(jnp.abs(T[-1, -1]) < tol, status, 1) + T, basis = _solve_simplex(T, n, basis, maxiter=maxiter, tol=tol, phase=1) T_new = T[:-1, :] - jit_delete = jit(jnp.delete, static_argnames=['assume_unique_indices']) T = jnp.delete(T_new, av, 1, assume_unique_indices=True) # phase 2 - T, basis, status = _solve_simplex(T, n, basis, maxiter, tol, 2, bland) + T, basis = _solve_simplex(T, n, basis, maxiter, tol, 2) solution = jnp.zeros(n + m) solution = solution.at[basis[:n]].set(T[:n, -1]) x = solution[:m] - return x, status + return x From d43ad9b4879699a229f1cf90326ce78b732a8e3b Mon Sep 17 00:00:00 2001 From: xbw886 <1042740841@qq.com> Date: Fri, 25 Oct 2024 01:14:02 +0800 Subject: [PATCH 4/4] fix problem --- sml/linear_model/emulations/quantile_emul.py | 2 +- sml/linear_model/tests/quantile_test.py | 2 +- sml/linear_model/utils/_linprog_simplex.py | 64 +++++++------------- 3 files changed, 24 insertions(+), 44 deletions(-) diff --git a/sml/linear_model/emulations/quantile_emul.py b/sml/linear_model/emulations/quantile_emul.py index c7c08d5f6..8f9e178dc 100644 --- a/sml/linear_model/emulations/quantile_emul.py +++ b/sml/linear_model/emulations/quantile_emul.py @@ -86,7 +86,7 @@ def generate_data(): # run # Larger max_iter can give higher accuracy, but it will take more time to run proc = proc_wrapper( - quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100 + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=200 ) start = time.time() result, coef, intercept = emulator.run(proc)(X_spu, y_spu) diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py index 5e693ede3..4f0c68751 100644 --- a/sml/linear_model/tests/quantile_test.py +++ b/sml/linear_model/tests/quantile_test.py @@ -78,7 +78,7 @@ def generate_data(): # run # Larger max_iter can give higher accuracy, but it will take more time to run proc = proc_wrapper( - quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100 + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=200 ) result, coef, intercept = spsim.sim_jax(sim, proc)(X, y) rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) diff --git a/sml/linear_model/utils/_linprog_simplex.py b/sml/linear_model/utils/_linprog_simplex.py index a97a460f3..0bd57a22f 100644 --- a/sml/linear_model/utils/_linprog_simplex.py +++ b/sml/linear_model/utils/_linprog_simplex.py @@ -32,46 +32,21 @@ def _pivot_col(T, tol=1e-5): return valid, result -def _pivot_row(T, basis, pivcol, phase, tol=1e-5, max_val=1e10): +def _pivot_row(T, pivcol, phase, tol=1e-5, max_val=1e10): + if phase == 1: + k = 2 + else: + k = 1 - def true_mask_func(T, pivcol): - if phase == 1: - k = 2 - else: - k = 1 + mask = T[:-k, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-k, -1]) - mask = T[:-k, pivcol] <= tol - ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) - mb = jnp.where(mask, jnp.inf, T[:-k, -1]) + q = jnp.where(ma >= max_val, jnp.inf, mb / ma) - q = jnp.where(ma >= max_val, jnp.inf, mb / ma) - - # 选择最小比值的行 - min_rows = jnp.nanargmin(q) - all_masked = jnp.all(mask) - return min_rows, all_masked - - def false_mask_func(T, pivcol): - if phase == 1: - k = 2 - else: - k = 1 - - mask = T[:-k, pivcol] <= tol - ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) - mb = jnp.where(mask, jnp.inf, T[:-k, -1]) - - q = jnp.where(ma >= max_val, jnp.inf, mb / ma) - - # 选择最小比值的行 - min_rows = jnp.nanargmin(q) - all_masked = jnp.all(mask) - return min_rows, all_masked - - true_min_rows, true_all_masked = true_mask_func(T, pivcol) - false_min_rows, false_all_masked = false_mask_func(T, pivcol) - min_rows = jnp.where(phase == 1, true_min_rows, false_min_rows) - all_masked = jnp.where(phase == 1, true_all_masked, false_all_masked) + # 选择最小比值的行 + min_rows = jnp.nanargmin(q) + all_masked = jnp.all(mask) row = min_rows # 处理全被掩盖的情况 @@ -110,6 +85,7 @@ def _solve_simplex( basis, maxiter=100, tol=1e-5, + max_val=1e10, phase=2, ): complete = False @@ -120,8 +96,8 @@ def _solve_simplex( while num < maxiter: pivcol_found, pivcol = _pivot_col(T, tol) - def cal_pivcol_found_True(T, basis, pivcol, phase, tol, complete): - pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol) + def cal_pivcol_found_True(T, pivcol, phase, tol, complete): + pivrow_found, pivrow = _pivot_row(T, pivcol, phase, tol, max_val) pivrow_isnot_found = pivrow_found == False complete = jnp.where(pivrow_isnot_found, True, complete) @@ -130,7 +106,7 @@ def cal_pivcol_found_True(T, basis, pivcol, phase, tol, complete): pivcol_is_found = pivcol_found == True pivrow_True, complete_True = cal_pivcol_found_True( - T, basis, pivcol, phase, tol, complete + T, pivcol, phase, tol, complete ) pivrow = jnp.where(pivcol_is_found, pivrow_True, 0) @@ -164,13 +140,17 @@ def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, max_val=1e10): T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective)) # phase 1 - T, basis = _solve_simplex(T, n, basis, maxiter=maxiter, tol=tol, phase=1) + T, basis = _solve_simplex( + T, n, basis, maxiter=maxiter, tol=tol, max_val=max_val, phase=1 + ) T_new = T[:-1, :] T = jnp.delete(T_new, av, 1, assume_unique_indices=True) # phase 2 - T, basis = _solve_simplex(T, n, basis, maxiter, tol, 2) + T, basis = _solve_simplex( + T, n, basis, maxiter=maxiter, tol=tol, max_val=max_val, phase=2 + ) solution = jnp.zeros(n + m) solution = solution.at[basis[:n]].set(T[:n, -1])