Skip to content

Commit

Permalink
Add perturber training
Browse files Browse the repository at this point in the history
  • Loading branch information
baskrahmer committed Feb 28, 2024
1 parent 56707c4 commit 8154b50
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 99 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ library, you can easily train and deploy perturbation augmentation models

# Roadmap

- [ ] Add default perturber model
- [x] Add default perturber model
- [ ] Reproduce training of perturber model
- [ ] Pretrain small and medium perturber models
- [ ] Add training of unconditional perturber models (i.e. only get a sentence, no target word/attribute)
- [ ] Add self-training by pretraining perturber base model (e.g. BART) on self-perturbed data

Expand Down
55 changes: 43 additions & 12 deletions perturbers/modeling/perturber.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,50 @@
import random
from dataclasses import dataclass
from typing import Optional

from transformers import BartForConditionalGeneration, BartTokenizer
from transformers import BartForConditionalGeneration, AutoTokenizer

from perturbers.data.panda_dict import get_panda_dict


@dataclass
class PerturberConfig:
sep_token: str = '<SEP>'
pert_sep_token: str = '<PERT_SEP>'
max_length: int = 128


class Perturber:
SEP_TOKEN = '<PERT_SEP>'

def __init__(self):
model_name = "facebook/perturber"
self.model = BartForConditionalGeneration.from_pretrained(model_name)
self.model.config.max_length = 128
self.tokenizer = BartTokenizer.from_pretrained(model_name)
def __init__(self, model=None, config: Optional[PerturberConfig] = None):
self.config = config if config is not None else PerturberConfig()

if model is None:
model_name = "facebook/perturber"
self.model = BartForConditionalGeneration.from_pretrained(model_name)
else:
model_name = model.config.name_or_path
self.model = model
self.model.config.max_length = self.config.max_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
self.tokenizer.add_tokens([self.config.pert_sep_token], special_tokens=True)

self.panda_dict = get_panda_dict()
self.input_template = PerturberTemplate(sep=self.config.sep_token, pert_sep=self.config.pert_sep_token,
original=model_name == "facebook/perturber")

def generate(self, input_txt: str, word: str = "", attribute: str = "", tokenizer_kwargs=None) -> str:
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
input_txt = f"{word}, {attribute} {Perturber.SEP_TOKEN} {input_txt}"
input_txt = self.input_template(input_txt, word, attribute)
output_tokens = self.model.generate(**self.tokenizer(input_txt, return_tensors='pt'), **tokenizer_kwargs)
return self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=True,
max_new_tokens=self.model.config.max_length
)[0]

def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None):
def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None, retry_unchanged=False):

if tokenizer_kwargs is None:
tokenizer_kwargs = {}
Expand All @@ -36,10 +54,23 @@ def __call__(self, input_txt, mode='word_list', tokenizer_kwargs=None):
elif mode == 'word_list':
targets = [w for w in input_txt.split(" ") if w in self.panda_dict]
perturbations = [(t, perturbed) for t in targets for perturbed in self.panda_dict[t]]
if perturbations:
word, attribute = random.choice(perturbations)
input_txt = self.generate(input_txt, word=word, attribute=attribute, tokenizer_kwargs=tokenizer_kwargs)
random.shuffle(perturbations)
for word, attribute in perturbations:
generated_txt = self.generate(input_txt, word=word, attribute=attribute,
tokenizer_kwargs=tokenizer_kwargs)
if generated_txt != input_txt or not retry_unchanged:
return generated_txt
else:
raise NotImplementedError

return input_txt


class PerturberTemplate:

def __init__(self, sep: str = ",", pert_sep: str = "<PERT_SEP>", original: bool = False):
self.sep = sep
self.pert_sep = pert_sep if not original else f" {pert_sep}"

def __call__(self, input_txt: str, word: str = "", attribute: str = "") -> str:
return f"{word}{self.sep} {attribute}{self.pert_sep} {input_txt}"
180 changes: 106 additions & 74 deletions perturbers/training/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,105 +3,111 @@

import lightning
import torch
import torchmetrics
from datasets import load_dataset
from lightning import Trainer, seed_everything
from lightning import pytorch as pl
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from torchmetrics.text import Perplexity, BLEUScore
from transformers import AutoModel, BartForConditionalGeneration # noqa 401
from transformers import AutoTokenizer, get_linear_schedule_with_warmup

from perturbers.training.utils import Config
from perturbers.modeling.perturber import PerturberTemplate
from perturbers.training.utils import TrainingConfig, get_diff_indices


class LightningWrapper(lightning.LightningModule):

def __init__(self, c: Config):
def __init__(self, c: TrainingConfig, tokenizer):
super().__init__()
self.model = AutoModel.from_pretrained(c.model_name)
# self.model = AutoModel.from_pretrained(c.model_name)
self.model = BartForConditionalGeneration.from_pretrained(c.model_name)
self.model.resize_token_embeddings(len(tokenizer))
self.tokenizer = tokenizer

