-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[linear-ridge-cholesky]使用 SPU 实现线性回归模型 (#297)
### What problem does this PR solve? I have read the CLA Document and I hereby sign the CLA Issue Number: Fixed #274 使用cholesky分解法实现ridge,在spu镜像中完成emul和test单元测试 ### Possible side effects? - Performance: 在diabetes数据集上,与sklearn的误差在0.01-0.02之间 - Backward compatibility: --------- Signed-off-by: magic-hya <huangya@asiainfo.com>
- Loading branch information
Showing
6 changed files
with
312 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,3 +31,8 @@ py_library( | |
"//sml/utils:fxp_approx", | ||
], | ||
) | ||
|
||
py_library( | ||
name = "ridge", | ||
srcs = ["ridge.py"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import jax.numpy as jnp | ||
from sklearn.linear_model import Ridge as skRidge | ||
|
||
import examples.python.utils.dataset_utils as dsutil | ||
import sml.utils.emulation as emulation | ||
from sml.linear_model.ridge import Ridge | ||
|
||
|
||
# TODO: design the enumation framework, just like py.unittest | ||
# all emulation action should begin with `emul_` (for reflection) | ||
def emul_Ridge(mode: emulation.Mode.MULTIPROCESS): | ||
def proc(x1, x2, y): | ||
model = Ridge(alpha=1.0, solver="cholesky") | ||
|
||
x = jnp.concatenate((x1, x2), axis=1) | ||
y = y.reshape((y.shape[0], 1)) | ||
|
||
return model.fit(x, y).predict(x) | ||
|
||
def load_data(): | ||
dataset_config = { | ||
"use_mock_data": False, | ||
"problem_type": "regression", | ||
"builtin_dataset_name": "diabetes", | ||
"left_slice_feature_ratio": 0.5, | ||
} | ||
|
||
x1, x2, y = dsutil.load_dataset_by_config(dataset_config) | ||
|
||
return x1, x2, y | ||
|
||
try: | ||
# bandwidth and latency only work for docker mode | ||
emulator = emulation.Emulator( | ||
emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20 | ||
) | ||
emulator.up() | ||
|
||
# load mock data | ||
x1, x2, y = load_data() | ||
|
||
# sklearn test | ||
x = jnp.concatenate((x1, x2), axis=1) | ||
sklearn_result = ( | ||
skRidge(alpha=1, solver='cholesky', fit_intercept=True).fit(x, y).predict(x) | ||
) | ||
print("[sklearn_result]---------------------------------------------") | ||
print(sklearn_result[:10]) | ||
|
||
# mark these data to be protected in SPU | ||
x1, x2, y = emulator.seal(x1, x2, y) | ||
|
||
# run | ||
result = emulator.run(proc)(x1, x2, y) | ||
print("[emul_result]------------------------------------------------") | ||
print(result[:10]) | ||
|
||
# absolute_error | ||
print("[absolute_error]---------------------------------------------") | ||
print(jnp.round(jnp.abs(result - sklearn_result)[:20], 5)) | ||
finally: | ||
emulator.down() | ||
|
||
|
||
if __name__ == "__main__": | ||
emul_Ridge(emulation.Mode.MULTIPROCESS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from enum import Enum | ||
|
||
import jax.numpy as jnp | ||
import jax.scipy as jsci | ||
|
||
|
||
class Solver(Enum): | ||
SVD = 'svd' # not supported | ||
CHOLESKY = 'cholesky' | ||
|
||
|
||
class Ridge: | ||
"""Linear least squares with l2 regularization. | ||
Minimizes the objective function:: | ||
||y - Xw||^2_2 + alpha * ||w||^2_2 | ||
This model solves a regression model where the loss function is | ||
the linear least squares function and regularization is given by | ||
the l2-norm. Also known as Ridge Regression or Tikhonov regularization. | ||
Parameters | ||
---------- | ||
alpha : {float}, default=1.0 | ||
Constant that multiplies the L2 term, controlling regularization | ||
strength. `alpha` must be a non-negative float i.e. in `[0, inf)`. | ||
fit_bias : bool, default=True | ||
Whether to fit the bias for this model. If set | ||
to false, no bias will be used in calculations | ||
solver : {'svd', 'cholesky'}, default='cholesky' | ||
Solver to use in the computational routines: | ||
- 'svd' uses a Singular Value Decomposition of X to compute the Ridge | ||
coefficients. | ||
- 'cholesky' uses the standard jax.scipy.linalg.solve function to | ||
obtain a closed-form solution via a Cholesky decomposition of | ||
dot(X.T, X) | ||
""" | ||
|
||
def __init__(self, alpha=1.0, fit_bias=True, solver="cholesky") -> None: | ||
self.alpha = alpha | ||
self.solver = solver | ||
self.fit_bias = fit_bias | ||
|
||
def fit(self, x, y): | ||
"""Fit Ridge regression model. | ||
Parameters | ||
---------- | ||
X : {array-like}, shape (n_samples, n_features) | ||
Training data. | ||
y : ndarray of shape (n_samples,) | ||
Target values. | ||
Returns | ||
------- | ||
self : object | ||
Returns an instance of self. | ||
""" | ||
if y.ndim == 1: | ||
y = y.reshape(-1, 1) | ||
alpha = float(self.alpha) | ||
|
||
x, y, x_offset, y_offset = self.preprocess_data(x, y) | ||
|
||
if self.solver == Solver.CHOLESKY.value: | ||
self.coef = _solve_cholesky(x, y, alpha) | ||
self.coef = self.coef.ravel() | ||
|
||
self.set_bias(x_offset, y_offset) | ||
|
||
return self | ||
|
||
def predict(self, x): | ||
""" | ||
Predict using the linear model. | ||
Parameters | ||
---------- | ||
X : {array-like}, shape (n_samples, n_features) | ||
Input data for prediction. | ||
Returns | ||
------- | ||
C : array, shape (n_samples,) | ||
Returns predicted values. | ||
""" | ||
a = x | ||
b = self.coef.T | ||
ret = jnp.dot(a, b) + self.bias | ||
return ret | ||
|
||
def preprocess_data(self, x, y): | ||
# Center and scale data. | ||
if self.fit_bias: | ||
x_offset = jnp.average(x, axis=0) | ||
x -= x_offset | ||
y_offset = jnp.average(y, axis=0) | ||
y -= y_offset | ||
else: | ||
x_offset = None | ||
y_offset = None | ||
return x, y, x_offset, y_offset | ||
|
||
def set_bias(self, x_offset, y_offset): | ||
if self.fit_bias: | ||
self.bias = y_offset - jnp.dot(x_offset, self.coef.T) | ||
else: | ||
self.bias = 0.0 | ||
|
||
|
||
def _solve_cholesky(x, y, alpha): | ||
# w = inv(X^t X + alpha*Id) * X.T y | ||
n_features = x.shape[1] | ||
|
||
A = jnp.dot(x.T, x) | ||
Xy = jnp.dot(x.T, y) | ||
|
||
A += jnp.diag(jnp.ones(n_features) * alpha) | ||
|
||
coefs = jsci.linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T | ||
return coefs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2023 Ant Group Co., Ltd. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
|
||
import jax.numpy as jnp | ||
from sklearn.linear_model import Ridge as skRidge | ||
|
||
import examples.python.utils.dataset_utils as dsutil | ||
import spu.spu_pb2 as spu_pb2 # type: ignore | ||
import spu.utils.simulation as spsim | ||
from sml.linear_model.ridge import Ridge | ||
|
||
|
||
class UnitTests(unittest.TestCase): | ||
def test_ridge(self): | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 | ||
) | ||
|
||
def proc(x1, x2, y): | ||
model = Ridge(alpha=1.0, solver="cholesky") | ||
|
||
x = jnp.concatenate((x1, x2), axis=1) | ||
y = y.reshape((y.shape[0], 1)) | ||
|
||
result = model.fit(x, y).predict(x) | ||
return result | ||
|
||
dataset_config = { | ||
"use_mock_data": False, | ||
"problem_type": "regression", | ||
"builtin_dataset_name": "diabetes", | ||
"left_slice_feature_ratio": 0.5, | ||
} | ||
|
||
x1, x2, y = dsutil.load_dataset_by_config(dataset_config) | ||
|
||
# sklearn test | ||
x = jnp.concatenate((x1, x2), axis=1) | ||
sklearn_result = ( | ||
skRidge(alpha=1, solver='cholesky', fit_intercept=True).fit(x, y).predict(x) | ||
) | ||
print("[sklearn_result]---------------------------------------------") | ||
print(sklearn_result[:10]) | ||
|
||
result = spsim.sim_jax(sim, proc)(x1, x2, y) | ||
print("[spsim_result]-----------------------------------------------") | ||
print(result[:10]) | ||
|
||
# absolute_error | ||
print("[absolute_error]---------------------------------------------") | ||
print(jnp.round(jnp.abs(result - sklearn_result)[:20], 5)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |