Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement supervised fine-tuning #31

Merged
merged 22 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ single-line-exclusions = ["typing"]
ban-relative-imports = "all"

[tool.pyright]
venvPath = "."
venv = ".venv"
include = ["repepo"]
exclude = ["**/node_modules", "**/__pycache__", "repepo/repe"]

Expand Down Expand Up @@ -89,4 +91,5 @@ dev = [
"jupyter>=1.0.0",
"pre-commit>=3.5.0",
"syrupy>=4.6.0",
"pyright>=1.1.339",
]
7 changes: 5 additions & 2 deletions repepo/algorithms/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import abc

from repepo.core import Dataset, Pipeline
from typing import Dict, Any


class BaseAlgorithm(abc.ABC):
class Algorithm(abc.ABC):
@abc.abstractmethod
def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
def run(
self, pipeline: Pipeline, dataset: Dataset, **kwargs: Dict[str, Any]
) -> Pipeline:
raise NotImplementedError()


Expand Down
46 changes: 46 additions & 0 deletions repepo/algorithms/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import evaluate

from typing import Callable, Dict, List
from .base import Pipeline
from repepo.core.types import Dataset
from .utils import AverageMeter

Callback = Callable[[Pipeline], Dict[str, float]]


class Metrics:
def __init__(self):
# TODO: make configurable
self.bleu = evaluate.load("bleu")
self.rouge = evaluate.load("rouge")

def compute_metrics(
self, predictions: List[str], references: List[str]
) -> Dict[str, float]:
bleu_results = self.bleu.compute(predictions=predictions, references=references)
rouge_results = self.rouge.compute(
predictions=predictions, references=references
)
assert bleu_results is not None
assert rouge_results is not None
return {
"bleu": bleu_results["bleu"],
"rouge1": rouge_results["rouge1"],
}


class EvalCallback:
def __init__(self, val_datasets: Dict[str, Dataset]):
self.metric_fns = Metrics()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the same Evaluator objects that we use for benchmarking? The val_dataset should respond to the same metrics as the test_dataset I would think. Unless the idea is that we can set a specific validator that should be used by SFT to pick the best performing result? Regardless, the Evaluator type already returns a float so would be suited to this

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I haven't looked closely at Evaluator class yet but will do so.

self.meter = AverageMeter()
# TODO: eval dataloader

def __call__(self, pipeline: Pipeline) -> Dict[str, float]:
self.meter.reset()
model = pipeline.model
tokenizer = pipeline.tokenizer
log_dict = {}

# TODO: implement

return log_dict
4 changes: 2 additions & 2 deletions repepo/algorithms/icl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import replace
from repepo.core.prompt import FewShotPrompter
from .base import BaseAlgorithm
from .base import Algorithm
from repepo.core import Pipeline, Dataset


class InContextLearning(BaseAlgorithm):
class InContextLearning(Algorithm):
def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
"""Uses an in-context learning prefix to prompts"""

Expand Down
4 changes: 2 additions & 2 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from repepo.repe.rep_reading_pipeline import RepReadingPipeline
from repepo.repe.rep_control_pipeline import RepControlPipeline

from repepo.algorithms.base import BaseAlgorithm
from repepo.algorithms.base import Algorithm

from repepo.core.prompt import IdentityPrompter
from repepo.core.format import InstructionFormatter

import torch


class Repe(BaseAlgorithm):
class Repe(Algorithm):
# TODO: linting

def __init__(self):
Expand Down
193 changes: 187 additions & 6 deletions repepo/algorithms/sft.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,196 @@
from repepo.core import Dataset
from repepo.core import Pipeline

from .base import BaseAlgorithm
from repepo.algorithms.base import Algorithm
from overrides import override

from dataclasses import dataclass
from dataclasses import field
from typing import List, Optional

class SupervisedFineTuning(BaseAlgorithm):
def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
import torch
import pprint
from torch.utils.data import DataLoader
from transformers.optimization import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
from repepo.core.types import Completion

from repepo.data import make_dataset, DatasetSpec
from repepo.data import utils
from repepo.data.dataset import sft
from repepo.utils.metrics import Metrics, AverageMeter
from repepo.variables import Environ
from repepo.variables import Model


@dataclass
class SupervisedFineTuningConfig:
# Training config
batch_size: int = 256 # Training batch size
shuffle: bool = True # Whether to shuffle the dataset
num_train_epochs: int = 10
learning_rate: float = 5e-5

# Experiment config
device: str = "cuda"


@dataclass
class WandbConfig:
project: str = field(default=Environ.WandbProject)
entity: str = field(default=Environ.WandbEntity)
name: str = field(default="sft-simple")
track: bool = field(default=False)


class WandbLogger:
def __init__(self, config: WandbConfig):
self.config = config
if self.config.track:
import wandb

self.wandb = wandb

def __enter__(self):
if self.config.track:
self.wandb.init(
project=self.config.project,
entity=self.config.entity,
name=self.config.name,
)
return self

def log(self, *args, **kwargs):
if self.config.track:
self.wandb.log(*args, **kwargs)
# Else no-op

