Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest changes to Flex #7

Merged
merged 12 commits into from
Feb 1, 2024
59 changes: 58 additions & 1 deletion fomo/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
from pymoo.algorithms.base.genetic import GeneticAlgorithm

def get_parent(pop):

if not hasattr(get_parent_WeightedCoinFlip, "_called"):
print("Default flex")
get_parent_WeightedCoinFlip._called = True

fng = pop.get("fng")
fn = pop.get("fn")
Expand Down Expand Up @@ -84,6 +88,10 @@ def get_parent(pop):

def get_parent_noCoinFlip(pop):

if not hasattr(get_parent_WeightedCoinFlip, "_called"):
print("Flex with no coin flip")
get_parent_WeightedCoinFlip._called = True

fng = pop.get("fng")
fng = np.tile(fng, 2)
fn = pop.get("fn")
Expand Down Expand Up @@ -116,6 +124,55 @@ def get_parent_noCoinFlip(pop):
return random.choice(S)


def get_parent_WeightedCoinFlip(pop):

if not hasattr(get_parent_WeightedCoinFlip, "_called"):
print("Flex with weighted coin flip")
get_parent_WeightedCoinFlip._called = True

samples_fnr = pop.get("samples_fnr")
fng = pop.get("fng")
fn = pop.get("fn")
gp_lens = pop.get('gp_lens')
G = np.arange(fng.shape[1])
S = np.arange(len(pop))
loss = []
weight = random.random()

while (len(G) > 0 and len(S) > 1):

g = random.choice(G)
loss = []

if (random.random() > weight):
#look at fairness
loss = fng[:, g]
G = G[np.where(G != g)]
else:
#look at accuracy
num_rows, num_cols = np.shape(samples_fnr)
indices = np.random.choice(num_cols, size = int(gp_lens[0, g]), replace = False)
fnr_sum = np.sum(samples_fnr[:, indices], axis=1)
pos_count = np.sum(samples_fnr[:, indices].astype(bool), axis=1)
for i in range (len(pos_count)):
if pos_count[i]:
loss.append(fnr_sum[i]/pos_count[i])
else:
loss.append(0)


L = min(loss)
epsilon = np.median(np.abs(loss - np.median(loss)))
survivors = np.where(loss <= L + epsilon)
S = S[survivors]
fng = fng[survivors]
fn = fn[survivors]
samples_fnr = samples_fnr[survivors]
gp_lens = gp_lens[survivors]

S = S[:, None].astype(int, copy=False)
return random.choice(S)

class FLEX(Selection):

