Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'megatron.core.parallel_state' has no attribute 'get_amax_reduction_group' #6625

Closed
yen-shi opened this issue May 10, 2023 · 4 comments · Fixed by #6791
Labels
bug Something isn't working

Comments

@yen-shi
Copy link
Contributor

yen-shi commented May 10, 2023

Describe the bug

When running megatron_gpt_eval.py with an FP8 model, The FP8 path is run and not able to find get_amax_reduction_group() in parallel_state.

The error is triggered at this line:
https://github.com/NVIDIA/NeMo/blob/21048627b3923c9268842990aafdef141bd14bd1/nemo/collections/nlp/modules/common/megatron/transformer.py#LL1432C70-L1432C70

Steps/Code to reproduce bug

Get a model trained with TE FP8

Run:
python megatron_gpt_eval.py gpt_model_file=models/5b_fp8_tp1.nemo

Expected behavior

The script is expected to finish and generate outputs.

Environment overview (please complete the following information)

  • Environment location:
    Container: nvcr.io/nvidia/pytorch:23.03-py3
  • Method of NeMo install:
    Call ./reinstall.sh on main branch commit c3deeac
  • If method of install is [Docker], provide docker pull & docker run commands used
    docker run --gpus '"device=0"' -it --ipc=host --ulimit memlock=-1 -v /home/scratch.yenshiw_sw/NeMo:/workspace/local-nemo --ulimit stack=67108864 nvcr.io/nvidia/pytorch:23.03-py3

Environment details

Additional context

I cannot find name get_amax_reduction_group in megatron source code (parallel_state):
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py

@yen-shi yen-shi added the bug Something isn't working label May 10, 2023
@timmoon10
Copy link
Collaborator

timmoon10 commented May 10, 2023

Looks like this is a bug from switching from Apex to Megatron-core (#6393). There was some recent MLPerf-related development in Apex that optimized AMAX reductions in Transformer Engine (NVIDIA/apex#1585, NVIDIA/apex#1597).

Perhaps this belongs as its own issue, but I see more recent changes in Apex that haven't made their way to Megatron-core:

Pinging @erhoo82 @Aidyn-A @ksivaman

@aklife97
Copy link
Collaborator

This doesn't look like it'll need any NeMo side changes.
Core does not have _AMAX_REDUCTION_GROUP that Apex does which we need to have to make fp8 work. We'd need to add this to Core which should directly enable it in NeMo.

That said, this is obviously a regression since Apex supported it and we don't yet have it in Core, but should be a fairly straightforward resolution as soon as we have it there

@erhoo82
Copy link
Collaborator

erhoo82 commented May 25, 2023

_AMAX_REDUCTION_GROUP was added to merge the TP- and DP- reductions into a single communication call. This is only needed at FP8 training and should be added to Megatron-Core for compatibility with TE.

@aklife97
Copy link
Collaborator

@erhoo82, thanks to @timmoon10 this is already in core now.
I'm waiting for #6627 to merge, and we should be able to close this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants