Skip to content

Conversation

@cdoern
Copy link
Contributor

@cdoern cdoern commented May 9, 2025

What does this PR do?

adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc

the provider comes with one recipe finetune_single_device which works both with and without LoRA.

any model that is a valid HF identifier can be given and the model will be pulled.

this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box

The provider processes the given dataset into the proper format, establishes the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training

if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every save_steps and at the end of training.

Test Plan

re-enabled post_training integration test suite with a singular test that loads the simpleqa dataset: https://huggingface.co/datasets/llamastack/simpleqa and a tiny granite model: https://huggingface.co/ibm-granite/granite-3.3-2b-instruct. The test now uses the llama stack client and the proper post_training API

runs one step with a batch_size of 1. This test runs on CPU on the Ubuntu runner so it needs to be a small batch and a single step.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 9, 2025
@cdoern cdoern force-pushed the hf branch 2 times, most recently from 42e0c06 to c8fe49c Compare May 12, 2025 01:24
@cdoern cdoern force-pushed the hf branch 3 times, most recently from e323876 to 380b4b2 Compare May 12, 2025 21:27
@ashwinb
Copy link
Contributor

ashwinb commented May 13, 2025

This is great!

@booxter
Copy link
Contributor

booxter commented May 14, 2025

This is very promising overall; once we have it in-tree, we should be able to revive the integration tests suite for post-training API that was blocked because llama consolidated model builds were not available for fetch. #1786

@cdoern cdoern force-pushed the hf branch 8 times, most recently from c7e8fd4 to 6cd0a75 Compare May 14, 2025 15:49
finally:
# Clean up resources
if hasattr(trainer, "model"):
if device.type != "cpu":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines 408-414 are implemented in torchtune. Please move the code to a common function (evacuate_model_from_device?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created common/utils.py for post_training and clear_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am lazy importing torch so we don't hit issues within this method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the name @booxter suggested more than clear_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I can change it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@booxter
Copy link
Contributor

booxter commented May 14, 2025

I've reviewed the code, it looks reasonable enough to merge it after some basic cleanup. I don't expect you to fix the blocking problem as part of this PR, but making a note in code about the problem would be nice. Thanks a lot for enabling (some) integration tests even! I will follow up with enabling other test cases in #1786 once this PR lands.

🚀

@cdoern cdoern changed the title [WIP] feat: add huggingface post_training impl feat: add huggingface post_training impl May 14, 2025
@cdoern cdoern marked this pull request as ready for review May 14, 2025 20:02
@cdoern cdoern requested a review from ashwinb May 15, 2025 19:59
@cdoern cdoern force-pushed the hf branch 3 times, most recently from 787c440 to 9d7b77c Compare May 15, 2025 21:11
@cdoern cdoern requested a review from booxter May 16, 2025 12:42
@booxter
Copy link
Contributor

booxter commented May 16, 2025

One question about proper procedure for model evacuation to CPU.

Copy link
Contributor

@booxter booxter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

@booxter
Copy link
Contributor

booxter commented May 16, 2025

@ashwinb I think this is ready to go. I'd like to get this in so that we can enable the rest of the tests for the API: #1786 🙏

@cdoern
Copy link
Contributor Author

cdoern commented May 16, 2025

rebased

assert job_artifacts.checkpoints[0].epoch == 0
assert "/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0" in job_artifacts.checkpoints[0].path

while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a timeout here so the CI action can die appropriately

Copy link
Contributor Author

@cdoern cdoern May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a timeout using pytests pytest-timeout package. This can be applied to other tests as well in the future for a succinct timeout

@cdoern cdoern requested a review from ashwinb May 16, 2025 19:16
@cdoern cdoern force-pushed the hf branch 3 times, most recently from 9d1739d to c63efdc Compare May 16, 2025 20:00
@cdoern
Copy link
Contributor Author

cdoern commented May 16, 2025

rebased

cdoern added 3 commits May 16, 2025 16:37
adds an inline HF SFTTrainer provider. Alongside touchtune -- this is a super popular option for running training jobs. The config allows a user to specify some key fields such as a model, chat_template, device, etc

the provider comes with one recipe `finetune_single_device` which works both with and without LoRA.

any model that is a valid HF identifier can be given and the model will be pulled.

this has been tested so far with CPU and MPS device types, but should be compatible with CUDA out of the box

The provider processes the given dataset into the proper format, established the various steps per epoch, steps per save, steps per eval, sets a sane SFTConfig, and runs n_epochs of training

if checkpoint_dir is none, no model is saved. If there is a checkpoint dir, a model is saved every `save_steps` and at the end of training.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
the experimental_post_training template now uses HF post_training and dataset providers

Signed-off-by: Charlie Doern <cdoern@redhat.com>
set inline::huggingface as the default post_training provider for the ollama distribution and add integration tests for post_training

Signed-off-by: Charlie Doern <cdoern@redhat.com>
currently this impl hangs because of `trainer.train()` blocking.

Re-write the implementation to kick off the model download, device instantiation, dataset processing, and training in a monitored subprocess.

All of these steps need to be in a subprocess or else different devices are used which causes torch errors.

Signed-off-by: Charlie Doern <cdoern@redhat.com>
Copy link
Contributor

@ashwinb ashwinb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets go 🚀

@ashwinb ashwinb merged commit f02f7b2 into llamastack:main May 16, 2025
24 checks passed
@cdoern cdoern mentioned this pull request May 28, 2025
41 tasks
@reluctantfuturist reluctantfuturist mentioned this pull request Jun 3, 2025
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. new-in-tree-provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants