Skip to content

Comments

[NaN] Fix nan print issue when running Megatron-Deepspeed with DeepSpeed#434

Merged
tjruwase merged 1 commit intodeepspeedai:mainfrom
ys950902:nan_issue
Aug 24, 2024
Merged

[NaN] Fix nan print issue when running Megatron-Deepspeed with DeepSpeed#434
tjruwase merged 1 commit intodeepspeedai:mainfrom
ys950902:nan_issue

Conversation

@ys950902
Copy link

@ys950902 ys950902 commented Aug 5, 2024

When we running megatron-deepspeed with deepspeed met nan issue, the only way we can judge this issue can see below is no lm loss print and the number of nan iterations is still 0 which is not correct:
iteration 9/ 10 | consumed samples: 108 | consumed tokens: 442368 | elapsed time per iteration (ms): 1979.2 | learning rate: 4.219E-07 | global batch size: 12 | loss scale: 1.0 | actual seqlen: 4096 | number of skipped iterations: 0 | number of nan iterations: 0 | samples per second: 6.063 | tokens per gpu per second (tgs): 2069.506 | TFLOPs: 127.00 |

This pr is to fix this issue, whether is skipped iter we should do the nan check.

Copy link

@abhilash1910 abhilash1910 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@tjruwase
Copy link

tjruwase commented Aug 7, 2024

@ys950902, can you please share a bit more details about why skipped_iter is False in this case?

@ys950902
Copy link
Author

ys950902 commented Aug 7, 2024

@ys950902, can you please share a bit more details about why skipped_iter is False in this case?

Hi @tjruwase, thanks for your reply, when you running Megatron-DeepSpeed with DeepSpeed for 3D parallelism:
https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/training.py#L674
or running for zero2/3
https://github.com/microsoft/Megatron-DeepSpeed/blob/main/megatron/training.py#L762
the skipped_iter is set to 0 by default, and DeepSpeed won't update this flag, so is false here.

@tjruwase
Copy link

tjruwase commented Aug 7, 2024

@ys950902, thanks for the explanation. I think the correct solution is to use the was_step_applied() API of DeepSpeed. And I noticed that for the non-3D parallelism case, it is already used to set update_successful.
https://github.com/microsoft/Megatron-DeepSpeed/blob/53b241f992f9b3dd7917bc36472f60cb118f8303/megatron/training.py#L746

The problem is that update_successful is not used to appropriately set skipped_iter unlike the non-deepspeed code path.
https://github.com/microsoft/Megatron-DeepSpeed/blob/53b241f992f9b3dd7917bc36472f60cb118f8303/megatron/training.py#L773-L778

Can you try setting update_successful and skipped_iter for both deepspeed cases in a consistent fashion to the megatron case? Thanks

@ys950902
Copy link
Author

ys950902 commented Aug 7, 2024

@ys950902, thanks for the explanation. I think the correct solution is to use the was_step_applied() API of DeepSpeed. And I noticed that for the non-3D parallelism case, it is already used to set update_successful.

https://github.com/microsoft/Megatron-DeepSpeed/blob/53b241f992f9b3dd7917bc36472f60cb118f8303/megatron/training.py#L746

The problem is that update_successful is not used to appropriately set skipped_iter unlike the non-deepspeed code path.

https://github.com/microsoft/Megatron-DeepSpeed/blob/53b241f992f9b3dd7917bc36472f60cb118f8303/megatron/training.py#L773-L778

Can you try setting update_successful and skipped_iter for both deepspeed cases in a consistent fashion to the megatron case? Thanks

Got it, I will fix it as you suggested!

@ys950902
Copy link
Author

ys950902 commented Aug 8, 2024

Hi @tjruwase, could you please take a look on this pr and with the modify in deepspeed to support bfloat16 deepspeedai/DeepSpeed#5879.

@ys950902
Copy link
Author

Hi @tjruwase, will you merge this pr?

@tjruwase tjruwase merged commit 4f9f1f6 into deepspeedai:main Aug 24, 2024
loadams pushed a commit that referenced this pull request Feb 7, 2025
Signed-off-by: Logan Adams <loadams@microsoft.com>
YJHMITWEB pushed a commit to YJHMITWEB/Megatron-DeepSpeed that referenced this pull request Aug 9, 2025
Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
YJHMITWEB pushed a commit to YJHMITWEB/Megatron-DeepSpeed that referenced this pull request Aug 9, 2025
Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
tjruwase pushed a commit that referenced this pull request Aug 14, 2025
…nabled (#479)

* pass batch_dim_idx to deepspeed sequence parallel distributed attention for supporting batch size larger than 1

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* add fused_rms_norm support on XPU device (#431)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* [LLaMa] Adding support converting checkpoint from mds to hf (#432)

* add support converting checkpoint from hf to mds

* Fix PP issue

* update

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* add device check when import ipex (#436)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* fix TFLOPs calculation (#371)

* fix TFLOPs calculation

when GQA used, we observe right TFLOPs after this fix.
when GQA is not used, huge difference in TFLOPs is solved with
selective recompute .
some other minor difference will also be observed as logits macs also added.

* add copyrights

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* fix nan issue when running megatron-deepspeed (#434)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* enable empty cache on XPU device (#438)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* [wandb] disable wandb more gracefully (#422)

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* [Bug] Fix crash when logging optimizer state to tb (#417)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* add FPDT support; add Ulysses rotary position embedding support

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* add FPDT support; add Ulysses rotary position embedding support

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* add FPDT support; add Ulysses rotary position embedding support

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* add FPDT support; add Ulysses rotary position embedding support

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* remove unnecessary files

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* set the warmup length to be FPDT chunk size if enabled

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* Enable Sequence Parallelism (#429)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* grad_wei can't be NoneType when running with DeepSpeed, for zero3 will divided the gradient (#428)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* fix init issue for rms_norm in squence_parallel (#448)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* enable profiler for specific ranks (#451)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* fix init issue for silently ignoring the deepspeed config (#452)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* fix moe tflops (#445)

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* [tool]GQA convert support (#454)

* [tools]GQA convert support

* fix readme

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* Fix import error in `deepspeed_to_megatron.py` (#455)

Previously, `deepspeed_to_megatron.py` would raise an import error
due to the relative import.

This commit fixes this issue by changing from the relative import
to the absolute import like in `deepspeed_to_transformers.py`.

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* Update references to new GitHub org (deepspeedai) (#462)

Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* add sequence_parallel in layernorm init to enable 3D parallelism can run successfully with DeepSpeed (#468)

Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

* fix bug when FPDT is disabled but with original Ulysses

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
Signed-off-by: jinghan yao yjhmitweb@gmail.com
Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>

---------

Signed-off-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
Signed-off-by: Logan Adams <loadams@microsoft.com>
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: jinghan yao yjhmitweb@gmail.com
Co-authored-by: Jinghan Yao <yjhmitweb@ascend-rw02.ten.osc.edu>
Co-authored-by: YiSheng5 <syhm@mail.ustc.edu.cn>
Co-authored-by: billishyahao <yahao.he@gmail.com>
Co-authored-by: Polisetty V R K Jyothendra Varma <jvarma@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Jinghan Yao <yjhmitweb@ascend-rw01.ten.osc.edu>
Co-authored-by: ranzhejiang <zhejiang.ran@intel.com>
Co-authored-by: Xinyu Lian <lian7@illinois.edu>
Co-authored-by: inkcherry <mingzhi.liu@intel.com>
Co-authored-by: hotsuyuki <hotsuyuki.kawanishi@gmail.com>
Co-authored-by: Jinghan Yao <yjhmitweb@cardinal-rw02.ten.osc.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants