Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tune run and refactor CLI #586

Merged
merged 39 commits into from
Mar 29, 2024
Merged

Add tune run and refactor CLI #586

merged 39 commits into from
Mar 29, 2024

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Mar 25, 2024

Context

Once again, I fail to limit the scope of a PR. As I was adding the tune run CLI command which we would use to launch all our recipes, I noticed how clunky our current CLI code was and needed to take on the refactor. Please take a look at all the code as there are some changes to the UX of the CLI, as well. This should help us in the long run as it makes our code more debuggable, adds better error handling, adds more tests, and makes the CLI more extensible.

Look through the changelog and the code for all the information and some FAQs right below.

Why is the CLI so slow?

Startup time for the tune CLI is slowwww:

(torchtune-2) [jrcummings@devvm050.nha0 ~/projects/torchtune (fixup-cli)]$ time tune
usage: tune [-h] {download,ls,cp,run,validate} ...

...

real    0m5.995s
user    0m5.557s
sys     0m3.480s

Compare this to HuggingFace's accelerate:

(torchtune-2) [jrcummings@devvm050.nha0 ~/projects/torchtune (fixup-cli)]$ time accelerate
usage: accelerate <command> [<args>]

...

real    0m4.474s
user    0m4.176s
sys     0m3.298s

And the classic git:

(torchtune-2) [jrcummings@devvm050.nha0 ~/projects/torchtune (fixup-cli)]$ time git
usage: git [-v | --version] [-h | --help] [-C <path>] [-c <name>=<value>]

...

real    0m0.024s
user    0m0.006s
sys     0m0.018s

Okay, we have no chance of being as fast as git, which is written in C and has been optimized over years of development. Python itself is slow and argparse, as part of the stdlib, is no exception. We are roughly on par with accelerate, which also uses argparse under the hood, with the added overhead likely from incorporating all the commands from torchrun, which also uses argparse. There are some small optimizations like copying over all the torchrun commands rather than loading them from torchrun's argparser, but the best way to speed up our CLI would be to take a dependency on another library like click orrrrrrrr write our CLI in a language like Go. Neither of those options seems likely in the short term, so really I'm just calling this out as a known issue in case users are concerned.

Why did you add new dataclasses to represent recipes and configs? Isn't that overkill?

Attributes of our recipes are getting more and more complex as the days progress so I think it's natural to start loosely formalizing this information. Specifically for this PR, tune run needed to know when a recipe was able to be run in a distributed-fashion. Now, I could check the string for "distributed", but that's a week assertion. Now, with recipes having an attribute called supports_distributed, there's no guesswork. This adds very little to the overhead needed to add a new "core" recipe and gives us stronger reasoning about the recipes we're running.

Changelog

  • Remove convert_checkpoint command. cc @kartikayk and @ebsmothers for the approval here
  • Add better error handling for download and cp
  • Add dataclasses for recipes and configs
  • Add tune run command
  • Add tune run command tests
  • Update README
  • Update tests

Testing

  1. CI green
  2. Unit tests
  3. Run all single recipes (minus alpaca generate)
