Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OSCP]利用SPU实现分位数回归算法 #865

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion libspu/dialect/pphlo/IR/type_inference.cc
xbw886 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,13 @@ LogicalResult inferDynamicUpdateSliceOp(
}

// dynamic_update_slice_c1
TypeTools tools(operand.getContext());
auto common_vis =
tools.computeCommonVisibility({tools.getTypeVisibility(operandType),
tools.getTypeVisibility(updateType)});
inferredReturnTypes.emplace_back(RankedTensorType::get(
operandType.getShape(), operandType.getElementType()));
operandType.getShape(),
tools.getType(operandType.getElementType(), common_vis)));
return success();
}

Expand Down
8 changes: 8 additions & 0 deletions sml/linear_model/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,11 @@ py_binary(
"//sml/linear_model/utils:solver",
],
)

py_library(
name = "quantile",
srcs = ["quantile.py"],
deps = [
"//sml/linear_model/utils:_linprog_simplex",
],
)
9 changes: 9 additions & 0 deletions sml/linear_model/emulations/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ py_binary(
"//sml/utils:emulation",
],
)

py_binary(
name = "quantile_emul",
srcs = ["quantile_emul.py"],
deps = [
"//sml/linear_model:quantile",
"//sml/utils:emulation",
],
)
102 changes: 102 additions & 0 deletions sml/linear_model/emulations/quantile_emul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import jax.numpy as jnp
from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor

import sml.utils.emulation as emulation
from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor

CONFIG_FILE = emulation.CLUSTER_ABY3_3PC


def emul_quantile(mode=emulation.Mode.MULTIPROCESS):
def proc_wrapper(
quantile,
alpha,
fit_intercept,
lr,
max_iter,
):
quantile_custom = SmlQuantileRegressor(
quantile=quantile,
alpha=alpha,
fit_intercept=fit_intercept,
lr=lr,
max_iter=max_iter,
)

def proc(X, y):
quantile_custom_fit = quantile_custom.fit(X, y)
result = quantile_custom_fit.predict(X)
return result

return proc

def generate_data():
from jax import random

# 设置随机种子
key = random.PRNGKey(42)
# 生成 X 数据
key, subkey = random.split(key)
X = random.normal(subkey, (100, 2))
# 生成 y 数据
y = (
5 * X[:, 0] + 2 * X[:, 1] + random.normal(key, (100,)) * 0.1
) # 高相关性,带有小噪声
return X, y

try:
# bandwidth and latency only work for docker mode
emulator = emulation.Emulator(CONFIG_FILE, mode, bandwidth=300, latency=20)
emulator.up()

# load mock data
X, y = generate_data()

# compare with sklearn
quantile_sklearn = SklearnQuantileRegressor(
quantile=0.3, alpha=0.1, fit_intercept=True, solver='highs'
)
start = time.time()
quantile_sklearn_fit = quantile_sklearn.fit(X, y)
score_plain = jnp.mean(y <= quantile_sklearn_fit.predict(X))
xbw886 marked this conversation as resolved.
Show resolved Hide resolved
end = time.time()
print(f"Running time in SKlearn: {end - start:.2f}s")

# mark these data to be protected in SPU
X_spu, y_spu = emulator.seal(X, y)

# run
proc = proc_wrapper(
quantile=0.3, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300
xbw886 marked this conversation as resolved.
Show resolved Hide resolved
)
start = time.time()
result = emulator.run(proc)(X_spu, y_spu)
end = time.time()
score_encrpted = jnp.mean(y <= result)
xbw886 marked this conversation as resolved.
Show resolved Hide resolved
print(f"Running time in SPU: {end - start:.2f}s")

# print acc
print(f"Accuracy in SKlearn: {score_plain:.2f}")
print(f"Accuracy in SPU: {score_encrpted:.2f}")

finally:
emulator.down()


if __name__ == "__main__":
emul_quantile(emulation.Mode.MULTIPROCESS)
185 changes: 185 additions & 0 deletions sml/linear_model/quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright 2024 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers
import warnings
xbw886 marked this conversation as resolved.
Show resolved Hide resolved
from warnings import warn

import jax
import jax.numpy as jnp
import pandas as pd
from jax import grad

from sml.linear_model.utils._linprog_simplex import _linprog_simplex


