-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement supervised fine-tuning #31
Changes from 18 commits
6c3b9a9
3f6cc30
b61d4fc
004e417
966185b
0bf81b0
d16dfdb
3083aa4
9a6fefe
1e0196a
99a1495
2c84370
ed81657
36f0b04
9f71d7a
5752166
6fa87b5
2d04144
4adab46
59c9484
63f6b28
39dfe4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import evaluate | ||
|
||
from typing import Callable, Dict, List | ||
from .base import Pipeline | ||
from repepo.core.types import Dataset | ||
from .utils import AverageMeter | ||
|
||
Callback = Callable[[Pipeline], Dict[str, float]] | ||
|
||
|
||
class Metrics: | ||
def __init__(self): | ||
# TODO: make configurable | ||
self.bleu = evaluate.load("bleu") | ||
self.rouge = evaluate.load("rouge") | ||
|
||
def compute_metrics( | ||
self, predictions: List[str], references: List[str] | ||
) -> Dict[str, float]: | ||
bleu_results = self.bleu.compute(predictions=predictions, references=references) | ||
rouge_results = self.rouge.compute( | ||
predictions=predictions, references=references | ||
) | ||
assert bleu_results is not None | ||
assert rouge_results is not None | ||
return { | ||
"bleu": bleu_results["bleu"], | ||
"rouge1": rouge_results["rouge1"], | ||
} | ||
|
||
|
||
class EvalCallback: | ||
def __init__(self, val_datasets: Dict[str, Dataset]): | ||
self.metric_fns = Metrics() | ||
self.meter = AverageMeter() | ||
# TODO: eval dataloader | ||
|
||
def __call__(self, pipeline: Pipeline) -> Dict[str, float]: | ||
self.meter.reset() | ||
model = pipeline.model | ||
tokenizer = pipeline.tokenizer | ||
log_dict = {} | ||
|
||
# TODO: implement | ||
|
||
return log_dict |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,196 @@ | ||
from repepo.core import Dataset | ||
from repepo.core import Pipeline | ||
|
||
from .base import BaseAlgorithm | ||
from repepo.algorithms.base import Algorithm | ||
from overrides import override | ||
|
||
from dataclasses import dataclass | ||
from dataclasses import field | ||
from typing import List, Optional | ||
|
||
class SupervisedFineTuning(BaseAlgorithm): | ||
def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline: | ||
import torch | ||
import pprint | ||
from torch.utils.data import DataLoader | ||
from transformers.optimization import AdamW | ||
from transformers.optimization import get_linear_schedule_with_warmup | ||
from repepo.core.types import Completion | ||
|
||
from repepo.data import make_dataset, DatasetSpec | ||
from repepo.data import utils | ||
from repepo.data.dataset import sft | ||
from repepo.utils.metrics import Metrics, AverageMeter | ||
from repepo.variables import Environ | ||
from repepo.variables import Model | ||
|
||
|
||
@dataclass | ||
class SupervisedFineTuningConfig: | ||
# Training config | ||
batch_size: int = 256 # Training batch size | ||
shuffle: bool = True # Whether to shuffle the dataset | ||
num_train_epochs: int = 10 | ||
learning_rate: float = 5e-5 | ||
|
||
# Experiment config | ||
device: str = "cuda" | ||
|
||
|
||
@dataclass | ||
class WandbConfig: | ||
project: str = field(default=Environ.WandbProject) | ||
entity: str = field(default=Environ.WandbEntity) | ||
name: str = field(default="sft-simple") | ||
track: bool = field(default=False) | ||
|
||
|
||
class WandbLogger: | ||
def __init__(self, config: WandbConfig): | ||
self.config = config | ||
if self.config.track: | ||
import wandb | ||
|
||
self.wandb = wandb | ||
|
||
def __enter__(self): | ||
if self.config.track: | ||
self.wandb.init( | ||
project=self.config.project, | ||
entity=self.config.entity, | ||
name=self.config.name, | ||
) | ||
return self | ||
|
||
def log(self, *args, **kwargs): | ||
if self.config.track: | ||
self.wandb.log(*args, **kwargs) | ||
# Else no-op | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
if self.config.track: | ||
self.wandb.finish() | ||
|
||
|
||
class SupervisedFineTuning(Algorithm): | ||
@override | ||
def __init__(self, config: SupervisedFineTuningConfig): | ||
self.config = config | ||
|
||
def run( | ||
dtch1997 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, pipeline: Pipeline, dataset: Dataset, logger: Optional[WandbLogger] = None | ||
) -> Pipeline: | ||
"""Modifies the base model weights""" | ||
|
||
# Make supervised data module | ||
# Run training, with optional WandB eval | ||
# Load model and tokenizer | ||
model = pipeline.model | ||
tokenizer = pipeline.tokenizer | ||
|
||
# Add new tokens to tokenizer | ||
# This is because many tokenizers don't have a padding token | ||
special_tokens_dict = utils.get_pad_token(tokenizer) | ||
special_tokens_dict.update(utils.get_special_tokens(tokenizer)) | ||
utils.smart_tokenizer_and_embedding_resize( | ||
special_tokens_dict=special_tokens_dict, | ||
tokenizer=tokenizer, | ||
model=model, | ||
) | ||
|
||
# Make train dataloader | ||
completions: List[Completion] = pipeline.formatter.apply_list(dataset) | ||
_ds = sft.SupervisedDataset(completions, tokenizer=tokenizer) | ||
data_collator = sft.DataCollatorForSupervisedDataset(tokenizer=tokenizer) | ||
train_dataloader = DataLoader( | ||
_ds, | ||
batch_size=self.config.batch_size, | ||
shuffle=self.config.shuffle, | ||
collate_fn=data_collator, | ||
) | ||
|
||
# Set device | ||
device = self.config.device if torch.cuda.is_available() else "cpu" | ||
model.to(device) # type: ignore | ||
|
||
# Set up optimizer and scheduler | ||
num_training_steps = self.config.num_train_epochs * len(train_dataloader) | ||
optimizer = AdamW(model.parameters(), lr=self.config.learning_rate) | ||
scheduler = get_linear_schedule_with_warmup( | ||
optimizer, num_warmup_steps=0, num_training_steps=num_training_steps | ||
) | ||
|
||
# Set up metrics, meters | ||
metric_fns = Metrics() | ||
meter = AverageMeter() | ||
|
||
# Training loop | ||
global_step = 0 | ||
model.train() | ||
for epoch in range(int(self.config.num_train_epochs)): | ||
# epoch_iterator = tqdm(train_dataloader, desc="Training") | ||
epoch_iterator = train_dataloader | ||
for step, batch in enumerate(epoch_iterator): | ||
global_step += self.config.batch_size | ||
batch = {k: v.to(device) for k, v in batch.items()} | ||
outputs = model( | ||
input_ids=batch["input_ids"], | ||
labels=batch["labels"], | ||
attention_mask=batch["attention_mask"], | ||
) | ||
loss = outputs.loss | ||
optimizer.zero_grad() | ||
loss.backward() | ||
# Gradient clipping | ||
torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0) | ||
optimizer.step() | ||
scheduler.step() | ||
print(f"epoch : {epoch} | step: {step} | loss: {loss}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: would be better to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah let's do this later. I also want to figure out how to separate the display from the core functionality. |
||
if logger is not None: | ||
logger.log({"train/loss": loss}, step=global_step) | ||
|
||
# TODO: Evaluation callback? | ||
|
||
# keep pyright happy for now | ||
return pipeline | ||
|
||
|
||
if __name__ == "__main__": | ||
import pyrallis | ||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM | ||
from transformers.models.auto.tokenization_auto import AutoTokenizer | ||
|
||
@dataclass | ||
class TrainSFTConfig: | ||
sft: SupervisedFineTuningConfig = SupervisedFineTuningConfig() | ||
dataset: DatasetSpec = DatasetSpec(name="truthfulqa") | ||
wandb: WandbConfig = WandbConfig() | ||
|
||
model_name_or_path: str = Model.Pythia70m | ||
model_max_length: int = field( | ||
default=512, | ||
metadata={ | ||
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." | ||
}, | ||
) | ||
cache_dir: str = Environ.HuggingfaceCacheDir | ||
output_dir: str = Environ.OutputDir | ||
|
||
config = pyrallis.parse(TrainSFTConfig) | ||
pprint.pprint(config) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
config.model_name_or_path, | ||
cache_dir=config.cache_dir, | ||
) | ||
|
||
# TODO: figure out typing for this | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
config.model_name_or_path, | ||
cache_dir=config.cache_dir, | ||
model_max_length=config.model_max_length, | ||
padding_side="right", | ||
use_fast=True, | ||
) | ||
|
||
pipeline = Pipeline(model, tokenizer) | ||
dataset = make_dataset(config.dataset) | ||
|
||
with WandbLogger(config.wandb) as logger: | ||
algorithm = SupervisedFineTuning(config.sft) | ||
algorithm.run(pipeline, dataset, logger=logger) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
class AverageMeter: | ||
def __init__(self): | ||
self.metrics = {} | ||
|
||
def update(self, name, value, n=1): | ||
""" | ||
Update a named metric with a new value. | ||
|
||
Parameters: | ||
- name: The name of the metric. | ||
- value: The new value to incorporate. | ||
- n: The weight/count of the value. Default is 1. | ||
""" | ||
if name not in self.metrics: | ||
self.metrics[name] = {"val": 0, "sum": 0, "count": 0, "avg": 0} | ||
|
||
metric = self.metrics[name] | ||
metric["val"] = value | ||
metric["sum"] += value * n | ||
metric["count"] += n | ||
metric["avg"] = metric["sum"] / metric["count"] | ||
|
||
def get_avg(self, name): | ||
""" | ||
Get the running average of a named metric. | ||
|
||
Parameters: | ||
- name: The name of the metric. | ||
""" | ||
return self.metrics[name]["avg"] if name in self.metrics else None | ||
|
||
def reset(self, name=None): | ||
""" | ||
Resets statistics of a named metric or all metrics if name is None. | ||
|
||
Parameters: | ||
- name: The name of the metric. | ||
""" | ||
if name: | ||
self.metrics[name] = {"val": 0, "sum": 0, "count": 0, "avg": 0} | ||
else: | ||
for metric in self.metrics.values(): | ||
metric["val"] = 0 | ||
metric["sum"] = 0 | ||
metric["count"] = 0 | ||
metric["avg"] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the same
Evaluator
objects that we use for benchmarking? Theval_dataset
should respond to the same metrics as thetest_dataset
I would think. Unless the idea is that we can set a specific validator that should be used by SFT to pick the best performing result? Regardless, theEvaluator
type already returns a float so would be suited to thisThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I haven't looked closely at Evaluator class yet but will do so.