Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

storing & logging gradient norm in trainer #27326

Merged
merged 2 commits into from
Feb 19, 2024

Conversation

shijie-wu
Copy link
Contributor

What does this PR do?

Report gradient norm during training - Fixes #26143

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @pacman100

@amyeroberts
Copy link
Collaborator

cc @muellerzr

@huggingface huggingface deleted a comment from github-actions bot Dec 7, 2023
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks good to me, can you rebase from main to deal with the failing tests hopefully?

Copy link

github-actions bot commented Jan 1, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@mjbommar
Copy link

mjbommar commented Jan 2, 2024

Thank you for the work on this, @shijie-wu!

It may seem like a little PR to some, but this would be a huge step to bring transformers closer to parity with projects like gpt-neox for large-scale training.

@muellerzr
Copy link
Contributor

Gentle ping @shijie-wu :)

@jubgjf
Copy link

jubgjf commented Jan 8, 2024

Found that self.accelerator.clip_grad_norm_ will return None if we are using DeepSpeed with Trainer. In DeepSpeed we should use model.get_global_grad_norm() to get grad_norm:

_grad_norm = self.accelerator.clip_grad_norm_(
    model.parameters(),
    args.max_grad_norm,
)
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
    grad_norm = model.get_global_grad_norm()
else:
    grad_norm = _grad_norm.item() if _grad_norm is not None else None

@shijie-wu
Copy link
Contributor Author

sorry for the delay! PTAL @muellerzr @mjbommar

@shijie-wu
Copy link
Contributor Author

Gentle ping @muellerzr @mjbommar :)

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Sorry for the delay!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Contributor

cc @amyeroberts for final review :)

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

@amyeroberts amyeroberts merged commit 4f09d0f into huggingface:main Feb 19, 2024
21 checks passed
@152334H
Copy link

152334H commented Mar 2, 2024

not sure if this was mentioned anywhere, but this PR breaks training checkpoint saving because

  1. the grad norm is added to TrainerState.log_history as a tensor
  2. TrainerState.save_to_json attempts to jsonify that tensor, which naturally errors out as Tensors can't be jsonified

my fix for this was to patch save_to_json to the following:

    def save_to_json(self, json_path: str):
        """Save the content of this instance in JSON format inside `json_path`."""
        selfd = dataclasses.asdict(self)
        for d in selfd['log_history']:
            if 'grad_norm' in d: d['grad_norm'] = d['grad_norm'].item()
        json_string = json.dumps(selfd, indent=2, sort_keys=True) + "\n"
        with open(json_path, "w", encoding="utf-8") as f: f.write(json_string)

but this is probably not the best approach to doing this

@shijie-wu
Copy link
Contributor Author

@152334H it does convert grad_norm to number before passing it into _maybe_log_save_evaluate

if (
is_accelerate_available()
and self.accelerator.distributed_type == DistributedType.DEEPSPEED
):
grad_norm = model.get_global_grad_norm()
else:
grad_norm = _grad_norm.item() if _grad_norm is not None else None

same for deepspeed

https://github.com/microsoft/DeepSpeed/blob/bcc617a0009dd27b4e144de59979bd7770eaf57c/deepspeed/runtime/engine.py#L448-L458

what backend were you using?

@shijie-wu shijie-wu deleted the grad_norm branch March 2, 2024 22:39
@152334H
Copy link

152334H commented Mar 2, 2024

Deepspeed zero2.

Seems likely that the type hint is not universally correct. The value returned in scaled_global_norm for zero2 is a tensor scalar. That value subsequently assigns _global_grad_norm without any .item().

@shubhanjan99
Copy link

not sure if this was mentioned anywhere, but this PR breaks training checkpoint saving because

  1. the grad norm is added to TrainerState.log_history as a tensor
  2. TrainerState.save_to_json attempts to jsonify that tensor, which naturally errors out as Tensors can't be jsonified

I'm facing the same issue with deepspeed stage 1, can you please fix this. I need to use v4.38.0 for a different fix?

@muellerzr
Copy link
Contributor

Can you all try installing with pip install git+https://github.com/huggingface/transformers@muellerzr-deepspeed-item?

This PR may have fixed this too as well: #29444

@muellerzr muellerzr mentioned this pull request Mar 4, 2024
5 tasks
@shubhanjan99
Copy link

Can you all try installing with pip install git+https://github.com/huggingface/transformers@muellerzr-deepspeed-item?

That fixed it for me! Thanks a lot

@lucasjinreal
Copy link

same error here:

11%|████████████████████████▏ | 800/7050 [4:07:59<32:10:41, 18.53s/it]Trainer is attempting to log a value of "2.204314947128296" of type <class 'torch.Tensor'> for key "train/grad_norm" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.

Please tell me how to fix it?

itazap pushed a commit that referenced this pull request May 14, 2024
* report grad_norm during training

* support getting grad_norm from deepspeed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

storing & logging gradient norm in trainer
9 participants