-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Wandb experience #660
Conversation
…pytorch#650) Co-authored-by: Kartikay Khandelwal <47255723+kartikayk@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! Left a bunch of comments, please let me know if any of them are unclear. A couple other general things:
) | ||
memory_stats = utils.memory_stats_log(device=self._device) | ||
log.info(f"Memory Stats:\n{memory_stats}") | ||
log.info(f"Model trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:,.2f}M") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this line?
) | ||
) | ||
memory_stats = utils.memory_stats_log(device=self._device) | ||
log.info(f"Memory Stats:\n{memory_stats}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: change back to "Memory Stats after model init:" just to be explicit
self._metric_logger.log_dict(memory_stats, step=self.total_training_steps) | ||
log.info(f"Memory Stats:\n{memory_stats}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So right now we log to stdout only when WandB is not enabled, but log to both WandB and stdout when it is enabled? Don't all our metric loggers support log_dict? If so, can we just call log_dict
only here?
torchtune/utils/metric_logging.py
Outdated
entity: Optional[str] = None, | ||
group: Optional[str] = None, | ||
log_strategy: Optional[str] = "main", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could type as a literal, like we do e.g. here
torchtune/utils/metric_logging.py
Outdated
log_strategy (Optional[str]): Strategy to use for logging. Options are "main", "node", "all". | ||
Default: "main" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would add more detail here explaining what each of these mean
|
||
|
||
# Training env | ||
device: mps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have not thoroughly tested our recipes yet with mps... it makes sense that a lot of folks would run this config on their macbook but for now I would keep this as cuda or cpu (if it fits) and for your personal testing on mac override this from the command-line
# Memory management | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is full and not reduced precision
project: torchtune | ||
log_every_n_steps: 1 | ||
|
||
# # Logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we remove this?
torchtune/utils/metric_logging.py
Outdated
if ( | ||
(self.log_strategy == "main" and self.rank == 0) | ||
or (self.log_strategy == "node" and self.local_rank == 0) | ||
or self.log_strategy == "all" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined to make this a quick private method (e.g., _should_log()
) since the logic here is more complex, that way you can test this logic in an isolated way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you'll need to run pre-commit run --all-files
to fix some of the spacing and indent issues
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Rafi Ayub <33648637+RdoubleA@users.noreply.github.com>
"Memory Stats after model init:", device=self._device | ||
) | ||
) | ||
if self._device == torch.device("cuda"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for the CPU recipe tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this actually wont throw the log just doesn't print anything
.. code-block:: bash | ||
|
||
pip install wandb | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a "tip" to run wandb login
before running?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch
recipes/lora_finetune_distributed.py
Outdated
@@ -541,9 +541,12 @@ def train(self) -> None: | |||
if ( | |||
self.total_training_steps % self._log_peak_memory_every_n_steps == 0 | |||
and self._is_rank_zero | |||
and self._device == torch.device("cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this one too? For distributed tests they should only run on GPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mm yeah I can remove this check for distributed recipes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple questions on the device == cuda checks (especially for distributed recipes). Otherwise looks good though
The idea is create a better W&B experience for the end user
checkpoint_dir
and save the file. This could also be added to othermetric_loggers
. We also keep track of the config in the Overview tab by updating the underlyingwandb.config
. We also add the run-id to the config filename, so they don't automatically overwrite. The naming isf"torchtune_config_{self._wandb.run.id}.yaml"
at this time.memory_stats_log
function to return a dict so we can also log that to W&B. Also only log memory stats when training on GPU (some recipe tests use CPU)