-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement supervised fine-tuning #31
Conversation
SFT overfits a small dataset. To reproduce: python repepo/algorithms/sft.py --sft.batch_size 4 --wandb.track True Reference: https://wandb.ai/ucl-dark/Rep-Eng/runs/zuuj7mvq?workspace=user-dtch1997 Notes:
|
Rebased onto |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could use some basic test coverage as well to ensure it works as expected
|
||
class EvalCallback: | ||
def __init__(self, val_datasets: Dict[str, BaseDataset]): | ||
self.metric_fns = Metrics() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the same Evaluator
objects that we use for benchmarking? The val_dataset
should respond to the same metrics as the test_dataset
I would think. Unless the idea is that we can set a specific validator that should be used by SFT to pick the best performing result? Regardless, the Evaluator
type already returns a float so would be suited to this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I haven't looked closely at Evaluator class yet but will do so.
torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), 1.0) | ||
optimizer.step() | ||
scheduler.step() | ||
print(f"epoch : {epoch} | step: {step} | loss: {loss}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would be better to use tqdm
to get an updating progress bar rather than printing directly. It would also be good to add a way to disable this outputting to the screen, maybe with a param to run()
called verbose: bool
? We can figure that out later though potentially, as it's more polish than core functionality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah let's do this later. I also want to figure out how to separate the display from the core functionality.
@chanind do you have any idea why |
Will think about how to integrate |
|
||
new_pipeline = algorithm.run(pipeline, dataset=dataset) | ||
|
||
# Skip testing outputs as they will be gibberish |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can try testing with larger_model
instead of model
- that's pythia 360m and seems to generate real answers, e.g. https://github.com/dtch1997/repepo/blob/main/tests/core/test_benchmark.py#L10. Alternatively, we could try overfitting on a few examples and just asserting that it does output the stuff we overfit on. For instance, we could overfit on a single wrong example, like "Paris is in"
"Germany"
, just to verify that it is training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. but not gonna do it immediately. Will create an issue in backlog
return AutoTokenizer.from_pretrained("EleutherAI/pythia-70m") | ||
return AutoTokenizer.from_pretrained( | ||
"EleutherAI/pythia-70m", | ||
model_max_length=128, # Required to avoid overflow error in SFT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 this is a good call, I feel like this has come up in other places too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test could be improved, but otherwise LGTM
repepo/algorithms/sft.py
WandbLogger
instancerun
loop, though I think maybe this should be derived from theBenchmark
andEvaluator
classesmake_dataset
from aDatasetSpec
which allows using a custom split of the dataset.Notes
algorithm.run
API to allow for logging and eval callbacks. Maybe we can just allowalgorithm.run
to accept arbitrarykwargs
.to_dict
andfrom_dict
methods toExample
andCompletion
inrepepo.core.types.