class QuantileRegressor:
"""
Initialize the quantile regression model.
Parameters
----------
quantile : float, default=0.5
The quantile to be predicted. Must be between 0 and 1.
A quantile of 0.5 corresponds to the median (50th percentile).
alpha : float, default=1.0
Regularization strength; must be a positive float.
Larger values specify stronger regularization, reducing model complexity.
fit_intercept : bool, default=True
Whether to calculate the intercept for the model.
If False, no intercept will be used in calculations, meaning the model will
assume that the data is already centered.
lr : float, default=0.01
Learning rate for the optimization process. This controls the size of
the steps taken in each iteration towards minimizing the objective function.
max_iter : int, default=1000
The maximum number of iterations for the optimization algorithm.
This controls how long the model will continue to update the weights
before stopping.
Attributes
----------
coef_ : array-like of shape (n_features,)
The coefficients (weights) assigned to the input features. These will be
learned during model fitting.
intercept_ : float
The intercept (bias) term. If `fit_intercept=True`, this will be
learned during model fitting.
"""

def __init__(
self, quantile=0.5, alpha=1.0, fit_intercept=True, lr=0.01, max_iter=1000
):
self.quantile = quantile
self.alpha = alpha
self.fit_intercept = fit_intercept
self.lr = lr
self.max_iter = max_iter

self.coef_ = None
self.intercept_ = None

def fit(self, X, y, sample_weight=None):
"""
Fit the quantile regression model using linear programming.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target values.
sample_weight : array-like of shape (n_samples,), optional
Individual weights for each sample. If not provided, all samples
are assumed to have equal weight.
Returns
-------
self : object
Returns an instance of self.
Steps:
1. Determine the number of parameters (`n_params`), accounting for the intercept if needed.
2. Define the objective function `c`, incorporating both the L1 regularization and the pinball loss.
3. Set up the equality constraint matrix `A_eq` and vector `b_eq` based on the input data `X` and `y`.
4. Solve the linear programming problem using `_linprog_simplex`.
5. Extract the model parameters (intercept and coefficients) from the solution.
"""
n_samples, n_features = X.shape
n_params = n_features

# sample_weight = jnp.ones((n_samples,))
if sample_weight is None:
sample_weight = jnp.ones((n_samples,))

if self.fit_intercept:
n_params += 1

alpha = jnp.sum(sample_weight) * self.alpha

# After rescaling alpha, the minimization problem is
# min sum(pinball loss) + alpha * L1
# Use linear programming formulation of quantile regression
# min_x c x
# A_eq x = b_eq
# 0 <= x
# x = (s0, s, t0, t, u, v) = slack variables >= 0
# intercept = s0 - t0
# coef = s - t
# c = (0, alpha * 1_p, 0, alpha * 1_p, quantile * 1_n, (1-quantile) * 1_n)
# residual = y - X@coef - intercept = u - v
# A_eq = (1_n, X, -1_n, -X, diag(1_n), -diag(1_n))
# b_eq = y
# p = n_features
# n = n_samples
# 1_n = vector of length n with entries equal one
# see https://stats.stackexchange.com/questions/384909/
c = jnp.concatenate(
[
jnp.full(2 * n_params, fill_value=alpha),
sample_weight * self.quantile,
sample_weight * (1 - self.quantile),
]
)

if self.fit_intercept:
c = c.at[0].set(0)
c = c.at[n_params].set(0)

eye = jnp.eye(n_samples)
if self.fit_intercept:
ones = jnp.ones((n_samples, 1))
A = jnp.concatenate([ones, X, -ones, -X, eye, -eye], axis=1)
else:
A = jnp.concatenate([X, -X, eye, -eye], axis=1)

b = y

result = _linprog_simplex(c, A, b, maxiter=self.max_iter, tol=1e-3)

solution = result[0]

params = solution[:n_params] - solution[n_params : 2 * n_params]

if self.fit_intercept:
self.coef_ = params[1:]
self.intercept_ = params[0]
else:
self.coef_ = params
self.intercept_ = 0.0
return self

def predict(self, X):
"""
Predict target values using the fitted quantile regression model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Input data for which predictions are to be made.
Returns
-------
y_pred : array-like of shape (n_samples,)
Predicted target values.
Notes
-----
The predict method computes the predicted target values using the model's
learned coefficients and intercept (if fit_intercept=True).
- If the model includes an intercept, a column of ones is added to the input data `X` to account
for the intercept in the linear combination.
- The method then computes the dot product between the modified `X` and the stacked vector of
intercept and coefficients.
- If there is no intercept, the method simply computes the dot product between `X` and the coefficients.
"""

xbw886 marked this conversation as resolved.
Show resolved Hide resolved
if self.fit_intercept:
xbw886 marked this conversation as resolved.
Show resolved Hide resolved
X = jnp.column_stack((jnp.ones(X.shape[0]), X))

return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_]))
else:
return jnp.dot(X, self.coef_)
10 changes: 10 additions & 0 deletions sml/linear_model/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,13 @@ py_test(
"//spu/utils:simulation",
],
)

py_test(
name = "quantile_test",
srcs = ["quantile_test.py"],
deps = [
"//sml/linear_model:quantile",
"//spu:init",
"//spu/utils:simulation",
],
)
Loading
Loading