self._device = "cuda" if (c.use_gpu and torch.cuda.is_available()) else "cpu"

self.learning_rate = c.learning_rate
self.num_steps = c.num_steps
self.num_steps = c.train_steps
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)

self.train_metrics = self.get_metric_dict(c)
self.val_metrics = self.get_metric_dict(c)
self.test_metrics = self.get_metric_dict(c)
self.train_metrics = self.get_metric_dict(c, "train")
self.val_metrics = self.get_metric_dict(c, "val")
self.test_metrics = self.get_metric_dict(c, "test")

@staticmethod
def get_metric_dict(c: Config):
return {
'precision': torchmetrics.Precision(task="multiclass", num_classes=c.num_labels, ignore_index=-100),
'recall': torchmetrics.Recall(task="multiclass", num_classes=c.num_labels, ignore_index=-100),
'f1': torchmetrics.F1Score(task="multiclass", num_classes=c.num_labels, ignore_index=-100),
'confusion_matrix': torchmetrics.ConfusionMatrix(task="multiclass", num_classes=c.num_labels,
ignore_index=-100),
def get_metric_dict(self, c: TrainingConfig, split: str):
metrics = {
f'{split}_ppl': Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device),
f'{split}_ppl_perturbed': Perplexity(ignore_index=self.tokenizer.pad_token_id).to(self._device),
}
if split == "test":
metrics[f'{split}_bleu4'] = BLEUScore(n_gram=4).to(self._device)
return metrics

def update_metrics(self, batch, outputs, metrics):
preds = torch.argmax(outputs, dim=2)
for metric in metrics.values():
metric(preds=preds, target=batch['labels'])

def log_metrics(self, metrics, prefix):
def update_metrics(self, batch, outputs, metrics, generations=None):
for metric_key, metric in metrics.items():
if metric_key != "confusion_matrix":
self.log(f"{prefix}_{metric_key}", metric.compute())
else:
conf_matrix = metric.compute()

TP = conf_matrix[1, 1]
TN = conf_matrix[0, 0]
FP = conf_matrix[0, 1]
FN = conf_matrix[1, 0]
if "bleu" in metric_key and generations is not None:
value = metric(
preds=generations,
target=[[_] for _ in self.tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)],
)
elif "ppl" in metric_key:
if "perturbed" in metric_key:
idx = batch["perturbed_idx"]
value = metric(preds=outputs[idx].unsqueeze(0), target=batch['labels'][idx].unsqueeze(0))
else:
value = metric(preds=outputs, target=batch['labels'])
self.log(metric_key, value, on_step=metric_key.startswith("train"), on_epoch=True, prog_bar=True)

# Compute FPR and FNR
FPR = FP / (FP + TN)
FNR = FN / (FN + TP)

self.log(f"{prefix}_FPR", FPR)
self.log(f"{prefix}_FNR", FNR)

def clear_metrics(self, metrics):
@staticmethod
def clear_metrics(metrics):
for metric_key, metric in metrics.items():
metric.reset()

def training_step(self, batch, batch_idx):
outputs, loss = self.model.forward(**batch["inputs"], labels=batch["labels"])
outputs, loss = self.forward(batch)
self.log("train_loss", loss, on_step=True, on_epoch=True)
self.update_metrics(batch, outputs, self.train_metrics)
return loss

def validation_step(self, batch, batch_idx):
outputs, loss = self.model.forward(**batch["inputs"], labels=batch["labels"])
self.log("val_loss", loss, on_step=True, on_epoch=True)
outputs, loss = self.forward(batch)
self.log("val_loss", loss, on_step=False, on_epoch=True)
self.update_metrics(batch, outputs, self.val_metrics)
return loss

def test_step(self, batch, batch_idx):
outputs, loss = self.model.forward(**batch["inputs"], labels=batch["labels"])
self.log("test_loss", loss, on_step=True, on_epoch=True)
self.update_metrics(batch, outputs, self.test_metrics)
outputs, loss = self.forward(batch)
generations = self.generate(batch)
self.log("test_loss", loss, on_step=False, on_epoch=True)
self.update_metrics(batch, outputs, self.test_metrics, generations)
return loss

def on_train_epoch_end(self) -> None:
self.log_metrics(self.train_metrics, prefix="train")
self.clear_metrics(self.train_metrics)

def on_validation_epoch_end(self) -> None:
self.log_metrics(self.val_metrics, prefix="val")
self.clear_metrics(self.val_metrics)

def on_test_epoch_end(self) -> None:
self.log_metrics(self.test_metrics, prefix="test")
self.clear_metrics(self.test_metrics)

def forward(self, batch):
return self.model(**batch)
outputs = self.model(**{k: v for k, v in batch.items() if k in ["input_ids", "attention_mask", "labels"]})
return outputs.logits, outputs.loss

