Skip to content

Commit

Permalink
[linear-ridge-cholesky]使用 SPU 实现线性回归模型 (#297)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

I have read the CLA Document and I hereby sign the CLA
Issue Number: Fixed #274

使用cholesky分解法实现ridge,在spu镜像中完成emul和test单元测试
### Possible side effects?
- Performance:
在diabetes数据集上,与sklearn的误差在0.01-0.02之间
- Backward compatibility:

---------

Signed-off-by: magic-hya <huangya@asiainfo.com>
  • Loading branch information
magic-hya authored Aug 14, 2023
1 parent 25ae05e commit 9857c82
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 0 deletions.
5 changes: 5 additions & 0 deletions sml/linear_model/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ py_library(
"//sml/utils:fxp_approx",
],
)

py_library(
name = "ridge",
srcs = ["ridge.py"],
)
10 changes: 10 additions & 0 deletions sml/linear_model/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,13 @@ py_binary(
"//sml/utils:emulation",
],
)

py_binary(
name = "ridge_emul",
srcs = ["ridge_emul.py"],
deps = [
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency
"//sml/linear_model:ridge",
"//sml/utils:emulation",
],
)
79 changes: 79 additions & 0 deletions sml/linear_model/emulations/ridge_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp
from sklearn.linear_model import Ridge as skRidge

import examples.python.utils.dataset_utils as dsutil
import sml.utils.emulation as emulation
from sml.linear_model.ridge import Ridge


# TODO: design the enumation framework, just like py.unittest
# all emulation action should begin with `emul_` (for reflection)
def emul_Ridge(mode: emulation.Mode.MULTIPROCESS):
def proc(x1, x2, y):
model = Ridge(alpha=1.0, solver="cholesky")

x = jnp.concatenate((x1, x2), axis=1)
y = y.reshape((y.shape[0], 1))

return model.fit(x, y).predict(x)

def load_data():
dataset_config = {
"use_mock_data": False,
"problem_type": "regression",
"builtin_dataset_name": "diabetes",
"left_slice_feature_ratio": 0.5,
}

x1, x2, y = dsutil.load_dataset_by_config(dataset_config)

return x1, x2, y

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(
emulation.CLUSTER_ABY3_3PC, mode, bandwidth=300, latency=20
)
emulator.up()

# load mock data
x1, x2, y = load_data()

# sklearn test
x = jnp.concatenate((x1, x2), axis=1)
sklearn_result = (
skRidge(alpha=1, solver='cholesky', fit_intercept=True).fit(x, y).predict(x)
)
print("[sklearn_result]---------------------------------------------")
print(sklearn_result[:10])

# mark these data to be protected in SPU
x1, x2, y = emulator.seal(x1, x2, y)

# run
result = emulator.run(proc)(x1, x2, y)
print("[emul_result]------------------------------------------------")
print(result[:10])

# absolute_error
print("[absolute_error]---------------------------------------------")
print(jnp.round(jnp.abs(result - sklearn_result)[:20], 5))
finally:
emulator.down()


if __name__ == "__main__":
emul_Ridge(emulation.Mode.MULTIPROCESS)
140 changes: 140 additions & 0 deletions sml/linear_model/ridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum

import jax.numpy as jnp
import jax.scipy as jsci


class Solver(Enum):
SVD = 'svd' # not supported
CHOLESKY = 'cholesky'


class Ridge:
"""Linear least squares with l2 regularization.
Minimizes the objective function::
||y - Xw||^2_2 + alpha * ||w||^2_2
This model solves a regression model where the loss function is
the linear least squares function and regularization is given by
the l2-norm. Also known as Ridge Regression or Tikhonov regularization.
Parameters
----------
alpha : {float}, default=1.0
Constant that multiplies the L2 term, controlling regularization
strength. `alpha` must be a non-negative float i.e. in `[0, inf)`.
fit_bias : bool, default=True
Whether to fit the bias for this model. If set
to false, no bias will be used in calculations
solver : {'svd', 'cholesky'}, default='cholesky'
Solver to use in the computational routines:
- 'svd' uses a Singular Value Decomposition of X to compute the Ridge
coefficients.
- 'cholesky' uses the standard jax.scipy.linalg.solve function to
obtain a closed-form solution via a Cholesky decomposition of
dot(X.T, X)
"""

