Commit 98d904f
authored
[VLM] Add token-imbalance loss (#1803)
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.
#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$
If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:
$$
\tilde g_i = c \cdot g_i .
$$
FSDP will *average* these gradients across ranks:
$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
=\frac{c}{R}\sum_{i=1}^{R} g_i .
$$
We want this to equal the **global‑sample average**:
$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$
Thus for FSDP gradient to be correct, we need
$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$
So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**
P/s: sorry this PR is based on #1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.1 parent b276387 commit 98d904f
File tree
4 files changed
+118
-3
lines changed- torchtitan
- components
- experiments/vlm
- infra
4 files changed
+118
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
| 26 | + | |
| 27 | + | |
27 | 28 | | |
28 | 29 | | |
29 | 30 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
10 | 9 | | |
11 | 10 | | |
12 | 11 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
197 | 197 | | |
198 | 198 | | |
199 | 199 | | |
200 | | - | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
201 | 203 | | |
202 | 204 | | |
203 | 205 | | |
| |||
0 commit comments