(torchtune-2) [jrcummings@devvm050.nha0 ~/projects/torchtune (fixup-cli)]$ tune run lora_finetune_single_device --config llama2/7B_lora_single_device Running recipe_main with parameters {'model': {'_component_': 'torchtune.models.llama2.lora_llama2_7b', 'lora_attn_modules': ['q_proj', 'v_proj'], 'apply_lora_to_mlp': False, 'apply_lora_to_output': False, 'lora_rank': 8, 'lora_alpha': 16}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelMetaCheckpointer', 'checkpoint_dir': './model', 'checkpoint_files': ['consolidated.00.pth'], 'adapter_checkpoint': None, 'recipe_checkpoint': None, 'output_dir': './', 'model_type': 'LLAMA2'}, 'resume_from_checkpoint': False, 'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_tokenizer', 'path': './model/tokenizer.model'}, 'dataset': {'_component_': 'torchtune.datasets.alpaca_dataset', 'train_on_input': True, 'use_clean': True}, 'seed': None, 'shuffle': True, 'batch_size': 2, 'optimizer': {'_component_': 'torch.optim.AdamW', 'weight_decay': 0.01, 'lr': 0.0003}, 'lr_scheduler': {'_component_': 'torchtune.modules.get_cosine_schedule_with_warmup', 'num_warmup_steps': 100}, 'loss': {'_component_': 'torch.nn.CrossEntropyLoss'}, 'epochs': 1, 'max_steps_per_epoch': None, 'gradient_accumulation_steps': 1, 'output_dir': '/tmp/lora_finetune_output', 'metric_logger': {'_component_': 'torchtune.utils.metric_logging.DiskLogger', 'log_dir': '${output_dir}'}, 'log_every_n_steps': None, 'device': 'cuda', 'dtype': 'bf16', 'enable_activation_checkpointing': True} Setting manual seed to local seed 1696546973. Local seed is seed + rank = 1696546973 + 0 Writing logs to /tmp/lora_finetune_output/log_1711657264.txt Model is initialized with precision torch.bfloat16.
Memory Stats after model init::
GPU peak memory allocation: 13.96 GB
GPU peak memory reserved: 13.98 GB
GPU peak memory active: 13.96 GB

Tokenizer is initialized from file.
Optimizer and loss are initialized.
Loss is initialized.
Dataset and Sampler are initialized.
Learning rate scheduler is initialized.
1|11|Loss: 1.597842812538147: 0%|

(torchtune-2) [jrcummings@devvm050.nha0 ~/projects/torchtune (fixup-cli)]$ tune run full_finetune_single_device --config llama2/7B_full_single_device Running recipe_main with parameters {'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_tokenizer', 'path': './tokenizer.model'}, 'dataset': {'_component_': 'torchtune.datasets.alpaca_dataset', 'train_on_input': True}, 'seed': None, 'shuffle': True, 'model': {'_component_': 'torchtune.models.llama2.llama2_7b'}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelMetaCheckpointer', 'checkpoint_dir': './model', 'checkpoint_files': ['consolidated.00.pth'], 'recipe_checkpoint': None, 'output_dir': './', 'model_type': 'LLAMA2'}, 'resume_from_checkpoint': False, 'batch_size': 2, 'epochs': 3, 'optimizer': {'_component_': 'torch.optim.SGD', 'lr': 2e-05}, 'loss': {'_component_': 'torch.nn.CrossEntropyLoss'}, 'max_steps_per_epoch': None, 'gradient_accumulation_steps': 1, 'device': 'cuda', 'enable_activation_checkpointing': True, 'dtype': 'bf16', 'metric_logger': {'_component_': 'torchtune.utils.metric_logging.DiskLogger', 'log_dir': '${output_dir}'}, 'output_dir': '/tmp/alpaca-llama2-finetune', 'log_every_n_steps': None} Setting manual seed to local seed 1384747597. Local seed is seed + rank = 1384747597 + 0 Writing logs to /tmp/alpaca-llama2-finetune/log_1711657490.txt Model is initialized with precision torch.bfloat16.
Memory Stats after model init::
GPU peak memory allocation: 13.95 GB
GPU peak memory reserved: 13.97 GB
GPU peak memory active: 13.95 GB

Tokenizer is initialized from file.
Optimizer is initialized.
Loss is initialized.

(torchtune-2) [jrcummings@devvm050.nha0 ~/projects/torchtune (fixup-cli)]$ tune run eleuther_eval --config eleuther_eval Running recipe_main with parameters {'model': {'_component_': 'torchtune.models.llama2.llama2_7b'}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelTorchTuneCheckpointer', 'checkpoint_dir': './', 'checkpoint_files': ['qlora_trained.pt'], 'output_dir': './', 'model_type': 'LLAMA2'}, 'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_tokenizer', 'path': './tokenizer.model'}, 'device': 'cuda', 'dtype': 'bf16', 'seed': 217, 'tasks': ['truthfulqa_mc2'], 'limit': None, 'max_seq_length': 4096} 2024-03-28:13:27:18,513 INFO [_parse.py:52] Running recipe_main with parameters {'model': {'_component_': 'torchtune.models.llama2.llama2_7b'}, 'checkpointer': {'_component_': 'torchtune.utils.FullModelTorchTuneCheckpointer', 'checkpoint_dir': './', 'checkpoint_files': ['qlora_trained.pt'], 'output_dir': './', 'model_type': 'LLAMA2'}, 'tokenizer': {'_component_': 'torchtune.models.llama2.llama2_tokenizer', 'path': './tokenizer.model'}, 'device': 'cuda', 'dtype': 'bf16', 'seed': 217, 'tasks': ['truthfulqa_mc2'], 'limit': None, 'max_seq_length': 4096} Setting manual seed to local seed 217. Local seed is seed + rank = 217 + 0 2024-03-28:13:27:19,173 DEBUG [seed.py:59] Setting manual seed to local seed 217. Local seed is seed + rank = 217 + 0 Model is initialized with precision torch.bfloat16. 2024-03-28:13:27:40,285 INFO [eleuther_eval.py:160] Model is initialized with precision torch.bfloat16. Tokenizer is initialized from file. 2024-03-28:13:27:40,314 INFO [eleuther_eval.py:146] Tokenizer is initialized from file. 2024-03-28:13:27:41,786 INFO [huggingface.py:148] Using device 'cuda:0' 2024-03-28:13:27:48,865 WARNING [__init__.py:194] Some tasks could not be loaded due to missing dependencies. Run with `--verbosity DEBUG` for full details. 2024-03-28:13:27:52,458 WARNING [__init__.py:194] Some tasks could not be loaded due to missing dependencies. Run with `--verbosity DEBUG` for full details. Running evaluation on ['truthfulqa_mc2'] tasks. 2024-03-28:13:28:08,859 INFO [eleuther_eval.py:181] Running evaluation on ['truthfulqa_mc2'] tasks. 2024-03-28:13:28:08,861 INFO [task.py:363] Building contexts for task on rank 0... 2024-03-28:13:28:10,287 INFO [evaluator.py:324] Running loglikelihood requests 2%|███▍
  1. Testing snapshot_download with ignore_patterns (correctly ignores downloading safetensors)
Screenshot 2024-03-29 at 12 32 00 PM

Copy link

pytorch-bot bot commented Mar 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/586

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 67bde68 with merge base f6f6855 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 25, 2024
@@ -135,8 +135,7 @@ def setup(self) -> None:
self._limit = self._cfg.limit
self._tasks = list(self._cfg.tasks)

seed = utils.set_seed(seed=self._cfg.seed)
logger.info(f"Random seed set to {seed}.")
utils.set_seed(seed=self._cfg.seed)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utils.set_seed already logs the resulting seed, didn't need to do it twice :)

$ tune cp full_finetune_distributed.py ./my_custom_full_finetune.py
$ tune cp full_finetune_distributed.py ./new_dir/my_custom_full_finetune.py --make-parents
# Attach proper suffix if needed
if destination.name != "" and destination.suffix != proper_suffix:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers thanks for the discussion here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanna make sure I understand: destination.name == "" implies we are copying to a directory and so the filename will be inherited from the original config, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bingo

resume_download=True,
token=args.hf_token,
)
except GatedRepoError:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better error handling!

@joecummings joecummings changed the title [WIP] Add tune run and big refactoring of CLI Add tune run and refactor CLI Mar 28, 2024
@joecummings joecummings marked this pull request as ready for review March 28, 2024 22:56
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once again, I fail to limit the scope of a PR.

I legitimately loled at this

"lora_finetune_single_device.py",
"lora_finetune_distributed.py",
"eleuther_eval.py",

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o god

Comment on lines 12 to 19
@dataclass
class Config:
name: str
file_path: str


@dataclass
class Recipe:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I do like these though

Comment on lines 25 to 26
def get_configs(self) -> List[Config]:
return self.configs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I don't like so much. Why is the configs field not sufficient? Or if we do need this, please just define as a separate function not on the dataclass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just good OOP design, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if configs was a private attribute this getter would make sense, but as it is just calling Recipe.configs is preferable than Recipe.get_configs() imo

return self.configs


_ALL_RECIPES = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a lot of LOC for our root init file, but I actually think the structure makes sense here. Nitpicking a bit I wonder if it's better to have all these in a separate file now though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh __init__.py is the last place I'd look for this. Why not just define this in a recipes.py or something like that and then just import ALL_RECIPES? Also it's a big confusing that this is a private dict in __init__.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should all be hidden from the end user, in fact maybe we should make Config and Recipe private since this is primarily for us developers to track built-in recipe attributes and their associated configs... an internal registry if you will (@ebsmothers shudders)

Although, I'm not entirely convinced that we really need this. The primary attribute you need to check is supports_distributed so you can run with torchrun or runpy on single device. Why the distinction? Why can't we just always use torchrun? And do you anticipate many more recipe attributes like this to be added? Because if we remove supports_distributed, we can still do just a simple mapping as we had before which is easier to maintain than the dataclasses

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that these should all be hidden from the user, but I will push back on the second part.

I originally wanted to run everything through torchrun, but adding on a multiprocessing layer even for simple recipes proved troublesome. 1) It clogs up the logs 2) it was much harder to track down errors even for simple recipes and 3) testing became much more difficult b/c we can't directly observe output (must be logged to file first)

