Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute true loss Flax examples #18458

Closed
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
541 commits
Select commit Hold shift + click to select a range
6a9726e
Fix `DocumentQuestionAnsweringPipelineTests` (#19023)
ydshieh Sep 14, 2022
f5f430e
Add support for Japanese GPT-NeoX-based model by ABEJA, Inc. (#18814)
SO0529 Sep 14, 2022
4eb36f2
Mark right save_load test as slow (#19031)
sgugger Sep 14, 2022
693ba2c
Fix GPT-NeoX doc examples (#19033)
lewtun Sep 14, 2022
31be02f
TF: tf.debugging assertions without tf.running_eagerly() protection (…
gante Sep 14, 2022
0e24548
Add safeguards for CUDA kernel load in Deformable DETR (#19037)
sgugger Sep 14, 2022
0efbb6e
fix GPT2 token's `special_tokens_mask` when used with `add_bos_token=…
SaulLu Sep 14, 2022
3774010
Automate check for new pipelines and metadata update (#19029)
sgugger Sep 14, 2022
16913b3
Dev version
LysandreJik Sep 14, 2022
abca174
Fix a broken link for deepspeed ZeRO inference in the docs (#19001)
nijkah Sep 14, 2022
8edf196
[doc] debug: fix import (#19042)
stas00 Sep 14, 2022
7743cac
[bnb] Small improvements on utils (#18646)
younesbelkada Sep 15, 2022
30a28f5
Update image segmentation pipeline test (#18731)
amyeroberts Sep 15, 2022
0a42b61
Fix `test_save_load` for `TFViTMAEModelTest` (#19040)
ydshieh Sep 15, 2022
9b80a0b
Pin minimum PyTorch version for BLOOM ONNX export (#19046)
lewtun Sep 15, 2022
2322eb8
Update serving signatures and make sure we actually use them (#19034)
Rocketknight1 Sep 15, 2022
2700ba6
Move cache: expand error message (#19051)
sgugger Sep 15, 2022
578e18e
🚨🚨🚨 Optimize Top P Sampler and fix edge case (#18984)
ekagra-ranjan Sep 15, 2022
68bb33d
Fixing OPT fast tokenizer option. (#18753)
Narsil Sep 15, 2022
f7ce4f1
Fix custom tokenizers test (#19052)
sgugger Sep 15, 2022
16242e1
Run `torchdynamo` tests (#19056)
ydshieh Sep 15, 2022
f3d3863
fix arg name in BLOOM testing and remove unused arg document (#18843)
shijie-wu Sep 15, 2022
0b5c7e4
Adds package and requirement spec output to version check exception (…
colindean Sep 15, 2022
c8e40d6
fix `use_cache` (#19060)
younesbelkada Sep 16, 2022
c603c80
FX support for ConvNext, Wav2Vec2 and ResNet (#19053)
michaelbenayoun Sep 16, 2022
532ca05
[doc] Fix link in PreTrainedModel documentation (#19065)
tomaarsen Sep 16, 2022
d63bdf7
Add FP32 cast in ConvNext LayerNorm to prevent rounding errors with F…
jimypbr Sep 16, 2022
820cb97
Organize test jobs (#19058)
sgugger Sep 16, 2022
bc5d0b1
Automatically tag CLIP repos as zero-shot-image-classification (#19064)
osanseviero Sep 16, 2022
70ba10e
Fix `LeViT` checkpoint (#19069)
ydshieh Sep 16, 2022
658010c
TF: tests for (de)serializable models with resized tokens (#19013)
gante Sep 16, 2022
5e636ee
Add type hints for PyTorch UniSpeech, MPNet and Nystromformer (#19039)
daspartho Sep 16, 2022
773314a
replace logger.warn by logger.warning (#19068)
fxmarty Sep 16, 2022
9017ba4
Fix tokenizer load from one file (#19073)
sgugger Sep 16, 2022
56c548f
Note about developer mode (#19075)
LysandreJik Sep 16, 2022
7d0486c
Bump mako in /examples/research_projects/decision_transformer (#19077)
dependabot[bot] Sep 16, 2022
ae21953
german autoclass (#19049)
flozi00 Sep 16, 2022
ca485e5
Add tests for legacy load by url and fix bugs (#19078)
sgugger Sep 16, 2022
ba7f217
Add runner availability check (#19054)
ydshieh Sep 19, 2022
22264f9
fix working dir (#19101)
ydshieh Sep 19, 2022
fbe8464
Added type hints for TFConvBertModel (#19088)
kishore-s-15 Sep 19, 2022
1bbad7a
Added Type hints for VIT MAE (#19085)
kishore-s-15 Sep 19, 2022
fe5e7ce
Add type hints for TF MPNet models (#19089)
kishore-s-15 Sep 19, 2022
6f25d10
Added type hints to ResNetForImageClassification (#19084)
kishore-s-15 Sep 19, 2022
0d1ba2d
added type hints (#19076)
daspartho Sep 19, 2022
e7206ce
Improve vision models docs (#19103)
NielsRogge Sep 19, 2022
6be338f
correct spelling in README (#19092)
flozi00 Sep 19, 2022
3b0cecb
Don't warn of move if cache is empty (#19109)
sgugger Sep 19, 2022
6227078
HPO: keep the original logic if there's only one process, pass the tr…
sywangyi Sep 19, 2022
801ebd0
Add documentation of Trainer.create_model_card (#19110)
sgugger Sep 19, 2022
261301d
Added type hints for YolosForObjectDetection (#19086)
kishore-s-15 Sep 19, 2022
c81ebd1
Beit postprocessing (#19099)
alaradirik Sep 20, 2022
cc567e0
Fix the wrong schedule (#19117)
ydshieh Sep 20, 2022
6740341
Change document question answering pipeline to always return an array…
ankrgyl Sep 20, 2022
de26241
german processing (#19121)
flozi00 Sep 20, 2022
36e356c
Fix: update ltp word segmentation call in mlm_wwm (#19047)
xyh1756 Sep 20, 2022
36b9a99
Fix BeitFeatureExtractor postprocessing (#19119)
alaradirik Sep 20, 2022
06f341d
Add a missing space in a script arg documentation (#19113)
bryant1410 Sep 20, 2022
18643ff
Skip `test_export_to_onnx` for `LongT5` if `torch` < 1.11 (#19122)
ydshieh Sep 20, 2022
ef6741f
Fix GLUE MNLI when using `max_eval_samples` (#18722)
lvwerra Sep 21, 2022
9e95706
Add post_process_semantic_segmentation method to SegFormer (#19072)
alaradirik Sep 21, 2022
da6a1b6
[BugFix] Fix fsdp option on shard_grad_op. (#19131)
ZHUI Sep 21, 2022
e7fdfc7
Add post_process_semantic_segmentation method to DPTFeatureExtractor …
alaradirik Sep 21, 2022
486134e
Fix FlaxPretTrainedModel pt weights check (#19133)
mishig25 Sep 21, 2022
114295c
Refuse Datasets 2.5.0 while waiting for a patch
sgugger Sep 21, 2022
66154a6
suppoer deps from github (#19141)
lhoestq Sep 21, 2022
451df72
Fix dummy creation for multi-frameworks objects (#19144)
sgugger Sep 21, 2022
d5848a5
Allowing users to use the latest `tokenizers` release ! (#19139)
Narsil Sep 21, 2022
3c7b965
Add some tests for check_dummies (#19146)
sgugger Sep 21, 2022
c7fd289
Fixed typo in generation_utils.py (#19145)
nbalepur Sep 21, 2022
126a739
Add support for conditional detr (#18948)
DeppMeng Sep 22, 2022
9393f96
[fix] Add DeformableDetrFeatureExtractor (#19140)
NielsRogge Sep 22, 2022
4d0f8c0
Add `accelerate` support for ViLT (#18683)
younesbelkada Sep 22, 2022
2d9853b
MSN (Masked Siamese Networks) for ViT (#18815)
sayakpaul Sep 22, 2022
cf6308e
Improve conditional detr docs (#19154)
NielsRogge Sep 22, 2022
1b5ab39
TF: check embeddings range (#19102)
gante Sep 22, 2022
83dc637
Reduce LR for TF MLM example test (#19156)
Rocketknight1 Sep 22, 2022
e5b7cff
update perf_train_cpu_many doc (#19151)
sywangyi Sep 22, 2022
74a3ea4
Bump oauthlib in /examples/research_projects/decision_transformer (#1…
dependabot[bot] Sep 22, 2022
3a396c5
fix: ckpt paths. (#19159)
sayakpaul Sep 22, 2022
8d59385
Fix TrainingArguments documentation (#19162)
sgugger Sep 22, 2022
49629e7
fix HPO DDP GPU problem (#19168)
sywangyi Sep 23, 2022
905635f
[WIP] Trainer supporting evaluation on multiple datasets (#19158)
timbmg Sep 23, 2022
7e84723
Add semantic segmentation post-processing method to MobileViT (#19105)
alaradirik Sep 23, 2022
fe01ec3
Detr preprocessor fix (#19007)
alaradirik Sep 23, 2022
49bf569
Add doctests to Perceiver examples (#19129)
stevenmanton Sep 23, 2022
0cea8d5
Add offline runners info in the Slack report (#19169)
ydshieh Sep 23, 2022
ece7624
Fix incorrect comments about atten mask for pytorch backend (#18728)
lygztq Sep 23, 2022
6395d12
Fixed type hint for pipelines/check_task (#19150)
Fei-Wang Sep 23, 2022
5da6afd
Update run_clip.py (#19130)
enze5088 Sep 23, 2022
fa4eeb4
german training, accelerate and model sharing (#19171)
flozi00 Sep 23, 2022
71fc331
Separate Push CI images from Scheduled CI (#19170)
ydshieh Sep 26, 2022
408b5e3
Remove pos arg from Perceiver's Pre/Postprocessors (#18602)
aielawady Sep 26, 2022
98af4f9
Bump protobuf in /examples/research_projects/decision_transformer (#1…
dependabot[bot] Sep 26, 2022
ea75e9f
Use `assertAlmostEqual` in `BloomEmbeddingTest.test_logits` (#19200)
ydshieh Sep 26, 2022
216b2f9
Move the model type check (#19027)
ankrgyl Sep 26, 2022
c20b2c7
Use repo_type instead of deprecated datasets repo IDs (#19202)
sgugger Sep 26, 2022
be4f269
Updated hf_argparser.py (#19188)
IMvision12 Sep 26, 2022
ca08863
Add warning for torchaudio <= 0.10 in MCTCTFeatureExtractor (#19203)
ydshieh Sep 26, 2022
a32f97c
Fix cached_file in offline mode for cached non-existing files (#19206)
sgugger Sep 26, 2022
7132d55
Remove unused `cur_len` in generation_utils.py (#18874)
ekagra-ranjan Sep 27, 2022
ea540a5
add wav2vec2_alignment (#16782)
arijitx Sep 27, 2022
88f597b
add doc for hyperparameter search (#19192)
sywangyi Sep 27, 2022
226b0e4
Add a use_parallel_residual argument to control the residual computin…
NinedayWang Sep 27, 2022
e3a30e2
translated add_new_pipeline (#19215)
nickprock Sep 27, 2022
34be08e
More tests for regression in cached non existence (#19216)
sgugger Sep 27, 2022
2d95695
Use `math.pi` instead of `torch.pi` in `MaskFormer` (#19201)
ydshieh Sep 27, 2022
2df6028
Added tests for yaml and json parser (#19219)
IMvision12 Sep 27, 2022
942fa8c
Fix small use_cache typo in the docs (#19191)
ankrgyl Sep 28, 2022
a357ed5
Generate: add warning when left padding should be used (#19067)
gante Sep 28, 2022
22d37a9
Fix deprecation warning for return_all_scores (#19217)
ogabrielluiz Sep 28, 2022
de359c4
Fix doctest for `TFDeiTForImageClassification` (#19173)
ydshieh Sep 28, 2022
9c6aeba
Document and validate typical_p in generation (#19128)
mapmeld Sep 28, 2022
4a0b958
Fix trainer seq2seq qa.py evaluate log and ft script (#19208)
iamtatsuki05 Sep 28, 2022
64998a5
Fix cache names in CircleCI jobs (#19223)
ydshieh Sep 28, 2022
0fc68a7
Fix seq2seq QA example
sgugger Sep 28, 2022
990936a
Move AutoClasses under Main Classes (#19163)
stevhliu Sep 29, 2022
6957350
Focus doc around preprocessing classes (#18768)
stevhliu Sep 29, 2022
99c3249
Fix confusing working directory in Push CI (#19234)
ydshieh Sep 29, 2022
9d732fd
XGLM - Fix Softmax NaNs when using FP16 (#18057)
gsarti Sep 29, 2022
bb6fa06
Add a getattr method, which replaces _module_getattr in torch.fx.Trac…
michaelbenayoun Sep 29, 2022
0dc7b3a
[TensorFlow] Adding GroupViT (#18020)
ariG23498 Sep 29, 2022
ba9e336
Fix `m2m_100.mdx` doc example missing `labels` (#19149)
Mustapha-AJEGHRIR Sep 29, 2022
3a27ba3
Fix opt softmax small nit (#19243)
younesbelkada Sep 29, 2022
902d30b
Use `hf_raise_for_status` instead of deprecated `_raise_for_status` (…
Wauplin Sep 29, 2022
b79028f
Fix TrainingArgs argument serialization (#19239)
atturaioe Sep 29, 2022
655f72a
Fix test fetching for examples (#19237)
sgugger Sep 29, 2022
01eb34a
Improve DETR post-processing methods (#19205)
alaradirik Sep 29, 2022
cca6e6f
Cast TF generate() inputs (#19232)
Rocketknight1 Sep 29, 2022
f16bbf1
Skip pipeline tests (#19248)
sgugger Sep 29, 2022
163cd15
Add job names in Past CI artifacts (#19235)
ydshieh Sep 29, 2022
1a1893e
Update Past CI report script (#19228)
ydshieh Sep 29, 2022
49d62b0
[Wav2Vec2] Fix None loss in doc examples (#19218)
rbsteinm Sep 29, 2022
f3d2f7a
Add MarkupLM (#19198)
NielsRogge Sep 30, 2022
4fd32a1
Catch `HFValidationError` in `TrainingSummary` (#19252)
ydshieh Sep 30, 2022
368b649
Rebase ESM PR and update all file formats (#19055)
Rocketknight1 Sep 30, 2022
582d085
Add expected output to the sample code for `ViTMSNForImageClassificat…
sayakpaul Sep 30, 2022
e396358
Add stop sequence to text generation pipeline (#18444)
KMFODA Sep 30, 2022
dad578e
Add notebooks (#19259)
JingyaHuang Sep 30, 2022
3e2dd7f
Poc to use safetensors (#19175)
sgugger Sep 30, 2022
2fba98e
Add `beautifulsoup4` to the dependency list (#19253)
ydshieh Sep 30, 2022
f33858d
Fix Encoder-Decoder testing issue about repo. names (#19250)
ydshieh Sep 30, 2022
6a08162
Fix cached lookup filepath on windows for hub (#19178)
kjerk Sep 30, 2022
cfb777f
Docs - Guide to add a new TensorFlow model (#19256)
gante Sep 30, 2022
5cd16f0
time series forecasting model (#17965)
kashif Sep 30, 2022
36f52e9
Restructure DETR post-processing, return prediction scores (#19262)
alaradirik Oct 3, 2022
c28d04e
Update no_trainer script for summarization (#19277)
divyanshugit Oct 3, 2022
18c0620
Don't automatically add bug label (#19302)
sgugger Oct 3, 2022
68f50f3
Breakup export guide (#19271)
stevhliu Oct 3, 2022
008531c
Update Protobuf dependency version to fix known vulnerability (#19247)
qthequartermasterman Oct 3, 2022
ca26277
Bump joblib from 0.16.0 to 1.2.0 in /examples/research_projects/lxmer…
dependabot[bot] Oct 3, 2022
c7ec0af
Bump joblib in /examples/research_projects/decision_transformer (#19270)
dependabot[bot] Oct 3, 2022
4c962d5
Bump joblib in /examples/research_projects/visual_bert (#19269)
dependabot[bot] Oct 3, 2022
534cd8f
Update README.md (#19309)
ShubhamJagtap2000 Oct 4, 2022
fe10796
[Docs] Fix link (#19313)
patrickvonplaten Oct 4, 2022
3a1a56a
Fix for sequence regression fit() in TF (#19316)
Rocketknight1 Oct 4, 2022
ac5ea74
Added Type hints for LED TF (#19315)
IMvision12 Oct 4, 2022
9b63016
Added type hints for TF: rag model (#19284)
debjit-bw Oct 4, 2022
cc263e9
alter retrived to retrieved (#18863)
gouqi666 Oct 4, 2022
ca3ebc4
ci(stale.yml): upgrade actions/setup-python to v4 (#19281)
oscard0m Oct 4, 2022
cd024da
ci(workflows): update actions/checkout to v3 (#19280)
oscard0m Oct 4, 2022
f134d38
wrap forward passes with torch.no_grad() (#19279)
daspartho Oct 4, 2022
2403dbd
wrap forward passes with torch.no_grad() (#19278)
daspartho Oct 4, 2022
d6e9204
wrap forward passes with torch.no_grad() (#19274)
daspartho Oct 4, 2022
a978288
wrap forward passes with torch.no_grad() (#19273)
daspartho Oct 4, 2022
6fd254a
Removing BertConfig inheritance from LayoutLMConfig (#19307)
arnaudstiegler Oct 4, 2022
6dce9e0
docker-build: Update actions/checkout to v3 (#19288)
Sushrut1101 Oct 4, 2022
587d84b
Add `BloomForQuestionAnswering` (#19310)
younesbelkada Oct 4, 2022
971da2e
Clamping hidden state values to allow FP16 (#19229)
SSamDav Oct 4, 2022
bf7eb0c
Remove interdependency from OpenAI tokenizer (#19327)
E-Aho Oct 4, 2022
6268694
removing XLMConfig inheritance from FlaubertConfig (#19326)
D3xter1922 Oct 4, 2022
07e94bf
Maskformer post-processing fixes and improvements (#19172)
alaradirik Oct 5, 2022
512fa41
Removed interdependency of BERT's Tokenizer in tokenization of prophe…
divyanshugit Oct 5, 2022
e12bbe3
Remove bert interdependency from clip tokenizer (#19332)
shyamsn97 Oct 5, 2022
c54bb1a
[WIP]remove XLMTokenizer inheritance from FlaubertTokenizer (#19330)
D3xter1922 Oct 5, 2022
60db81f
Making camembert independent from roberta, clean (#19337)
Mustapha-AJEGHRIR Oct 5, 2022
2f53ab5
Add sudachi and jumanpp tokenizers for bert_japanese (#19043)
r-terada Oct 5, 2022
e794ca5
Frees LongformerTokenizer of the Roberta dependency (#19346)
srhrshr Oct 5, 2022
4cbc797
Change `BloomConfig` docstring (#19336)
younesbelkada Oct 5, 2022
c875a96
Test failing test while we resolve the issue. (#19355)
sgugger Oct 5, 2022
071df6e
Call _set_save_spec() when creating TF models (#19321)
Rocketknight1 Oct 5, 2022
226b8ef
correct typos in README (#19304)
paulaxisabel Oct 5, 2022
d9101b7
Removes Roberta and Bert config dependencies from Longformer (#19343)
srhrshr Oct 5, 2022
ad98642
Fix gather for metrics (#19360)
muellerzr Oct 5, 2022
7598791
Fix MaskFormer failing postprocess tests (#19354)
alaradirik Oct 5, 2022
45e1403
Add WhisperModel to transformers (#19166)
ArthurZucker Oct 5, 2022
bad353c
Fix DETR segmentation postprocessing output (#19363)
alaradirik Oct 5, 2022
7e7f62b
Fix pipeline tests for Roberta-like tokenizers (#19365)
sgugger Oct 5, 2022
f0b4901
🚨 🚨 🚨 Fix ViT parameter initialization (#19341)
alaradirik Oct 6, 2022
ce26201
Change link of repojacking vulnerable link (#19393)
Ilaygoldman Oct 6, 2022
ae3e3bc
fix docs example, add object_detection to DETR docs (#19377)
alaradirik Oct 6, 2022
7e348aa
Making `ConvBert Tokenizer` independent from `bert Tokenizer` (#19347)
IMvision12 Oct 7, 2022
46fd04b
Fix gather for metrics (#19389)
muellerzr Oct 7, 2022
969534a
Added Type hints for XLM TF (#19333)
IMvision12 Oct 7, 2022
e162ceb
add ONNX support for swin transformer (#19390)
bibhabasumohapatra Oct 7, 2022
b29ebdf
removes prophet config dependencies from xlm-prophet (#19400)
srhrshr Oct 7, 2022
41ec5d0
Added type hints for TF: TransfoXL (#19380)
thliang01 Oct 7, 2022
56af8df
HF <-> megatron checkpoint reshaping and conversion for GPT (#19317)
pacman100 Oct 7, 2022
331ea01
Remove unneded words from audio-related feature extractors (#19405)
osanseviero Oct 7, 2022
e9a49ba
[WIP] Add ZeroShotObjectDetectionPipeline (#18445) (#18930)
sahamrit Oct 7, 2022
fa4bcd5
edit: cast attention_mask to long in DataCollatorCTCWithPadding (#19369)
ddobokki Oct 7, 2022
5fef17f
Copy BertTokenizer dependency into retribert tokenizer (#19371)
Davidy22 Oct 7, 2022
a26d71d
Export TensorFlow models to ONNX with dynamic input shapes (#19255)
dwyatte Oct 7, 2022
994b7a4
update attention mask handling (#19385)
ArthurZucker Oct 7, 2022
e6fc201
Remove dependency of Bert from Squeezebert tokenizer (#19403)
rchan26 Oct 7, 2022
c2b83d5
Removed Bert and XML Dependency from Herbert (#19410)
harry7337 Oct 7, 2022
06514b3
Clip device map (#19409)
patrickvonplaten Oct 7, 2022
6ef16f2
Remove Dependency between Bart and LED (slow/fast) (#19408)
Infrared1029 Oct 7, 2022
7418a48
Removed `Bert` interdependency in `tokenization_electra.py` (#19356)
OtherHorizon Oct 7, 2022
34e0cc6
Make `Camembert` TF version independent from `Roberta` (#19364)
Mustapha-AJEGHRIR Oct 7, 2022
de4d71e
Removed Bert dependency from BertGeneration code base. (#19370)
Threepointone4 Oct 7, 2022
983451a
Improve and fix ImageSegmentationPipeline (#19367)
alaradirik Oct 7, 2022
9ac586b
Rework pipeline tests (#19366)
sgugger Oct 7, 2022
d92e22d
Remove ref to is_pipeline_test
sgugger Oct 8, 2022
8b6bba5
Fix `ViTMSNForImageClassification` doctest (#19275)
ydshieh Oct 10, 2022
cbb8a37
Skip `BloomEmbeddingTest.test_embeddings` for PyTorch < 1.10 (#19261)
ydshieh Oct 10, 2022
4107445
Fix repo names for ESM tests (#19451)
Rocketknight1 Oct 10, 2022
1241a49
remove RobertaConfig inheritance from MarkupLMConfig (#19404)
D3xter1922 Oct 10, 2022
83dc49b
Backtick fixed (paragraph 68) (#19440)
kant Oct 10, 2022
3410705
Fixed duplicated line (paragraph #83) Documentation: @sgugger (#19436)
kant Oct 10, 2022
c523a86
fix marianMT convertion to onnx (#19287)
kventinel Oct 10, 2022
7d5ce68
Fix typo in image-classification/README.md (#19424)
zhawe01 Oct 10, 2022
298f6a9
Stop relying on huggingface_hub's private methods (#19392)
LysandreJik Oct 10, 2022
3080bb4
Add onnx support for VisionEncoderDecoder (#19254)
mht-sharma Oct 10, 2022
4824741
Remove dependency of Roberta in Blenderbot (#19411)
rchan26 Oct 10, 2022
ba71bf4
fix: renamed variable name (#18850)
ariG23498 Oct 10, 2022
af69360
Add `OPTForQuestionAnswering` (#19402)
clementapa Oct 10, 2022
e3f028f
Add TF whisper (#19378)
amyeroberts Oct 10, 2022
e150c4e
Fix the error message in run_t5_mlm_flax.py (#19282)
Oct 10, 2022
b0b962c
Add Italian translation for `add_new_model.mdx` (#18713)
Steboss89 Oct 10, 2022
4dd784c
Fix momentum and epsilon values (#19454)
amyeroberts Oct 10, 2022
d866b48
Generate: corrected exponential_decay_length_penalty type hint (#19376)
ShivangMishra Oct 10, 2022
9df953a
Fix misspelled word in docstring (#19415)
Bearnardd Oct 10, 2022
25cfd91
Fixed a non-working hyperlink in the README.md file (#19434)
MikailINTech Oct 10, 2022
a7bc422
fix (#19469)
ydshieh Oct 10, 2022
692c5be
wrap forward passes with torch.no_grad() (#19439)
daspartho Oct 10, 2022
870a954
wrap forward passes with torch.no_grad() (#19438)
daspartho Oct 10, 2022
d739a70
wrap forward passes with torch.no_grad() (#19416)
daspartho Oct 10, 2022
c6a928c
wrap forward passes with torch.no_grad() (#19414)
daspartho Oct 10, 2022
5f5e264
wrap forward passes with torch.no_grad() (#19413)
daspartho Oct 10, 2022
df2f281
wrap forward passes with torch.no_grad() (#19412)
daspartho Oct 10, 2022
1010097
Dev version
LysandreJik Oct 10, 2022
d7d71c8
Compute true loss
duongna21 Aug 3, 2022
b5ccda0
final
duongna21 Aug 3, 2022
135cb98
fixup
duongna21 Aug 3, 2022
d94d04f
final
duongna21 Aug 3, 2022
d0ccf00
final
duongna21 Aug 3, 2022
418f6c4
Update examples/flax/language-modeling/run_bart_dlm_flax.py
duongna21 Aug 9, 2022
b90b5ae
jax.tree_map => jax.tree_util.tree_map
duongna21 Aug 9, 2022
6c0ae1c
final
duongna21 Oct 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions examples/flax/image-captioning/run_image_captioning_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
batch_idx = np.arange(len(dataset))

for idx in range(steps):

start_idx = batch_size * idx
end_idx = batch_size * (idx + 1)

Expand All @@ -347,7 +346,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf


def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="train"):

if train_time:
summary_writer.scalar("train_time", train_time, step)

Expand Down Expand Up @@ -782,11 +780,9 @@ def blockwise_data_loader(
num_splits = steps // steps_per_block + int(steps % steps_per_block > 0)

for idx in range(num_splits):

if not block_size:
_ds = ds
else:

start_idx = block_size * idx
end_idx = block_size * (idx + 1)

Expand Down Expand Up @@ -926,8 +922,9 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):

# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum() / padding_mask.sum()
return loss
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels

# Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0):
Expand All @@ -936,29 +933,38 @@ def train_step(state, batch, label_smoothing_factor=0.0):
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels

grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")

# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_map(lambda x: x / num_labels, loss)

# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)

metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")

return new_state, metrics

# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)

# summarize metrics
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")

# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_map(lambda x: x / num_labels, loss)

metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics

# Define generation function
Expand Down Expand Up @@ -1024,7 +1030,6 @@ def evaluation_loop(
ckpt_dir: str = "",
is_prediction=False,
):

logger.info(f"*** {'Predict' if is_prediction else 'Evaluate'} ***")

metrics = []
Expand Down Expand Up @@ -1103,12 +1108,10 @@ def evaluation_loop(
logger.info(desc)

if jax.process_index() == 0:

if not os.path.isdir(os.path.join(training_args.output_dir, ckpt_dir)):
os.makedirs(os.path.join(training_args.output_dir, ckpt_dir), exist_ok=True)

if metrics:

# Save metrics (only for the evaluation/prediction being done along with training)
if has_tensorboard and training_args.do_train:
write_metric(
Expand Down Expand Up @@ -1143,7 +1146,6 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):
input_rng = None

if training_args.do_train:

cur_step = 0
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
Expand All @@ -1166,7 +1168,6 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):

# train
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):

cur_step += 1
batch = next(train_batches)
batch_start = time.time()
Expand All @@ -1177,7 +1178,6 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):

# log and save info
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:

_train_metric = unreplicate(train_metric)
desc = (
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
Expand Down Expand Up @@ -1217,7 +1217,6 @@ def predict(rng: jax.random.PRNGKey, dataset: Dataset):

# log and save info
if training_args.logging_steps <= 0:

logger.info(desc)

with open(os.path.join(training_args.output_dir, "log"), "a", encoding="UTF-8") as fp:
Expand Down
2 changes: 1 addition & 1 deletion examples/flax/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ The example script uses the 🤗 Datasets library. You can easily customize them
To setup all relevant files for training, let's create a directory.

```bash
mkdir ./norwegian-roberta-base
mkdir ./norwegian-bart-base
```

### Train tokenizer
Expand Down
28 changes: 17 additions & 11 deletions examples/flax/language-modeling/run_bart_dlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,19 +799,25 @@ def loss_fn(params):
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

# take average
loss = loss.sum() / label_mask.sum()
loss = loss.sum()
num_labels = label_mask.sum()

return loss
return loss, num_labels

grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")

metrics = jax.lax.pmean(
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
)
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_map(lambda x: x / num_labels, loss)
duongna21 marked this conversation as resolved.
Show resolved Hide resolved

# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)

metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics, new_dropout_rng

# Create parallel version of the train step
Expand Down Expand Up @@ -888,7 +894,7 @@ def eval_step(params, batch):
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
Comment on lines -891 to +897
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintentional change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit-gandhi I noticed that with drop_last=False the eval loss will become nan at the beginning of training, but eval accuracy is still on track. It appears to occur with both run_bart_dlm_flax and run_summarization so I temporarily turned it off. It would be great if you could take a look and fix it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! What immediately jumps out to me is that num_labels is 0, causing the 'true loss' to be nan. You didn't get this behaviour previously with the pmap operation? What eval batch size are you using?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bug occurs without dividing loss by num_labels actually. I took a quick look at the eval loss computed on every token at each step, the last losses was nan.

python run_summarization_flax.py \
	--output_dir ./bart-base-xsum \
	--model_name_or_path facebook/bart-base \
	--tokenizer_name facebook/bart-base \
	--dataset_name="xsum" \
	--do_train --do_eval --do_predict --predict_with_generate \
	--num_train_epochs 6 \
	--learning_rate 5e-5 --warmup_steps 0 \
	--per_device_train_batch_size 64 \
	--per_device_eval_batch_size 64 \
	--overwrite_output_dir \
	--max_source_length 512 --max_target_length 64 \
	--push_to_hub 

Printed tensor has a shape of (8, 64, 64), provided that I have 8 TPU cores, per_device_eval_batch_size=64 and max_target_length=64.
ảnh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no! That's interesting to see. Is this an artefact of using the psum? As in, were the losses nan when we used a pmap previously? If so, we'll need to address this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The losses were nan right after loss = optax.softmax_cross_entropy(logits, soft_labels). Great if you could have a look at this issue!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, this problem doesn't seem to be related to this PR. Flax examples are young so we'll improve it step by step :)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @duongna21, just wondering what the status is on the nan issue being discussed here. I'm running into this issue while using the recently added run_bart_dlm_flax.py script but am very new to Flax/Jax so haven't been able to really make sense of it yet. Is run_bart_dlm_flax.py useable for model pre-training in its current state? Just as a side note, I haven't seen the nan issue when running run_t5_mlm_flax.py on my training data. Thanks in advance for any clarification!

Copy link
Contributor Author

@duongna21 duongna21 Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tannonk I'm sorry for the late reply. I believe this error is related to the drop_last=False option. The training will be fine if you set drop_last=True, at the cost that a few examples in the last batch will be skipped. Nice to see anyone is able to work on this weird bug.

Comment on lines -891 to +897
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's drop_last=False which is causing the eval loss to be nan?

Copy link
Contributor Author

@duongna21 duongna21 Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit-gandhi Thanks for the review. Rebased.


eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
Expand Down Expand Up @@ -928,7 +934,7 @@ def eval_step(params, batch):
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
Comment on lines -931 to +937
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintentional change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here! Is drop_last causing issues with the eval loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
Expand Down
23 changes: 15 additions & 8 deletions examples/flax/language-modeling/run_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,18 +714,25 @@ def loss_fn(params):
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask

# take average
loss = loss.sum() / label_mask.sum()
loss = loss.sum()
num_labels = label_mask.sum()

return loss
return loss, num_labels

grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")

# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_map(lambda x: x / num_labels, loss)

# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)

metrics = jax.lax.pmean(
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}

return new_state, metrics, new_dropout_rng

Expand Down
2 changes: 0 additions & 2 deletions examples/flax/language-modeling/run_t5_mlm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ class FlaxDataCollatorForT5MLM:
decoder_start_token_id: int

def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:

# convert list to dict and tensorize input
batch = BatchEncoding(
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
Expand Down Expand Up @@ -397,7 +396,6 @@ def filter_input_ids(self, input_ids, sentinel_ids):
return input_ids

def random_spans_noise_mask(self, length):

"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .

Noise mask consisting of random spans of noise tokens.
Expand Down
34 changes: 22 additions & 12 deletions examples/flax/summarization/run_summarization_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,9 @@ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):

# ignore padded tokens from loss
loss = loss * padding_mask
loss = loss.sum() / padding_mask.sum()
return loss
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels

# Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0):
Expand All @@ -785,29 +786,38 @@ def train_step(state, batch, label_smoothing_factor=0.0):
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
return loss, num_labels

grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")

# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_map(lambda x: x / num_labels, loss)

# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)

metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")

return new_state, metrics

# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)

# summarize metrics
loss, num_labels = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")

# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_map(lambda x: x / num_labels, loss)

metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics

# Define generation function
Expand Down