-
Notifications
You must be signed in to change notification settings - Fork 565
Closed
Description
Bug description
Hi folks,
We found an hanging issue in the validation loop that might be related to the dataloder/sharding in FSDP.
Reproduction:
-
Command:
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" NGPU=2 ./run_train.sh
- log
...
[rank1]:[titan] 2025-08-21 11:19:52,968 - root - INFO - [Validation Debug] Rank 1 batch 87 Loss 12.241142272949219
[rank1]:[titan] 2025-08-21 11:19:52,972 - root - INFO - [Validation Debug] Rank 1 batch 88 start
[rank1]:[titan] 2025-08-21 11:19:52,988 - root - INFO - [Validation Debug] Rank 1 batch 88 Loss 12.209203720092773
[rank1]:[titan] 2025-08-21 11:19:52,993 - root - INFO - [Validation Debug] Rank 1 batch 89 start
[rank0]:[titan] 2025-08-21 11:19:52,905 - root - INFO - [Validation Debug] Rank 0 batch 84 Loss 12.281610488891602
[rank0]:[titan] 2025-08-21 11:19:52,910 - root - INFO - [Validation Debug] Rank 0 batch 85 start
[rank0]:[titan] 2025-08-21 11:19:52,927 - root - INFO - [Validation Debug] Rank 0 batch 85 Loss 12.214127540588379
[rank0]:[titan] 2025-08-21 11:19:52,931 - root - INFO - [Validation Debug] Rank 0 batch 86 start
[rank0]:[titan] 2025-08-21 11:19:52,947 - root - INFO - [Validation Debug] Rank 0 batch 86 Loss 12.226905822753906
[rank0]:[titan] 2025-08-21 11:19:52,951 - root - INFO - [Validation Debug] Rank 0 batch 87 start
[rank0]:[titan] 2025-08-21 11:19:52,968 - root - INFO - [Validation Debug] Rank 0 batch 87 Loss 12.206489562988281
[rank0]:[titan] 2025-08-21 11:19:52,972 - root - INFO - [Validation Debug] Rank 0 batch 88 start
[rank0]:[titan] 2025-08-21 11:19:52,988 - root - INFO - [Validation Debug] Rank 0 batch 88 Loss 12.239264488220215
[rank0]:[titan] 2025-08-21 11:19:52,992 - root - INFO - [Validation Debug] Rank 0 batch 89 start
[rank1]:[titan] 2025-08-21 11:19:53,009 - root - INFO - [Validation Debug] Rank 1 batch 89 Loss 12.284561157226562
[rank1]:[titan] 2025-08-21 11:19:53,013 - root - INFO - [Validation Debug] Rank 1 batch 90 start
[rank1]:[titan] 2025-08-21 11:19:53,030 - root - INFO - [Validation Debug] Rank 1 batch 90 Loss 12.25745964050293
[rank1]:[titan] 2025-08-21 11:19:53,034 - root - INFO - [Validation Debug] Rank 1 batch 91 start
[rank1]:[titan] 2025-08-21 11:19:53,050 - root - INFO - [Validation Debug] Rank 1 batch 91 Loss 12.285140991210938
[rank1]:[titan] 2025-08-21 11:19:53,054 - root - WARNING - Dataset alpaca_validation has run out of data
[rank1]:[titan] 2025-08-21 11:19:53,054 - root - INFO - [Validation Debug] Rank 1 Validation done at batch 91.
[rank0]:[titan] 2025-08-21 11:19:53,009 - root - INFO - [Validation Debug] Rank 0 batch 89 Loss 12.28492546081543
[rank0]:[titan] 2025-08-21 11:19:53,014 - root - INFO - [Validation Debug] Rank 0 batch 90 start
[rank0]:[titan] 2025-08-21 11:19:53,030 - root - INFO - [Validation Debug] Rank 0 batch 90 Loss 12.212723731994629
[rank0]:[titan] 2025-08-21 11:19:53,034 - root - INFO - [Validation Debug] Rank 0 batch 91 start
[rank0]:[titan] 2025-08-21 11:19:53,050 - root - INFO - [Validation Debug] Rank 0 batch 91 Loss 12.228861808776855
[rank0]:[titan] 2025-08-21 11:19:53,055 - root - INFO - [Validation Debug] Rank 0 batch 92 start
Rank 1 finishes validation at batch 91. Rank 0 hangs at batch 92.
If we set the validation step to a fixed number that is smaller than the total steps like this:
[validation]
enabled = true
local_batch_size=1
dataset = "alpaca_validation"
freq = 2
steps = 5
It doesn't hang. So our guess is that it is because the dataloder yields uneven number of batches in two ranks.
cc @ebsmothers
Versions
Code pointer provided above
torch 2.9.0a0+gite7cc42d
Metadata
Metadata
Assignees
Labels
No labels