Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Dec 20, 2023
1 parent e8eb636 commit 23f628a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
6 changes: 6 additions & 0 deletions repepo/algorithms/config/sft_pythia2.8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
model_name_or_path: "EleutherAI/pythia-2.8b"
sft:
num_train_epochs: 10
gradient_checkpointing: True
gradient_accumulation_steps: 1
mixed_precision: bf16
27 changes: 25 additions & 2 deletions repepo/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ class SupervisedFineTuningConfig:
mixed_precision: str = 'bf16'


def inspect_batch(batch: dict[str, torch.Tensor], tokenizer):

input_ids = batch['input_ids'][0]
labels = batch['labels'][0]
loss_active_idx = (labels != -100).nonzero()
loss_inactive_idx = (labels == -100).nonzero()
prompt_str = tokenizer.decode(input_ids[loss_inactive_idx].squeeze(-1), skip_special_tokens = True)
prediction_str = tokenizer.decode(input_ids[loss_active_idx].squeeze(-1), skip_special_tokens = True)
return {
'prompt_str': prompt_str,
'prediction_str': prediction_str
}


class SupervisedFineTuning(Algorithm):
@override
def __init__(self, config: SupervisedFineTuningConfig):
Expand Down Expand Up @@ -137,10 +151,19 @@ def run(
torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
print(f"epoch : {epoch} | step: {step} | loss: {loss}")

if logger is not None:
logger.log({"train/loss": loss.item()}, step=global_step)

# Log the gradients for debugging
grads = [
param.grad.detach().flatten()
for param in model.parameters()
if param.grad is not None
]
norm = torch.cat(grads).norm()
logger.log({"train/grad_norm": norm.item()}, step = global_step)

del outputs
del batch

Expand Down Expand Up @@ -176,7 +199,7 @@ def run(
@dataclass
class TrainSFTConfig:
sft: SupervisedFineTuningConfig = SupervisedFineTuningConfig()
dataset: DatasetSpec = DatasetSpec(name="truthfulqa", split = ":80%")
dataset: DatasetSpec = DatasetSpec(name="truthfulqa", split = ":1%")
val_dataset: DatasetSpec = DatasetSpec(name = "truthfulqa", split = "80:100%")
wandb: WandbConfig = WandbConfig()

Expand Down

0 comments on commit 23f628a

Please sign in to comment.