def __init__(self, alpha=1.0, fit_bias=True, solver="cholesky") -> None:
self.alpha = alpha
self.solver = solver
self.fit_bias = fit_bias

def fit(self, x, y):
"""Fit Ridge regression model.
Parameters
----------
X : {array-like}, shape (n_samples, n_features)
Training data.
y : ndarray of shape (n_samples,)
Target values.
Returns
-------
self : object
Returns an instance of self.
"""
if y.ndim == 1:
y = y.reshape(-1, 1)
alpha = float(self.alpha)

x, y, x_offset, y_offset = self.preprocess_data(x, y)

if self.solver == Solver.CHOLESKY.value:
self.coef = _solve_cholesky(x, y, alpha)
self.coef = self.coef.ravel()

self.set_bias(x_offset, y_offset)

return self

def predict(self, x):
"""
Predict using the linear model.
Parameters
----------
X : {array-like}, shape (n_samples, n_features)
Input data for prediction.
Returns
-------
C : array, shape (n_samples,)
Returns predicted values.
"""
a = x
b = self.coef.T
ret = jnp.dot(a, b) + self.bias
return ret

def preprocess_data(self, x, y):
# Center and scale data.
if self.fit_bias:
x_offset = jnp.average(x, axis=0)
x -= x_offset
y_offset = jnp.average(y, axis=0)
y -= y_offset
else:
x_offset = None
y_offset = None
return x, y, x_offset, y_offset

def set_bias(self, x_offset, y_offset):
if self.fit_bias:
self.bias = y_offset - jnp.dot(x_offset, self.coef.T)
else:
self.bias = 0.0


def _solve_cholesky(x, y, alpha):
# w = inv(X^t X + alpha*Id) * X.T y
n_features = x.shape[1]

A = jnp.dot(x.T, x)
Xy = jnp.dot(x.T, y)

A += jnp.diag(jnp.ones(n_features) * alpha)

coefs = jsci.linalg.solve(A, Xy, assume_a="pos", overwrite_a=True).T
return coefs
11 changes: 11 additions & 0 deletions sml/linear_model/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,14 @@ py_test(
"//spu/utils:simulation",
],
)

py_test(
name = "ridge_test",
srcs = ["ridge_test.py"],
deps = [
"//examples/python/utils:dataset_utils", # FIXME: remove examples dependency
"//sml/linear_model:ridge",
"//spu:init",
"//spu/utils:simulation",
],
)
67 changes: 67 additions & 0 deletions sml/linear_model/tests/ridge_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import jax.numpy as jnp
from sklearn.linear_model import Ridge as skRidge

import examples.python.utils.dataset_utils as dsutil
import spu.spu_pb2 as spu_pb2 # type: ignore
import spu.utils.simulation as spsim
from sml.linear_model.ridge import Ridge


class UnitTests(unittest.TestCase):
def test_ridge(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
)

def proc(x1, x2, y):
model = Ridge(alpha=1.0, solver="cholesky")

x = jnp.concatenate((x1, x2), axis=1)
y = y.reshape((y.shape[0], 1))

result = model.fit(x, y).predict(x)
return result

dataset_config = {
"use_mock_data": False,
"problem_type": "regression",
"builtin_dataset_name": "diabetes",
"left_slice_feature_ratio": 0.5,
}

x1, x2, y = dsutil.load_dataset_by_config(dataset_config)

# sklearn test
x = jnp.concatenate((x1, x2), axis=1)
sklearn_result = (
skRidge(alpha=1, solver='cholesky', fit_intercept=True).fit(x, y).predict(x)
)
print("[sklearn_result]---------------------------------------------")
print(sklearn_result[:10])

result = spsim.sim_jax(sim, proc)(x1, x2, y)
print("[spsim_result]-----------------------------------------------")
print(result[:10])

# absolute_error
print("[absolute_error]---------------------------------------------")
print(jnp.round(jnp.abs(result - sklearn_result)[:20], 5))


if __name__ == "__main__":
unittest.main()

0 comments on commit 9857c82

Please sign in to comment.