def __init__(self,
Expand All @@ -138,7 +195,7 @@ def _do(self, _, pop, n_select, n_parents=1, flag = 0, **kwargs):

for i in range(n_select * n_parents):
#get pop_size parents
p = get_parent_noCoinFlip(pop)
p = get_parent(pop)
parents.append(p)

return np.reshape(parents, (n_select, n_parents))
Expand Down
2 changes: 1 addition & 1 deletion fomo/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def _init_metrics(self):
self.accuracy_metrics_ = self.accuracy_metrics
self.fairness_metrics_ = self.fairness_metrics
if self.accuracy_metrics is None:
self.accuracy_metrics_ = [make_scorer(roc_auc_score, greater_is_better=False)]
self.accuracy_metrics_ = [make_scorer(roc_auc_score, greater_is_better=False, needs_proba=True)]
if self.fairness_metrics is None:
self.fairness_metrics_ = [metrics.multicalibration_loss]

Expand Down
49 changes: 32 additions & 17 deletions fomo/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ def subgroup_loss(y_true, y_pred, X_protected, metric, grouping = 'intersectiona
mask = X_protected[col] == val
indices = X_protected[mask].index
categories[category_key] = indices
# print('#intersectional groups: ', len(categories))
# singles = 0
# gp_lens = [len(lst) for lst in categories.values()]
# singles = gp_lens.count(1)
# avg_len = sum(gp_lens) / len(gp_lens) if gp_lens else 0

if isinstance(metric,str):
loss_fn = FPR if metric=='FPR' else FNR
Expand Down Expand Up @@ -357,12 +362,14 @@ def subgroup_MSE_scorer(estimator, X, y_true, **kwargs):
return subgroup_scorer( estimator, X, y_true, mean_squared_error, **kwargs)


def loss(estimator, X, y_true, metric, flag = 1, **kwargs):
def flex_loss(estimator, X, y_true, metric, **kwargs):
"""
returns
----------
fn: overall loss of all samples
fng: loss over group for every group in the training data
samples_fnr: False negative rate of every sample in the training data
gp_lens: length of each protected group

Parameters
----------
Expand All @@ -381,7 +388,9 @@ def loss(estimator, X, y_true, metric, flag = 1, **kwargs):
groups = kwargs['groups']
X_protected = X[groups]
categories = {}
group_losses = []
fng = []
samples_fnr = []
gp_lens = []

y_pred = estimator.predict_proba(X)[:,1]
y_pred = pd.Series(y_pred, index=X_protected.index)
Expand All @@ -393,30 +402,36 @@ def loss(estimator, X, y_true, metric, flag = 1, **kwargs):
else:
raise ValueError(f'metric={metric} must be "FPR", "FNR", or a callable')


if (flag == 1): #marginal grouping
categories = {}
for col in X_protected.columns:
unique_values = X_protected[col].unique()
for val in unique_values:
category_key = f'{col}_{val}'
mask = X_protected[col] == val
indices = X_protected[mask].index
categories[category_key] = indices
else: #intersectional grouping (flag is not 0 for now according to paper)
categories = X_protected.groupby(groups).groups
categories = {}
for col in X_protected.columns:
unique_values = X_protected[col].unique()
for val in unique_values:
category_key = f'{col}_{val}'
mask = X_protected[col] == val
indices = X_protected[mask].index
categories[category_key] = indices

for c, idx in categories.items():

category_loss = loss_fn(
y_true.loc[idx].values,
y_pred.loc[idx].values
)
group_losses.append(category_loss)
fng.append(category_loss)
gp_lens.append(len(y_true.loc[idx].values))

# print('#marginal groups: ', len(categories))
# singles = 0
# singles = gp_lens.count(1)
# avg_len = sum(gp_lens) / len(gp_lens) if gp_lens else 0

#Calculate FNR of each sample
for idx in y_true.index:
fnr = loss_fn(y_true[idx], y_pred[idx])
samples_fnr.append(fnr)

fn = loss_fn(y_true, y_pred)
fng = group_losses
return fn, fng
return fn, fng, samples_fnr, gp_lens


def mce(estimator, X, y_true, num_bins=10):
Expand Down
16 changes: 16 additions & 0 deletions fomo/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import inspect
import fomo.metrics as metrics
from .surrogate_models import MLP, Linear, InterLinear
from fomo.algorithm import Lexicase, Lexicase_NSGA2

class BasicProblem(ElementwiseProblem):
""" The evaluation function for each candidate sample weights. """
Expand Down Expand Up @@ -95,6 +96,14 @@ def _evaluate(self, sample_weight, out, *args, **kwargs):

out['F'] = np.asarray(f)

if isinstance(self.fomo_estimator.algorithm, (Lexicase, Lexicase_NSGA2)):
fn, fng, samples_fnr, gp_lens = metrics.flex_loss(est, X, y, 'FNR', **self.metric_kwargs)
out['fn'] = fn #FNR of all samples to be used in Flex
out['fng'] = fng #FNR of every group to be used in Flex
out['samples_fnr'] = samples_fnr #FNR of each sample to be used in Flex with weighted coin flip
out['gp_lens'] = gp_lens #Length of each protected group to be used in Flex with weighted coin flip


class SurrogateProblem(ElementwiseProblem):
""" The evaluation function for each candidate weights.

Expand Down Expand Up @@ -173,6 +182,13 @@ def _evaluate(self, x, out, *args, **kwargs):

out['F'] = np.asarray(f)

if isinstance(self.fomo_estimator.algorithm, (Lexicase, Lexicase_NSGA2)):
fn, fng, samples_fnr, gp_lens = metrics.flex_loss(est, X, y, 'FNR', **self.metric_kwargs)
out['fn'] = fn #FNR of all samples to be used in Flex
out['fng'] = fng #FNR of every group to be used in Flex
out['samples_fnr'] = samples_fnr #FNR of each sample to be used in Flex with weighted coin flip
out['gp_lens'] = gp_lens #Length of each protected group to be used in Flex with weighted coin flip

class MLPProblem(SurrogateProblem):
""" The evaluation function for each candidate weights.

Expand Down
2 changes: 1 addition & 1 deletion fomo/surrogate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _one_hot_encode(self, X):
return self.ohc.transform(X)
else:
binary_columns = [col for col in X.columns if X[col].isin([0, 1]).all()]
categorical_features = [c for c in X.columns if (X[c].nunique() < 8 and c not in binary_columns)] #Do not one-hot-encode binary columns and columns with more than 8 categories.
categorical_features = [c for c in X.columns if (X[c].nunique() <= 8 and c not in binary_columns)] #Do not one-hot-encode binary columns and columns with more than 8 categories.
self.ohc = ColumnTransformer(
[
(
Expand Down
Loading