def generate(self, batch):
generations = self.model.generate(
**{k: v for k, v in batch.items() if k in ["input_ids", "attention_mask"]},
max_length=batch['input_ids'].shape[-1],
)
return self.tokenizer.batch_decode(generations, skip_special_tokens=True)

def configure_optimizers(self):
optimizer = torch.optim.AdamW(
params=[p for p in self.model.parameters()],
lr=self.learning_rate,
betas=(0.9, 0.95),
betas=(0.9, 0.999),
)
scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
Expand All @@ -117,41 +123,45 @@ def configure_optimizers(self):
return [optimizer], [scheduler]


def get_collate_fn(c: Config, tokenizer):
def get_collate_fn(c: TrainingConfig, tokenizer):
tokenizer_kwargs = {"padding": True, "truncation": True, "max_length": c.max_length}
input_template = PerturberTemplate(sep=c.sep_token, pert_sep=c.pert_sep_token,
original=c.model_name == "facebook/perturber")

def collate_fn(batch: List):
original, perturbed = [], []
for item in batch:
perturbed_x, perturbed_y = [], []
for i, item in enumerate(batch):
perturbed.append(item['perturbed'])
original.append(
f'{item["selected_word"]}, {item["target_attribute"]} {tokenizer.sep_token} {item["original"]}')

original = tokenizer(original, return_tensors='pt', padding=True, truncation=True, max_length=c.max_length)
perturbed = tokenizer(perturbed, return_tensors='pt', padding=True, truncation=True, max_length=c.max_length)

original_tokens = [
tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=c.max_length)
for text in original_texts
]
perturbed_tokens = [
tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=c.max_length)
for text in perturbed_texts
]
original.append(input_template(item["original"], item["selected_word"], item["target_attribute"]))
idx = get_diff_indices(
tokenizer(item['original'], **tokenizer_kwargs).data['input_ids'],
tokenizer(item['perturbed'], **tokenizer_kwargs).data['input_ids']
)
perturbed_x += [i] * len(idx)
perturbed_y += idx

target_attributes = torch.tensor(target_attributes, dtype=torch.long)
original = tokenizer(original, return_tensors='pt', **tokenizer_kwargs)
perturbed = tokenizer(perturbed, return_tensors='pt', **tokenizer_kwargs)

return original_tokens, perturbed_tokens, target_attributes
return {
"input_ids": original["input_ids"],
"attention_mask": original["attention_mask"],
"labels": perturbed["input_ids"],
"perturbed_idx": (perturbed_x, perturbed_y),
}

return collate_fn


def get_loggers(c: Config):
def get_loggers(c: TrainingConfig):
loggers = [CSVLogger(save_dir=c.save_path, name=c.version)]
if c.use_wandb:
loggers.append(WandbLogger(name="FairNER", save_dir=c.save_path, version=c.version))
loggers.append(WandbLogger(name="perturbers", save_dir=c.save_path, version=c.version))
return loggers


def get_callbacks(c: Config):
def get_callbacks(c: TrainingConfig):
return [
pl.callbacks.EarlyStopping(
monitor="val_loss",
Expand All @@ -171,16 +181,26 @@ def get_callbacks(c: Config):
]


def train_perturber(c: Config):
def train_perturber(c: TrainingConfig):
seed_everything(c.seed, workers=True)

model = LightningWrapper(c)
if c.debug:
c.train_steps = 10
c.val_steps = 5
c.accumulate_grad_batches = 1

tokenizer = AutoTokenizer.from_pretrained(c.model_name, add_prefix_space=True)
tokenizer.add_tokens([c.sep_token, c.pert_sep_token], special_tokens=True)
model = LightningWrapper(c, tokenizer)
dataset = load_dataset(c.dataset_name)

train_ds = dataset["train"]
val_ds = dataset["validation"]

if c.debug:
train_ds = train_ds.select(range(128))
val_ds = val_ds.select(range(128))

collate_fn = get_collate_fn(c, tokenizer)

train_dataloader = DataLoader(train_ds, shuffle=True, batch_size=c.train_batch_size, collate_fn=collate_fn)
Expand All @@ -197,6 +217,8 @@ def train_perturber(c: Config):
precision=16 if c.fp16 else 32,
logger=get_loggers(c),
check_val_every_n_epoch=None,
gradient_clip_val=c.gradient_clipping_value,
accumulate_grad_batches=c.accumulate_grad_batches,
)

trainer.fit(
Expand All @@ -209,3 +231,13 @@ def train_perturber(c: Config):
dataloaders=val_dataloader,
ckpt_path='best',
)

if c.output_path:
model.model.save_pretrained(c.output_path)
tokenizer.save_pretrained(c.output_path)

if c.push_to_hub:
model.model.push_to_hub(c.hub_repo_id)
tokenizer.push_to_hub(c.hub_repo_id)

return model.model
Loading

0 comments on commit 8154b50

Please sign in to comment.