I'm definitely open to revisiting this issue in the future, but it should be noted that it's not a trivial change.

I don't imagine a ton of new recipes to be added to the core library, and I would argue that it's actually easier to maintain this than a separate list and dictionary for recipes and recipes -> configs, which is what we had before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see your point about distributed logs. That's fair, I can get on board with that

@@ -168,12 +168,12 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):

cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd_1)
with pytest.raises(SystemExit):
with pytest.raises(SystemExit, match=""):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is the point here just to check that we're not raising some unexpected exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so a SystemExit error returns 0 or None for a successful exit and a positive int for an errored exit. This is essentially a sanity check that we have a successful exit.

Comment on lines +77 to +84
# TODO (rohan-varma): Add check that nproc_per_node <= cuda device count. Currently,
# we don't do this since we test on CPUs for distributed. Will update once multi GPU CI is supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Time to update this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output = capsys.readouterr()
assert "The '--config' argument is required" in output.err

def test_run_succeeds_with_local_recipe_file_and_default_config(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely missed that you were able to do this in my first pass over run.py. Very sneaky.. and very nice. We should also think about the case of copying a config only (I don't think that will work?). But given CLI overrides are an option I think it shouldn't be as hi-pri anyways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean - a user can definitely copy a config, modify it, and run it with a default recipe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah sorry, I was looking at this too late at night. I understand now. I like the way you've set it up, but the behavior of _get_recipe and _get_config and their subsequent usage in _run_cmd is a bit non-obvious. Would be nice to document a bit: specifically that the first two methods return None when the recipe/config is not registered with TorchTune, and in that case we assume they refer to a local path.

runpy.run_path(TUNE_PATH, run_name="__main__")
local_file_run.assert_called_once()

def test_run_calls_local_file_run_for_local_file_recipe(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's not really calling a separate run for the local file here, right? Like you are just patching _run_single_device and checking that gets called. Unless I am missing the point here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: ( Stop calling me out

Comment on lines +62 to +78
for action in torchrun_argparser._actions:
if action.dest == "training_script":
action.dest = "recipe"
action.help = """
Name or path to recipe to be launched followed by args.
For a list of all possible recipes, run `tune ls`."""
elif action.dest == "training_script_args":
action.dest = "recipe_args"
action.help = "Args to be passed to the recipe."
elif action.dest == "help":
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand this? But still, might be nice to give a bit more detail in the docstring here

Comment on lines 81 to 82
args.__dict__["training_script"] = args.__dict__.pop("recipe")
args.__dict__["training_script_args"] = args.__dict__.pop("recipe_args")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to use __dict__ on the LHS of these two lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, fixed.

tune --nnodes 1 --nproc_per_node 2 \
lora_finetune_distributed \
--config llama2/13B_lora
tune run --nproc_per_node 2 full_finetune_distributed --config llama2/7B_full_distributed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did --nnodes go? Are we not supporting it anymore or is it just hidden?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--nnodes defaults to 1 already and we don't technically support multinode anyway so removing it for now.

@@ -39,5 +31,5 @@ def test_alpaca_generate(self, tmpdir, monkeypatch):
cmd += model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit):
with pytest.raises(SystemExit, match=""):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above. This ensures the SystemExit was a successfully run command, rather than an error.

log_out = caplog.messages[-1]
assert "'acc,none': 0.3" in log_out
err_log = caplog.messages[-1]
assert "'acc,none': 0.346" in err_log
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, sorry to be a bit of a buzz kill, but I suspect there will be some variance between runs. Is this level of precision needed? Can we add some slack here? Maybe a +/- 0.5 to the value? Would love to not have to re-run these tests while landing a PR

@@ -83,8 +83,8 @@ def test_eval_recipe_errors_without_lm_eval(self, caplog, monkeypatch, tmpdir):
""".split()

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit):
with pytest.raises(SystemExit, match="1"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can just look this up, but what are these magic values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Denotes SystemExits that contain errors.

@@ -61,7 +57,7 @@ def test_loss(self, config, tmpdir, monkeypatch):
log_file = gen_log_file_name(tmpdir)

cmd = f"""
tune full_finetune_single_device
tune run full_finetune_single_device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably add a '' to conform with what you've done above

help="Copy a built-in recipe or config to a local path.",
description="Copy a built-in recipe or config to a local path.",
epilog=textwrap.dedent(
"""\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indent is killing me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

"""Downloads a model from the HuggingFace Hub."""
# Download the tokenizer and PyTorch model files
try:
true_output_dir = snapshot_download(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add "ignore_patterns" and set it to by default ignore "*.safetensors"? I need to make this change everytime I run on runpod and it's really annoying. If a model ONLY has safetensors, the user can manually set it to an empty string. Just print out that you're ignoring this for space. It's also really annoying because the saved files are cached by default and I need to dig into how to get rid of something I don't want (unless I'm just being stupid).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I was gunna do this in a follow-up, but can do it here.

In the future, I actually think we might want to by default support safetensors, but that's a discussion for another time.

true_output_dir = snapshot_download(
args.repo_id,
local_dir=args.output_dir,
resume_download=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand what resume_download=True is doing? I also didnt find any info on this in https://huggingface.co/docs/huggingface_hub/en/guides/download

Comment on lines 90 to 91
"You need to provide a HuggingFace API token to download gated models."
"You can find your token by visiting https://huggingface.co/settings/tokens"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only reason that download from a gated repo can fail? What if the user has not signed the consent form or something like that? In that case is this error misleading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I can update this comment to be more descriptive.

ROOT = Path(torchtune.__file__).parent.parent


class Run(Subcommand):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm staring at this and still not sure what tune run is getting us? Better UX than torchrun?

BTW can we add some examples of how I'd run multiple finetune jobs on the same host on different cuda devices? I know this needs some wrangling of the underlying torchrun params. But would be a huge QoL improvement if we can just call this out in the examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a more consistent UX (everything runs with subcommands of argparse) and in general, this PR adds more testing for running of recipes.

On the second point sure, will reach out to you directly on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tune run --nproc_per_node 2 --rdzv-backend=c10d --rdzv-endpoint:localhost:20000 full_finetune_distributed --config llama2/7B_full

# Then you can launch
tune run --nproc_per_node 2 --rdzv-backend=c10d --rdzv-endpoint:localhost:<OTHER-PORT> full_finetune_distributed --config llama2/7B_full

@kartikayk
Copy link
Contributor

@joecummings I didn't follow the comment around validate. What are you removing?

Comment on lines +14 to +16
@classmethod
def create(cls, *args, **kwargs):
return cls(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting interesting... do you mind commenting what this is doing? IIUC it is adding the parsers to the tune command to enable a subcommand. Maybe you could rename this to something like create_parsers to be more clear?

List.create(subparsers)
Copy.create(subparsers)
Run.create(subparsers)
Validate.create(subparsers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so these are class methods, which will create the subcommand instances for you... but then how does tune have access to the self._parser in each of these if you're not assigning the created instance to anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I hate this too. I wish there was a way in which I could add a full subparser to a top level argparse, but that's not how they add_subparser API is designed. I finally gave up on trying to find a way around this when I saw that accelerate adopted this approach, too. https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/accelerate_cli.py

parser.description = "Torch Tune Recipe Launcher"
parser.usage = "tune [options] <recipe> [recipe_args]"
parser.formatter_class = argparse.RawDescriptionHelpFormatter
class TuneCLIParser:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if we should rename TuneArgumentParser as well to make a clearer distinction between these two?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, would you like me to add this change to this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TuneRecipeParser?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can that would be nice, but it's non-blocking

@joecummings
Copy link
Contributor Author

@joecummings I didn't follow the comment around validate. What are you removing?

Typo on my part. Meant to call out the removal of convert_checkpoint. I updated the PR description.

@kartikayk
Copy link
Contributor

Typo on my part. Meant to call out the removal of convert_checkpoint

Kill with fire!

@@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers and @RdoubleA hows this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work

@joecummings joecummings merged commit 8fa0bbb into main Mar 29, 2024
20 checks passed
@joecummings joecummings deleted the fixup-cli branch March 29, 2024 20:06
tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants