Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compute true loss Flax examples #18458
Compute true loss Flax examples #18458
Changes from 6 commits
6a9726e
f5f430e
4eb36f2
693ba2c
31be02f
0e24548
0efbb6e
3774010
16913b3
abca174
8edf196
7743cac
30a28f5
0a42b61
9b80a0b
2322eb8
2700ba6
578e18e
68bb33d
f7ce4f1
16242e1
f3d3863
0b5c7e4
c8e40d6
c603c80
532ca05
d63bdf7
820cb97
bc5d0b1
70ba10e
658010c
5e636ee
773314a
9017ba4
56c548f
7d0486c
ae21953
ca485e5
ba7f217
22264f9
fbe8464
1bbad7a
fe5e7ce
6f25d10
0d1ba2d
e7206ce
6be338f
3b0cecb
6227078
801ebd0
261301d
c81ebd1
cc567e0
6740341
de26241
36e356c
36b9a99
06f341d
18643ff
ef6741f
9e95706
da6a1b6
e7fdfc7
486134e
114295c
66154a6
451df72
d5848a5
3c7b965
c7fd289
126a739
9393f96
4d0f8c0
2d9853b
cf6308e
1b5ab39
83dc637
e5b7cff
74a3ea4
3a396c5
8d59385
49629e7
905635f
7e84723
fe01ec3
49bf569
0cea8d5
ece7624
6395d12
5da6afd
fa4eeb4
71fc331
408b5e3
98af4f9
ea75e9f
216b2f9
c20b2c7
be4f269
ca08863
a32f97c
7132d55
ea540a5
88f597b
226b0e4
e3a30e2
34be08e
2d95695
2df6028
942fa8c
a357ed5
22d37a9
de359c4
9c6aeba
4a0b958
64998a5
0fc68a7
990936a
6957350
99c3249
9d732fd
bb6fa06
0dc7b3a
ba9e336
3a27ba3
902d30b
b79028f
655f72a
01eb34a
cca6e6f
f16bbf1
163cd15
1a1893e
49d62b0
f3d2f7a
4fd32a1
368b649
582d085
e396358
dad578e
3e2dd7f
2fba98e
f33858d
6a08162
cfb777f
5cd16f0
36f52e9
c28d04e
18c0620
68f50f3
008531c
ca26277
c7ec0af
4c962d5
534cd8f
fe10796
3a1a56a
ac5ea74
9b63016
cc263e9
ca3ebc4
cd024da
f134d38
2403dbd
d6e9204
a978288
6fd254a
6dce9e0
587d84b
971da2e
bf7eb0c
6268694
07e94bf
512fa41
e12bbe3
c54bb1a
60db81f
2f53ab5
e794ca5
4cbc797
c875a96
071df6e
226b8ef
d9101b7
ad98642
7598791
45e1403
bad353c
7e7f62b
f0b4901
ce26201
ae3e3bc
7e348aa
46fd04b
969534a
e162ceb
b29ebdf
41ec5d0
56af8df
331ea01
e9a49ba
fa4bcd5
5fef17f
a26d71d
994b7a4
e6fc201
c2b83d5
06514b3
6ef16f2
7418a48
34e0cc6
de4d71e
983451a
9ac586b
d92e22d
8b6bba5
cbb8a37
4107445
1241a49
83dc49b
3410705
c523a86
7d5ce68
298f6a9
3080bb4
4824741
ba71bf4
af69360
e3f028f
e150c4e
b0b962c
4dd784c
d866b48
9df953a
25cfd91
a7bc422
692c5be
870a954
d739a70
c6a928c
5f5e264
df2f281
1010097
d7d71c8
b5ccda0
135cb98
d94d04f
d0ccf00
418f6c4
b90b5ae
6c0ae1c
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unintentional change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit-gandhi I noticed that with
drop_last=False
the eval loss will becomenan
at the beginning of training, but eval accuracy is still on track. It appears to occur with bothrun_bart_dlm_flax
andrun_summarization
so I temporarily turned it off. It would be great if you could take a look and fix it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! What immediately jumps out to me is that
num_labels
is 0, causing the 'true loss' to benan
. You didn't get this behaviour previously with thepmap
operation? What eval batch size are you using?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bug occurs without dividing
loss
bynum_labels
actually. I took a quick look at the eval loss computed on every token at each step, the last losses wasnan
.python run_summarization_flax.py \ --output_dir ./bart-base-xsum \ --model_name_or_path facebook/bart-base \ --tokenizer_name facebook/bart-base \ --dataset_name="xsum" \ --do_train --do_eval --do_predict --predict_with_generate \ --num_train_epochs 6 \ --learning_rate 5e-5 --warmup_steps 0 \ --per_device_train_batch_size 64 \ --per_device_eval_batch_size 64 \ --overwrite_output_dir \ --max_source_length 512 --max_target_length 64 \ --push_to_hub
Printed tensor has a shape of (8, 64, 64), provided that I have 8 TPU cores,
per_device_eval_batch_size=64
andmax_target_length=64
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh no! That's interesting to see. Is this an artefact of using the
psum
? As in, were the lossesnan
when we used apmap
previously? If so, we'll need to address this!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The losses were
nan
right afterloss = optax.softmax_cross_entropy(logits, soft_labels)
. Great if you could have a look at this issue!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, this problem doesn't seem to be related to this PR. Flax examples are young so we'll improve it step by step :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @duongna21, just wondering what the status is on the
nan
issue being discussed here. I'm running into this issue while using the recently addedrun_bart_dlm_flax.py
script but am very new to Flax/Jax so haven't been able to really make sense of it yet. Isrun_bart_dlm_flax.py
useable for model pre-training in its current state? Just as a side note, I haven't seen thenan
issue when runningrun_t5_mlm_flax.py
on my training data. Thanks in advance for any clarification!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tannonk I'm sorry for the late reply. I believe this error is related to the
drop_last=False
option. The training will be fine if you setdrop_last=True
, at the cost that a few examples in the last batch will be skipped. Nice to see anyone is able to work on this weird bug.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's
drop_last=False
which is causing the eval loss to benan
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit-gandhi Thanks for the review. Rebased.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unintentional change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here! Is
drop_last
causing issues with the eval loss?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit-gandhi Yes!