Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement supervised fine-tuning #31

Merged
merged 22 commits into from
Dec 13, 2023
Merged

Implement supervised fine-tuning #31

merged 22 commits into from
Dec 13, 2023

Conversation

dtch1997
Copy link
Owner

@dtch1997 dtch1997 commented Dec 6, 2023

  • Implement supervised fine-tuning within our framework in repepo/algorithms/sft.py
  • Add example usage in same script under main guard
  • Added some other utilities like a WandbLogger instance
  • Laid some groundwork for eval callbacks within the run loop, though I think maybe this should be derived from the Benchmark and Evaluator classes
  • Implemented make_dataset from a DatasetSpec which allows using a custom split of the dataset.

Notes

  • Training configuration can be updated from command line. See example in comment below.
  • Had to slightly modify algorithm.run API to allow for logging and eval callbacks. Maybe we can just allow algorithm.run to accept arbitrary kwargs.
  • Added to_dict and from_dict methods to Example and Completion in repepo.core.types.

@dtch1997
Copy link
Owner Author

dtch1997 commented Dec 6, 2023

SFT overfits a small dataset.

W B Chart 06_12_2023, 14 57 27

To reproduce:

python repepo/algorithms/sft.py --sft.batch_size 4 --wandb.track True

Reference: https://wandb.ai/ucl-dark/Rep-Eng/runs/zuuj7mvq?workspace=user-dtch1997

Notes:

  • memory requirements seem absurdly high. On a RTX 3080, I was only able to run with Pythia70m and a batch size of 4. I'm sure this is wrong as I've run with much higher batch sizes in the past.

@dtch1997
Copy link
Owner Author

dtch1997 commented Dec 6, 2023

Rebased onto main which has updated README

Copy link
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could use some basic test coverage as well to ensure it works as expected


class EvalCallback:
def __init__(self, val_datasets: Dict[str, BaseDataset]):
self.metric_fns = Metrics()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the same Evaluator objects that we use for benchmarking? The val_dataset should respond to the same metrics as the test_dataset I would think. Unless the idea is that we can set a specific validator that should be used by SFT to pick the best performing result? Regardless, the Evaluator type already returns a float so would be suited to this

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I haven't looked closely at Evaluator class yet but will do so.

repepo/algorithms/sft.py Outdated Show resolved Hide resolved
repepo/algorithms/sft.py Show resolved Hide resolved
repepo/core/types.py Outdated Show resolved Hide resolved
repepo/data/__init__.py Outdated Show resolved Hide resolved
repepo/data/__init__.py Outdated Show resolved Hide resolved
repepo/algorithms/sft.py Outdated Show resolved Hide resolved
torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
print(f"epoch : {epoch} | step: {step} | loss: {loss}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be better to use tqdm to get an updating progress bar rather than printing directly. It would also be good to add a way to disable this outputting to the screen, maybe with a param to run() called verbose: bool? We can figure that out later though potentially, as it's more polish than core functionality

Copy link
Owner Author

@dtch1997 dtch1997 Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's do this later. I also want to figure out how to separate the display from the core functionality.

@dtch1997
Copy link
Owner Author

dtch1997 commented Dec 6, 2023

@chanind do you have any idea why tests/run/test_run_repe.py is failing in CI? I can't reproduce the failure locally from this branch

tests/conftest.py Outdated Show resolved Hide resolved
@dtch1997
Copy link
Owner Author

Will think about how to integrate Evaluator into a callback in a subsequent PR


new_pipeline = algorithm.run(pipeline, dataset=dataset)

# Skip testing outputs as they will be gibberish
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try testing with larger_model instead of model - that's pythia 360m and seems to generate real answers, e.g. https://github.com/dtch1997/repepo/blob/main/tests/core/test_benchmark.py#L10. Alternatively, we could try overfitting on a few examples and just asserting that it does output the stuff we overfit on. For instance, we could overfit on a single wrong example, like "Paris is in" "Germany", just to verify that it is training.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. but not gonna do it immediately. Will create an issue in backlog

return AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
return AutoTokenizer.from_pretrained(
"EleutherAI/pythia-70m",
model_max_length=128, # Required to avoid overflow error in SFT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 this is a good call, I feel like this has come up in other places too

Copy link
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test could be improved, but otherwise LGTM

@dtch1997 dtch1997 merged commit 1d68251 into main Dec 13, 2023
2 checks passed
@dtch1997 dtch1997 deleted the sft branch January 31, 2024 11:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants