Skip to content

Commit

Permalink
Merge pull request #21 from HanXudong/dataset
Browse files Browse the repository at this point in the history
Version 0.1.0
  • Loading branch information
HanXudong authored Nov 26, 2022
2 parents 8fd6f2f + 0279dec commit bce46b5
Show file tree
Hide file tree
Showing 23 changed files with 564 additions and 57 deletions.
2 changes: 1 addition & 1 deletion fairlib/datasets/bios/bios.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def bert_encoding(self):
avg_data, cls_data = self.encoder.encode(text_data)
split_df["bert_avg_SE"] = list(avg_data)
split_df["bert_cls_SE"] = list(cls_data)
split_df["gender_class"] = split_df["g"]
split_df["gender_class"] = split_df["g"].map(gender2id)
split_df["profession_class"] = split_df["p"].map(professions2id)

split_df.to_pickle(Path(self.dest_folder) / "bios_{}_df.pkl".format(split))
Expand Down
3 changes: 1 addition & 2 deletions fairlib/datasets/utils/bert_encoding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import torch
from transformers import *
import pickle
from transformers import BertModel, BertTokenizer
from tqdm.auto import tqdm, trange

class BERT_encoder:
Expand Down
2 changes: 1 addition & 1 deletion fairlib/src/analysis/tables_and_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,4 @@ def make_zoom_plot(
ax.indicate_inset_zoom(axins, edgecolor="black")

if figure_name is not None:
fig.savefig(figure_name, dpi=960, bbox_inches="tight")
fig.savefig(figure_name+".pdf", format="pdf", dpi=960, bbox_inches="tight")
28 changes: 26 additions & 2 deletions fairlib/src/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .networks import adv
from .networks import FairCL
from .networks.DyBT import Group_Difference_Loss
from .networks.ARL import ARL

class State(object):

Expand Down Expand Up @@ -185,6 +186,8 @@ def __getattr__(self, name):
help='how many batches to wait before logging training status')
parser.add_argument('--save_batch_results', action='store_true', default=False,
help='if saving batch evaluation results')
parser.add_argument('--save_models', action='store_true', default=False,
help='if saving model parameters')
parser.add_argument('--checkpoint_interval', type=int, default=1, metavar='N',
help='checkpoint interval (epoch)')
parser.add_argument('--dataset', type=str, default='Moji',
Expand Down Expand Up @@ -282,6 +285,8 @@ def __getattr__(self, name):
# Gated adv
parser.add_argument('--adv_gated', action='store_true', default=False,
help='gated discriminator for augmented inputs given target labels')
parser.add_argument('--adv_gated_type', type=str, default="Augmentation",
help='Augmentation | Inputs | Separate')
parser.add_argument('--adv_BT', type=str, default=None, help='instacne reweighting for adv')
parser.add_argument('--adv_BTObj', type=str, default=None, help='instacne reweighting for adv')

Expand Down Expand Up @@ -338,6 +343,13 @@ def __getattr__(self, name):
parser.add_argument('--GBT_N', type=nonneg_int, default=None, help='size of the manipulated dataset')
parser.add_argument("--GBT_alpha", type=float, default=1, help="interpolation for generalized BT")

# ARL
parser.add_argument('--ARL', action='store_true', default=False,
help='Perform adversarial reweighted learning (ARL)')
parser.add_argument('--ARL_n',type=pos_int, default=1,
help='Update the adversary n times per main model update')


def get_dummy_state(self, *cmdargs, yaml_file=None, **opt_pairs):
if yaml_file is None:
# Use default Namespace (not UniqueNamespace) because dummy state may
Expand Down Expand Up @@ -515,8 +527,15 @@ def set_state(self, state, dummy=False, silence=False):

# Init discriminator for adversarial training
if state.adv_debiasing:

state.opt.discriminator = networks.adv.Discriminator(state)
# if state.adv_decoupling:
# raise NotImplementedError

if state.adv_gated and (state.adv_gated_type == "Separate"):
# Train a set of discriminators for each class
state.opt.discriminator = [networks.adv.Discriminator(state) for _ in range(state.num_classes)]
else:
# All other adv settings
state.opt.discriminator = networks.adv.Discriminator(state)
logging.info('Discriminator built!')
# adv.utils.print_network(state.opt.discriminator.subdiscriminators[0])

Expand All @@ -530,6 +549,11 @@ def set_state(self, state, dummy=False, silence=False):
if (state.DyBT is not None) and (state.DyBT == "GroupDifference"):
state.opt.group_difference_loss = Group_Difference_Loss(state)

# Init the ARL for unsupervised training
if state.ARL:
assert not state.adv_debiasing, "ARL is unsupervised bias mitigation, which cannot be used together with adversarial training"
state.opt.ARL_loss = ARL(state)

return state


Expand Down
10 changes: 5 additions & 5 deletions fairlib/src/dataloaders/BT.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
method (str, optional): Downsampling | Resampling. Defaults to "Downsampling".
Returns:
list: a list of indices of selected instacnes.
list: a list of indices of selected instances.
"""

# init a dict for storing the index of each group.
Expand Down Expand Up @@ -133,15 +133,15 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
weighting_counter = Counter(y)

# a list of (weights, actual length)
condidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_y] for (_y, _g) in group_idx.keys()])
candidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_y] for (_y, _g) in group_idx.keys()])

distinct_y_label = set(y)
distinct_g_label = set(protected_label)

# iterate each main task class
for y in distinct_y_label:
if method == "Downsampling":
selected = int(condidate_selected * weighting_counter[y])
selected = int(candidate_selected * weighting_counter[y])
elif method == "Resampling":
selected = int(weighting_counter[y] / len(distinct_g_label))
for g in distinct_g_label:
Expand All @@ -157,7 +157,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
weighting_counter = Counter(protected_label)
# a list of (weights, actual length)
# Noticing that if stratified_g, the order within the key has been changed.
condidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_g] for (_y, _g) in group_idx.keys()])
candidate_selected = min([len(group_idx[(_y, _g)])/weighting_counter[_g] for (_y, _g) in group_idx.keys()])

distinct_y_label = set(y)
distinct_g_label = set(protected_label)
Expand All @@ -166,7 +166,7 @@ def get_sampled_indices(BTObj, y, protected_label, method = "Downsampling"):
# for y in distinct_y_label:
for g in distinct_g_label:
if method == "Downsampling":
selected = int(condidate_selected * weighting_counter[g])
selected = int(candidate_selected * weighting_counter[g])
elif method == "Resampling":
selected = int(weighting_counter[g] / len(distinct_y_label))
# for g in distinct_g_label:
Expand Down
2 changes: 1 addition & 1 deletion fairlib/src/dataloaders/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ def __init__(self, args) -> None:

def encoder(self, sample):
encodings = self.tokenizer(sample, truncation=True, padding=True)
return encodings["input_ids"]
return encodings["input_ids"], encodings['token_type_ids'], encodings['attention_mask']
6 changes: 5 additions & 1 deletion fairlib/src/dataloaders/loaders/Adult.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ def load_data(self):
if self.args.protected_task == "gender":
self.protected_label =np.array(list(data["sex"])).astype(np.int32) # Gender
elif self.args.protected_task == "race":
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[_r+_s*5 for _r,_s in zip(list(data["race"]), list(data["sex"]))]
).astype(np.int32) # Intersectional
12 changes: 10 additions & 2 deletions fairlib/src/dataloaders/loaders/Bios.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@ def load_data(self):

data = pd.read_pickle(Path(self.args.data_dir) / self.filename)

if self.args.protected_task in ["economy", "both"] and self.args.full_label:
# if self.args.protected_task in ["economy", "both"] and self.args.full_label:
if self.args.protected_task in ["gender", "economy", "both", "intersection"] and self.args.full_label:
selected_rows = (data["economy_label"] != "Unknown")
data = data[selected_rows]

if self.args.encoder_architecture == "Fixed":
self.X = list(data[self.embedding_type])
elif self.args.encoder_architecture == "BERT":
self.X = self.args.text_encoder.encoder(list(data[self.text_type]))
_input_ids, _token_type_ids, _attention_mask = self.args.text_encoder.encoder(list(data[self.text_type]))
self.X = _input_ids
self.addition_values["input_ids"] = _input_ids
self.addition_values['attention_mask'] = _attention_mask
else:
raise NotImplementedError

Expand All @@ -28,5 +32,9 @@ def load_data(self):
self.protected_label = data["gender_class"].astype(np.int32) # Gender
elif self.args.protected_task == "economy":
self.protected_label = data["economy_class"].astype(np.int32) # Economy
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[2*_e+_g for _e,_g in zip(list(data["economy_class"]), list(data["gender_class"]))]
).astype(np.int32) # Intersection
else:
self.protected_label = data["intersection_class"].astype(np.int32) # Intersection
6 changes: 5 additions & 1 deletion fairlib/src/dataloaders/loaders/COMPAS.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ def load_data(self):
if self.args.protected_task == "gender":
self.protected_label =np.array(list(data["sex"])).astype(np.int32) # Gender
elif self.args.protected_task == "race":
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
self.protected_label = np.array(list(data["race"])).astype(np.int32) # Race
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[_r+_s*3 for _r,_s in zip(list(data["race"]), list(data["sex"]))]
).astype(np.int32) # Intersectional
8 changes: 6 additions & 2 deletions fairlib/src/dataloaders/loaders/Trustpilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ def load_data(self):
if self.args.protected_task == "gender":
self.protected_label = data["gender_label"].astype(np.int32) # Gender
elif self.args.protected_task == "age":
self.protected_label = data["age_label"].astype(np.int32) # Economy
self.protected_label = data["age_label"].astype(np.int32) # Age
elif self.args.protected_task == "country":
self.protected_label = data["country_label"].astype(np.int32) # Economy
self.protected_label = data["country_label"].astype(np.int32) # Country
elif self.args.protected_task == "intersection":
self.protected_label = np.array(
[4*_g+2*_a+_c for _g,_a,_c in zip(list(data["gender_label"]), list(data["age_label"]), data["country_label"])]
).astype(np.int32) # Intersection
else:
raise NotImplementedError
30 changes: 27 additions & 3 deletions fairlib/src/dataloaders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, args, split):
self.instance_weights = []
self.adv_instance_weights = []
self.regression_label = []
self.addition_values = {}

self.load_data()

Expand All @@ -50,7 +51,6 @@ def __init__(self, args, split):
if self.split == "train":
self.adv_decoupling()


print("Loaded data shapes: {}, {}, {}".format(self.X.shape, self.y.shape, self.protected_label.shape))

def __len__(self):
Expand All @@ -59,7 +59,25 @@ def __len__(self):

def __getitem__(self, index):
'Generates one sample of data'
return self.X[index], self.y[index], self.protected_label[index], self.instance_weights[index], self.adv_instance_weights[index], self.regression_label[index]
_X = self.X[index]
_y = self.y[index]
_protected_label = self.protected_label[index]
_instance_weights = self.instance_weights[index]
_adv_instance_weights = self.adv_instance_weights[index]
_regression_label = self.regression_label[index]

data_dict = {
0:_X,
1:_y,
2:_protected_label,
3:_instance_weights,
4:_adv_instance_weights,
5:_regression_label,
}
for _k in self.addition_values.keys():
if _k not in data_dict.keys():
data_dict[_k] = self.addition_values[_k][index]
return data_dict

def load_data(self):
pass
Expand All @@ -79,6 +97,9 @@ def manipulate_data_distribution(self):
self.y = self.y[selected_index]
self.protected_label = self.protected_label[selected_index]

for _k in self.addition_values.keys():
self.addition_values[_k] = [self.addition_values[_k][index] for index in selected_index]

def balanced_training(self):
if (self.args.BT is None) or (self.split != "train"):
# Without balanced training
Expand Down Expand Up @@ -112,6 +133,9 @@ def balanced_training(self):
self.protected_label = np.array(_protected_label)
self.instance_weights = np.array([1 for _ in range(len(self.protected_label))])

for _k in self.addition_values.keys():
self.addition_values[_k] = [self.addition_values[_k][index] for index in selected_index]

else:
raise NotImplementedError
return None
Expand Down Expand Up @@ -152,7 +176,7 @@ def adv_decoupling(self):
else:
pass
return None

def regression_init(self):
if not self.args.regression:
self.regression_label = np.array([0 for _ in range(len(self.protected_label))])
Expand Down
44 changes: 43 additions & 1 deletion fairlib/src/evaluators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,46 @@ def present_evaluation_scores(
validation_results = ["{}: {:2.2f}\t".format(k, 100.*valid_scores[k]) for k in valid_scores.keys()]
logging.info(('Validation {}').format("".join(validation_results)))
Test_results = ["{}: {:2.2f}\t".format(k, 100.*test_scores[k]) for k in test_scores.keys()]
logging.info(('Test {}').format("".join(Test_results)))
logging.info(('Test {}').format("".join(Test_results)))


def validation_is_best(
valid_preds, valid_labels, valid_private_labels,
model, epoch_valid_loss, selection_criterion = "DTO",
performance_metric = "accuracy", fairness_metric="TPR_GAP"
):
"""
Check is the current model is the best so far.
"""

is_best = False

valid_scores, _ = gap_eval_scores(
y_pred=valid_preds,
y_true=valid_labels,
protected_attribute=valid_private_labels,
args = model.args,
)

if selection_criterion == "DTO":
valid_dto_score = ((1-valid_scores[performance_metric])**2 + valid_scores[fairness_metric]**2)**0.5
if valid_dto_score < model.best_valid_loss:
model.best_valid_loss = valid_dto_score
is_best = True
elif selection_criterion == "Loss":
if epoch_valid_loss < model.best_valid_loss:
model.best_valid_loss = epoch_valid_loss
is_best = True
elif selection_criterion == "Performance":
if (1-valid_scores[performance_metric]) < model.best_valid_loss:
model.best_valid_loss = 1-valid_scores[performance_metric]
is_best = True
elif selection_criterion == "Fairness":
if valid_scores[fairness_metric] < model.best_valid_loss:
model.best_valid_loss = valid_scores[fairness_metric]
is_best = True
else:
raise NotImplementedError


return is_best
Loading

0 comments on commit bce46b5

Please sign in to comment.