From 977e93244fb8fb0d6c24750d8291351d23ebc8d4 Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:17:21 +0800 Subject: [PATCH] Resolve format issue (#352) # Pull Request ## What problem does this PR solve? Issue Number: Fixed # ## Possible side effects? - Performance: - Backward compatibility: --- sml/glm/glm.py | 90 ++++++++++++++++++++------------ sml/glm/glm_emul.py | 15 ++++-- sml/glm/glm_test.py | 109 ++++++++++++++++++++++---------------- sml/glm/utils/func.py | 10 ++-- sml/glm/utils/link.py | 1 + sml/glm/utils/loss.py | 20 +++---- sml/glm/utils/solver.py | 113 ++++++++++++++++++++++++++-------------- 7 files changed, 225 insertions(+), 133 deletions(-) diff --git a/sml/glm/glm.py b/sml/glm/glm.py index 0d76310e..c356281f 100644 --- a/sml/glm/glm.py +++ b/sml/glm/glm.py @@ -6,20 +6,23 @@ from utils.link import * import warnings import os + DEBUG = 0 + # Define the _GeneralizedLinearRegressor class using JAX class _GeneralizedLinearRegressor: - def __init__(self, - fit_intercept=True, # Whether to fit the intercept term, default is True - alpha=0, # L2 regularization strength, default is 0 (no regularization) - solver="newton-cholesky", # Optimization algorithm, default is Newton-Cholesky - max_iter=20, # Maximum number of iterations, default is 20 - warm_start=False, # Whether to use warm start, default is False - n_threads=None, # Deprecated parameter (no longer used) - tol=None, # Deprecated parameter (no longer used) - verbose=0 # Level of verbosity, default is 0 (no output) - ): + def __init__( + self, + fit_intercept=True, # Whether to fit the intercept term, default is True + alpha=0, # L2 regularization strength, default is 0 (no regularization) + solver="newton-cholesky", # Optimization algorithm, default is Newton-Cholesky + max_iter=20, # Maximum number of iterations, default is 20 + warm_start=False, # Whether to use warm start, default is False + n_threads=None, # Deprecated parameter (no longer used) + tol=None, # Deprecated parameter (no longer used) + verbose=0, # Level of verbosity, default is 0 (no output) + ): """ Initialize the generalized linear regression model. @@ -51,11 +54,19 @@ def __init__(self, self.warm_start = warm_start self.verbose = verbose if n_threads: - warnings.warn("SPU does not need n_threads.", category=DeprecationWarning, stacklevel=2) + warnings.warn( + "SPU does not need n_threads.", + category=DeprecationWarning, + stacklevel=2, + ) if warm_start: warnings.warn("Using minibatch in the second optimizer may cause problems.") if tol: - warnings.warn("SPU does not support early stop.", category=DeprecationWarning, stacklevel=2) + warnings.warn( + "SPU does not support early stop.", + category=DeprecationWarning, + stacklevel=2, + ) def fit(self, X, y, sample_weight=None): if sample_weight is None: @@ -70,7 +81,10 @@ def fit(self, X, y, sample_weight=None): if not self.warm_start or not hasattr(self, "coef_"): self.coef_ = None if self.solver == "lbfgs": - warnings.warn("LBFGS algorithm cannot be accurately implemented on SPU platform, only approximate implementation is available.", UserWarning) + warnings.warn( + "LBFGS algorithm cannot be accurately implemented on SPU platform, only approximate implementation is available.", + UserWarning, + ) self._fit_lbfgs(X, y) elif self.solver == "newton-cholesky": self._fit_newton_cholesky(X, y) @@ -85,22 +99,26 @@ def _get_link(self): def _fit_newton_cholesky(self, X, y): # Use the NewtonCholeskySolver class to implement the Newton-Cholesky optimization algorithm - solver = NewtonCholeskySolver(loss_model=self.loss_model, - l2_reg_strength=self.l2_reg_strength, - max_iter=self.max_iter, - verbose=self.verbose, - link=self.link_model, - coef=self.coef_) + solver = NewtonCholeskySolver( + loss_model=self.loss_model, + l2_reg_strength=self.l2_reg_strength, + max_iter=self.max_iter, + verbose=self.verbose, + link=self.link_model, + coef=self.coef_, + ) self.coef_ = solver.solve(X, y) def _fit_lbfgs(self, X, y): # Use the LBFGSSolver class to implement the Newton-Cholesky optimization algorithm - solver = LBFGSSolver(loss_model=self.loss_model, - max_iter=self.max_iter, - l2_reg_strength=self.l2_reg_strength, - verbose=self.verbose, - link=self.link_model, - coef=self.coef_) + solver = LBFGSSolver( + loss_model=self.loss_model, + max_iter=self.max_iter, + l2_reg_strength=self.l2_reg_strength, + verbose=self.verbose, + link=self.link_model, + coef=self.coef_, + ) self.coef_ = solver.solve(X, y) def predict(self, X): @@ -117,20 +135,22 @@ def score(self, X, y, sample_weight=None): # Calculate the model's predictions prediction = self.predict(X) - squared_error = lambda y_true, prediction: jnp.mean( - (y_true - prediction)**2) + squared_error = lambda y_true, prediction: jnp.mean((y_true - prediction) ** 2) # Calculate the model's deviance deviance = squared_error(y_true=y, prediction=prediction) # Calculate the null deviance - deviance_null = squared_error(y_true=y, - prediction=jnp.tile( - jnp.average(y), y.shape[0])) + deviance_null = squared_error( + y_true=y, prediction=jnp.tile(jnp.average(y), y.shape[0]) + ) # Calculate D^2 d2 = 1 - (deviance) / (deviance_null) return d2 def _check_solver_support(self): - supported_solvers = ["lbfgs", "newton-cholesky"] # List of supported optimization algorithms + supported_solvers = [ + "lbfgs", + "newton-cholesky", + ] # List of supported optimization algorithms if self.solver not in supported_solvers: raise ValueError( f"Invalid solver={self.solver}. Supported solvers are {supported_solvers}." @@ -143,6 +163,7 @@ class PoissonRegressor(_GeneralizedLinearRegressor): This regressor uses the 'log' link function. """ + def _get_loss(self): return HalfPoissonLoss() @@ -158,6 +179,7 @@ def _get_loss(self): def _get_link(self): return LogLink() + # The TweedieRegressor class represents a generalized linear model with Tweedie distribution using JAX. class TweedieRegressor(_GeneralizedLinearRegressor): def __init__( @@ -166,11 +188,13 @@ def __init__( ): super().__init__() # Ensure that the power is within the valid range for the Tweedie distribution - assert(power>=0 and power<=3) + assert power >= 0 and power <= 3 self.power = power def _get_loss(self): - return HalfTweedieLoss(self.power, ) + return HalfTweedieLoss( + self.power, + ) def _get_link(self): if self.power > 0: diff --git a/sml/glm/glm_emul.py b/sml/glm/glm_emul.py index 884dd1fb..05deaf43 100644 --- a/sml/glm/glm_emul.py +++ b/sml/glm/glm_emul.py @@ -7,9 +7,16 @@ sys.path.append('../../') import sml.utils.emulation as emulation import spu.utils.distributed as ppd -from glm import _GeneralizedLinearRegressor, PoissonRegressor, GammaRegressor, TweedieRegressor +from glm import ( + _GeneralizedLinearRegressor, + PoissonRegressor, + GammaRegressor, + TweedieRegressor, +) n_samples, n_features = 100, 5 + + def generate_data(noise=False): """ Generate random data for testing. @@ -39,8 +46,10 @@ def generate_data(noise=False): sample_weight = np.random.rand(n_samples) return X, y, coef, sample_weight + X, y, coef, sample_weight = generate_data() + def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS, num=10): """ Execute the encrypted SGD classifier in a simulation environment and output the results. @@ -85,9 +94,7 @@ def proc_ncSolver(X, y): # Specify the file paths for cluster and dataset CLUSTER_ABY3_3PC = os.path.join('../../', emulation.CLUSTER_ABY3_3PC) # Create the emulator with specified mode and bandwidth/latency settings - emulator = emulation.Emulator( - CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 - ) + emulator = emulation.Emulator(CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20) emulator.up() # Run the proc_ncSolver function using both plaintext and encrypted data diff --git a/sml/glm/glm_test.py b/sml/glm/glm_test.py index 9caca85f..467487af 100644 --- a/sml/glm/glm_test.py +++ b/sml/glm/glm_test.py @@ -3,15 +3,25 @@ import jax.numpy as jnp import spu.spu_pb2 as spu_pb2 import spu.utils.simulation as spsim -from glm import _GeneralizedLinearRegressor, PoissonRegressor, GammaRegressor, TweedieRegressor +from glm import ( + _GeneralizedLinearRegressor, + PoissonRegressor, + GammaRegressor, + TweedieRegressor, +) import numpy as np import scipy.stats as stats -from sklearn.linear_model._glm import _GeneralizedLinearRegressor as std__GeneralizedLinearRegressor -from sklearn.linear_model._glm import PoissonRegressor as std_PoissonRegressor +from sklearn.linear_model._glm import ( + _GeneralizedLinearRegressor as std__GeneralizedLinearRegressor, +) +from sklearn.linear_model._glm import PoissonRegressor as std_PoissonRegressor from sklearn.linear_model._glm import GammaRegressor as std_GammaRegressor from sklearn.linear_model._glm import TweedieRegressor as std_TweedieRegressor + verbose = 0 n_samples, n_features = 100, 5 + + def generate_data(noise=False): """ Generate random data for testing. @@ -41,49 +51,54 @@ def generate_data(noise=False): sample_weight = np.random.rand(n_samples) return X, y, coef, sample_weight + X, y, coef, sample_weight = generate_data() exp_y = jnp.exp(y) round_exp_y = jnp.round(exp_y) sim = spsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128) -def accuracy_test(model,std_model, y, coef, num=5): - """ - Test the fitting, prediction, and scoring functionality of the generalized linear regression model. - - Parameters: - ---------- - model : object - Generalized linear regression model object. - X : array-like, shape (n_samples, n_features) - Feature data. - y : array-like, shape (n_samples,) - Target data. - coef : array-like, shape (n_features + 1,) - True coefficients, including the intercept term and feature weights. - num : int, optional (default=5) - Number of coefficients to display. - - Returns: - ------- - None - - """ - model.fit(X, y, sample_weight) - std_model.fit(X,y,sample_weight) - norm_diff = jnp.linalg.norm(model.predict(X)[:num]-jnp.array(std_model.predict(X)[:num])) - if verbose: - print('True Coefficients:', coef[:num]) - print("Fitted Coefficients:", model.coef_[:num]) - print("std Fitted Coefficients:", std_model.coef_[:num]) - print("D^2 Score:", model.score(X[:num], y[:num])) - print("X:", X[:num]) - print("Samples:", y[:num]) - print("Predictions:", model.predict(X[:num])) - print("std Predictions:", std_model.predict(X[:num])) - print("norm of predict between ours and std: %f" %norm_diff) - print("_________________________________") - print() - assert norm_diff < 1e-2 + +def accuracy_test(model, std_model, y, coef, num=5): + """ + Test the fitting, prediction, and scoring functionality of the generalized linear regression model. + + Parameters: + ---------- + model : object + Generalized linear regression model object. + X : array-like, shape (n_samples, n_features) + Feature data. + y : array-like, shape (n_samples,) + Target data. + coef : array-like, shape (n_features + 1,) + True coefficients, including the intercept term and feature weights. + num : int, optional (default=5) + Number of coefficients to display. + + Returns: + ------- + None + + """ + model.fit(X, y, sample_weight) + std_model.fit(X, y, sample_weight) + norm_diff = jnp.linalg.norm( + model.predict(X)[:num] - jnp.array(std_model.predict(X)[:num]) + ) + if verbose: + print('True Coefficients:', coef[:num]) + print("Fitted Coefficients:", model.coef_[:num]) + print("std Fitted Coefficients:", std_model.coef_[:num]) + print("D^2 Score:", model.score(X[:num], y[:num])) + print("X:", X[:num]) + print("Samples:", y[:num]) + print("Predictions:", model.predict(X[:num])) + print("std Predictions:", std_model.predict(X[:num])) + print("norm of predict between ours and std: %f" % norm_diff) + print("_________________________________") + print() + assert norm_diff < 1e-2 + def proc_test(proc): """ @@ -111,6 +126,7 @@ def proc_test(proc): # Assert that the difference is within the tolerance assert norm_diff < 1e-4 + def proc_ncSolver(): """ Fit Generalized Linear Regression model using Newton-Cholesky algorithm and return the model coefficients. @@ -125,6 +141,7 @@ def proc_ncSolver(): model.fit(X, y) return model.coef_ + def proc_lbfgsSolver(): """ Fit Generalized Linear Regression model using Newton-Cholesky algorithm and return the model coefficients. @@ -139,6 +156,7 @@ def proc_lbfgsSolver(): model.fit(X, y) return model.coef_ + def proc_Poisson(): """ Fit Generalized Linear Regression model using PoissonRegressor and return the model coefficients. @@ -153,6 +171,7 @@ def proc_Poisson(): model.fit(X, round_exp_y) return model.coef_ + def proc_Gamma(): """ Fit Generalized Linear Regression model using GammaRegressor and return the model coefficients. @@ -167,6 +186,7 @@ def proc_Gamma(): model.fit(X, exp_y) return model.coef_ + def proc_Tweedie(): """ Fit Generalized Linear Regression model using TweedieRegressor and return the model coefficients. @@ -204,10 +224,10 @@ def test_gamma_accuracy(self): accuracy_test(model, std_model, exp_y, coef) print('test_gamma_accuracy: OK') - def test_Tweedie_accuracy(self,power=0): + def test_Tweedie_accuracy(self, power=0): # Test the accuracy of the TweedieRegressor model model = TweedieRegressor(power=power) - std_model = std_TweedieRegressor(alpha=0,power=power) + std_model = std_TweedieRegressor(alpha=0, power=power) accuracy_test(model, std_model, exp_y, coef) print('test_Tweedie_accuracy: OK') @@ -231,11 +251,12 @@ def test_gamma_encrypted(self): proc_test(proc_Gamma) print('test_gamma_encrypted: OK') - def test_Tweedie_encrypted(self,power=0): + def test_Tweedie_encrypted(self, power=0): # Test if the results of the TweedieRegressor model are correct after encryption proc_test(proc_Tweedie) print('test_Tweedie_encrypted: OK') + if __name__ == '__main__': # Run the unit tests unittest.main() diff --git a/sml/glm/utils/func.py b/sml/glm/utils/func.py index 35f2e611..eed00fd5 100644 --- a/sml/glm/utils/func.py +++ b/sml/glm/utils/func.py @@ -1,4 +1,6 @@ import numpy as np + + def _check_sample_weight(X, sample_weight): """Set sample_weight if None, and check for correct dtype""" n_samples = X.shape[0] @@ -7,8 +9,10 @@ def _check_sample_weight(X, sample_weight): else: sample_weight = np.asarray(sample_weight) if n_samples != len(sample_weight): - raise ValueError("n_samples=%d should be == len(sample_weight)=%d" - % (n_samples, len(sample_weight))) + raise ValueError( + "n_samples=%d should be == len(sample_weight)=%d" + % (n_samples, len(sample_weight)) + ) # normalize the weights to sum up to n_samples scale = n_samples / sample_weight.sum() - return (sample_weight * scale).astype(X.dtype) \ No newline at end of file + return (sample_weight * scale).astype(X.dtype) diff --git a/sml/glm/utils/link.py b/sml/glm/utils/link.py index 42adc5ac..83e9ec79 100644 --- a/sml/glm/utils/link.py +++ b/sml/glm/utils/link.py @@ -3,6 +3,7 @@ from dataclasses import dataclass import jax.numpy as jnp + # Define a dataclass to represent an interval @dataclass class Interval: diff --git a/sml/glm/utils/loss.py b/sml/glm/utils/loss.py index 5fc2da77..e7d8a3a0 100644 --- a/sml/glm/utils/loss.py +++ b/sml/glm/utils/loss.py @@ -1,6 +1,7 @@ import jax import jax.numpy as jnp + class BaseLoss: def get_sampleweight(self, sample_weight): """ @@ -28,9 +29,9 @@ def __call__(self, y_true, y_pred): Average half squared loss value. """ - y_t,y_p=y_true,y_pred + y_t, y_p = y_true, y_pred w = self.sample_weight - return jnp.sum(((y_t - y_p) ** 2 / 2) *w) + return jnp.sum(((y_t - y_p) ** 2 / 2) * w) class HalfPoissonLoss(BaseLoss): @@ -51,9 +52,10 @@ def __call__(self, y_true, y_pred): Average half Poisson loss value. """ - y_t,y_p=y_true,y_pred + y_t, y_p = y_true, y_pred w = self.sample_weight - return jnp.sum((y_p - y_t * jnp.log(y_p))*w) + return jnp.sum((y_p - y_t * jnp.log(y_p)) * w) + class HalfGammaLoss(BaseLoss): def __call__(self, y_true, y_pred): @@ -73,9 +75,10 @@ def __call__(self, y_true, y_pred): Average half Gamma loss value. """ - y_t,y_p=y_true,y_pred + y_t, y_p = y_true, y_pred w = self.sample_weight - return jnp.sum(w *(jnp.log(y_p / y_t) + y_t / y_p - 1)) + return jnp.sum(w * (jnp.log(y_p / y_t) + y_t / y_p - 1)) + class HalfTweedieLoss(BaseLoss): def __init__(self, power=0.5): @@ -112,7 +115,6 @@ def __call__(self, y_true, y_pred): """ p = self.power - y_t,y_p=y_true,y_pred + y_t, y_p = y_true, y_pred w = self.sample_weight - return jnp.sum((y_p ** (2 - p) / (2 - p) - y_t * y_p ** (1 - p) / (1 - p))*w) - + return jnp.sum((y_p ** (2 - p) / (2 - p) - y_t * y_p ** (1 - p) / (1 - p)) * w) diff --git a/sml/glm/utils/solver.py b/sml/glm/utils/solver.py index a5f13ca5..f8240aec 100644 --- a/sml/glm/utils/solver.py +++ b/sml/glm/utils/solver.py @@ -6,16 +6,19 @@ DEBUG = 0 + class Solver(ABC): - def __init__(self, - loss_model, - link, - max_iter=100, - l2_reg_strength=1, - n_threads=None, - fit_intercept=True, - verbose=0, - coef=None): + def __init__( + self, + loss_model, + link, + max_iter=100, + l2_reg_strength=1, + n_threads=None, + fit_intercept=True, + verbose=0, + coef=None, + ): self.loss_model = loss_model self.max_iter = max_iter self.n_threads = n_threads @@ -36,12 +39,18 @@ def solve(self, X, y, sample_weight=None): if self.fit_intercept: X = jnp.hstack([jnp.ones((n_samples, 1)), X]) # Add the intercept term if not self.coef: - self.coef = jnp.full((n_features + 1, ), 0.5) # Initialize coef using np.random.rand (uniform distribution between 0 and 1) + self.coef = jnp.full( + (n_features + 1,), 0.5 + ) # Initialize coef using np.random.rand (uniform distribution between 0 and 1) else: if not self.coef: - self.coef = jnp.full((n_features, ), 0.5) # Initialize coef using np.random.rand (uniform distribution between 0 and 1) - self.objective = lambda coef: self.loss_model( - y, self.link.inverse(X @ coef)) + jnp.linalg.norm(coef) * self.l2_reg_strength / 2 + self.coef = jnp.full( + (n_features,), 0.5 + ) # Initialize coef using np.random.rand (uniform distribution between 0 and 1) + self.objective = ( + lambda coef: self.loss_model(y, self.link.inverse(X @ coef)) + + jnp.linalg.norm(coef) * self.l2_reg_strength / 2 + ) self.objective_grad = jit(jax.grad(self.objective)) self.hessian_fn = jit(jax.hessian(self.objective)) return X @@ -53,15 +62,17 @@ def iteration(self): # Define NewtonCholeskySolver class using JAX class NewtonCholeskySolver(Solver): - def __init__(self, - loss_model, - link, - l2_reg_strength=1.0, - max_iter=100, - n_threads=None, - fit_intercept=True, - verbose=0, - coef=None): + def __init__( + self, + loss_model, + link, + l2_reg_strength=1.0, + max_iter=100, + n_threads=None, + fit_intercept=True, + verbose=0, + coef=None, + ): """ Solver for Newton-Cholesky optimization algorithm. @@ -85,8 +96,16 @@ def __init__(self, Initial coefficient values. Default is None. """ - super().__init__(loss_model, link, max_iter, l2_reg_strength, n_threads, fit_intercept, - verbose, coef) + super().__init__( + loss_model, + link, + max_iter, + l2_reg_strength, + n_threads, + fit_intercept, + verbose, + coef, + ) def solve(self, X, y, sample_weight=None): """ @@ -124,15 +143,17 @@ def cho_solve_wrapper(a, b): class LBFGSSolver(Solver): - def __init__(self, - loss_model, - link, - max_iter=100, - l2_reg_strength=1.0, - n_threads=None, - fit_intercept=True, - verbose=0, - coef=None): + def __init__( + self, + loss_model, + link, + max_iter=100, + l2_reg_strength=1.0, + n_threads=None, + fit_intercept=True, + verbose=0, + coef=None, + ): """ Implementation of LBFGS optimization algorithm for generalized linear regression. @@ -165,7 +186,16 @@ def __init__(self, A parameter in BFGS algorithm. """ - super().__init__(loss_model, link, max_iter, l2_reg_strength, n_threads, fit_intercept, verbose, coef) + super().__init__( + loss_model, + link, + max_iter, + l2_reg_strength, + n_threads, + fit_intercept, + verbose, + coef, + ) self.maxcor = 10 self.maxls = 3 self.gamma = 1 @@ -194,7 +224,7 @@ def solve(self, X, y, sample_weight=None): d = len(self.coef) self.s_history = jnp.zeros((self.maxcor, d)) self.y_history = jnp.zeros((self.maxcor, d)) - self.rho_history = jnp.zeros((self.maxcor, )) + self.rho_history = jnp.zeros((self.maxcor,)) f_k, g_k = jax.value_and_grad(self.objective)(self.coef) for self.i in range(self.max_iter): @@ -218,7 +248,7 @@ def _two_loop_recursion(self, g_k): his_size = len(self.rho_history) curr_size = his_size q = -jnp.conj(g_k) - a_his = jnp.zeros((self.maxcor, )) + a_his = jnp.zeros((self.maxcor,)) for j in range(his_size): i = his_size - 1 - j @@ -231,16 +261,19 @@ def _two_loop_recursion(self, g_k): for j in range(his_size): b_i = self.rho_history[j] * (self.y_history[j] @ q) q = q + (a_his[j] - b_i) * self.s_history[j] - return q + return q def _line_search(self, p_k, f_k, g_k): - a_k = 0.96 ** self.i + a_k = 0.96**self.i + # Build a local quadratic model using quasi-Newton method def quadratic_model(a): f_a = a * p_k @ g_k - return jnp.abs(jnp.abs(f_a) - jnp.abs(f_k)) / max(jnp.abs(f_a), jnp.abs(f_k)) + return jnp.abs(jnp.abs(f_a) - jnp.abs(f_k)) / max( + jnp.abs(f_a), jnp.abs(f_k) + ) alpha = 0.9 # alpha = quadratic_model(a_k) - a_k *= alpha ** self.maxls + a_k *= alpha**self.maxls return a_k