def __exit__(self, exc_type, exc_val, exc_tb):
if self.config.track:
self.wandb.finish()


class SupervisedFineTuning(Algorithm):
@override
def __init__(self, config: SupervisedFineTuningConfig):
self.config = config

def run(
dtch1997 marked this conversation as resolved.
Show resolved Hide resolved
self, pipeline: Pipeline, dataset: Dataset, logger: Optional[WandbLogger] = None
) -> Pipeline:
"""Modifies the base model weights"""

# Make supervised data module
# Run training, with optional WandB eval
# Load model and tokenizer
model = pipeline.model
tokenizer = pipeline.tokenizer

# Add new tokens to tokenizer
# This is because many tokenizers don't have a padding token
special_tokens_dict = utils.get_pad_token(tokenizer)
special_tokens_dict.update(utils.get_special_tokens(tokenizer))
utils.smart_tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
)

# Make train dataloader
completions: List[Completion] = pipeline.formatter.apply_list(dataset)
_ds = sft.SupervisedDataset(completions, tokenizer=tokenizer)
data_collator = sft.DataCollatorForSupervisedDataset(tokenizer=tokenizer)
train_dataloader = DataLoader(
_ds,
batch_size=self.config.batch_size,
shuffle=self.config.shuffle,
collate_fn=data_collator,
)

# Set device
device = self.config.device if torch.cuda.is_available() else "cpu"
model.to(device) # type: ignore

# Set up optimizer and scheduler
num_training_steps = self.config.num_train_epochs * len(train_dataloader)
optimizer = AdamW(model.parameters(), lr=self.config.learning_rate)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

# Set up metrics, meters
metric_fns = Metrics()
meter = AverageMeter()

# Training loop
global_step = 0
model.train()
for epoch in range(int(self.config.num_train_epochs)):
# epoch_iterator = tqdm(train_dataloader, desc="Training")
epoch_iterator = train_dataloader
for step, batch in enumerate(epoch_iterator):
global_step += self.config.batch_size
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(
input_ids=batch["input_ids"],
labels=batch["labels"],
attention_mask=batch["attention_mask"],
)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
print(f"epoch : {epoch} | step: {step} | loss: {loss}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be better to use tqdm to get an updating progress bar rather than printing directly. It would also be good to add a way to disable this outputting to the screen, maybe with a param to run() called verbose: bool? We can figure that out later though potentially, as it's more polish than core functionality

Copy link
Owner Author

@dtch1997 dtch1997 Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's do this later. I also want to figure out how to separate the display from the core functionality.

if logger is not None:
logger.log({"train/loss": loss}, step=global_step)

# TODO: Evaluation callback?

# keep pyright happy for now
return pipeline


if __name__ == "__main__":
import pyrallis
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.models.auto.tokenization_auto import AutoTokenizer

@dataclass
class TrainSFTConfig:
sft: SupervisedFineTuningConfig = SupervisedFineTuningConfig()
dataset: DatasetSpec = DatasetSpec(name="truthfulqa")
wandb: WandbConfig = WandbConfig()

model_name_or_path: str = Model.Pythia70m
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
cache_dir: str = Environ.HuggingfaceCacheDir
output_dir: str = Environ.OutputDir

config = pyrallis.parse(TrainSFTConfig)
pprint.pprint(config)

model = AutoModelForCausalLM.from_pretrained(
config.model_name_or_path,
cache_dir=config.cache_dir,
)

# TODO: figure out typing for this
tokenizer = AutoTokenizer.from_pretrained(
config.model_name_or_path,
cache_dir=config.cache_dir,
model_max_length=config.model_max_length,
padding_side="right",
use_fast=True,
)

pipeline = Pipeline(model, tokenizer)
dataset = make_dataset(config.dataset)

with WandbLogger(config.wandb) as logger:
algorithm = SupervisedFineTuning(config.sft)
algorithm.run(pipeline, dataset, logger=logger)
46 changes: 46 additions & 0 deletions repepo/algorithms/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
class AverageMeter:
def __init__(self):
self.metrics = {}

def update(self, name, value, n=1):
"""
Update a named metric with a new value.

Parameters:
- name: The name of the metric.
- value: The new value to incorporate.
- n: The weight/count of the value. Default is 1.
"""
if name not in self.metrics:
self.metrics[name] = {"val": 0, "sum": 0, "count": 0, "avg": 0}

metric = self.metrics[name]
metric["val"] = value
metric["sum"] += value * n
metric["count"] += n
metric["avg"] = metric["sum"] / metric["count"]

def get_avg(self, name):
"""
Get the running average of a named metric.

Parameters:
- name: The name of the metric.
"""
return self.metrics[name]["avg"] if name in self.metrics else None

def reset(self, name=None):
"""
Resets statistics of a named metric or all metrics if name is None.

Parameters:
- name: The name of the metric.
"""
if name:
self.metrics[name] = {"val": 0, "sum": 0, "count": 0, "avg": 0}
else:
for metric in self.metrics.values():
metric["val"] = 0
metric["sum"] = 0
metric["count"] = 0
metric["avg"] = 0
Loading
Loading