-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Conv1d for TDNN #25728
Use Conv1d for TDNN #25728
Conversation
# hidden_states = hidden_states.transpose(1, 2) | ||
# hidden_states = self.kernel(hidden_states) | ||
|
||
hidden_states = hidden_states.transpose(1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean, like this solution a lot @gau-nernst! Do you know what kind of speed-ups we get already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will try to run some benchmarks and let you know.
Since I'm also working on LoRA at my work, I realize doing this will bypass any "magic" to be applied on nn.Linear, since we don't call nn.Linear.forward(). Although LoRA is not likedly to be applied here, since the TDNN layers will most likely be trained from scratch, it may lead to unexpected behavior if the user wants to "hack" nn.Linear.
Another solution is to replace nn.Linear completely with nn.Conv1d, and apply some load_state_dict hook for backward compatability. But I think PyTorch's load state dict hook is only available recently. I will check on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool cool! That's an astute observation regarding LoRA. I think it's probably ok given that the linear TDNN layers are only a fraction of the Transformer encoder ones, and as you say they will be trained from scratch in a round of fine-tuning. Replacing nn.Linear
with nn.Conv1d
might work, but we'll have to propagate this change to Flax and TF to ensure we have cross platform compatibility (this could get quite involved). Overall, the priority is maintaining cross platform compatibility within transformers
, over LoRA from third-party libraries
On RTX 3090, 5s input
The improvement is around 5%. It's small but not insignificant. Another optimization we can do is to make Benchmark script import time
from itertools import product
import torch
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForXVector
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = "cuda"
batch_sizes = (1, 8)
dtypes = (torch.float32, torch.float16)
m = Wav2Vec2ForXVector.from_pretrained("facebook/wav2vec2-xls-r-300m").to(device)
for dtype, bsize in product(dtypes, batch_sizes):
m.to(dtype)
x = torch.randn(bsize, 16_000 * 5, device=device, dtype=dtype)
m.eval()
with torch.no_grad():
m(x) # warmup
torch.cuda.synchronize()
N = 100
with torch.no_grad():
time0 = time.perf_counter()
for _ in range(N):
out = m(x)
torch.cuda.synchronize()
print(f"{dtype}, {bsize=}, forward: {N / (time.perf_counter() - time0):.4f} it/s")
m.train()
m(x)[0].sum().backward() # warmup
time0 = time.perf_counter()
for _ in range(N):
out = m(x)
out[0].sum().backward()
torch.cuda.synchronize()
print(f"{dtype}, {bsize=}, forward+backward: {N / (time.perf_counter() - time0):.4f} it/s") |
Very cool! Thanks for the results @gau-nernst and nice benchmark script! I think it's worth pursuing in this case: 5% is the lower end of what we'd expect to get for I would be against adding a breaking change with the channel-last approach, since this is in contradiction to the scope of the PR where we set out to optimise in an entirely non-breaking way |
Cool! Anything else you want me to add to the PR, perhaps apart from removing the commented out old code? Is there a correctness test that we can add? I didn't benchmark with torch.compile() since Wav2Vec2 did not work with it yet (I hope the relevant PR will be merged soon). Would you want to see benchmark results with torch.compile() for TDNNLayer alone? Perhaps it can even optimize the old way (unfold + linear) |
It looks super clean already @gau-nernst - I think all that's left to do is:
No worries about benchmarking with torch compile - the results you provided previously were more than enough to justify this code change. We won't expect all users to be using PT 2.0, so this speed-up will definitely hold in many cases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @gau-nernst! Really elegant solution 👏 Nice job on bringing the speed-ups without breaking changes
If the slow test mentioned before passes, then this is good to merge for me! Requesting final review from @ArthurZucker
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very elegant!
Regarding this:
Another optimization we can do is to make TDNNLayer accepts channels-first input instead of channels-last (current approach) to avoid transposing data twice every layer. However, this would break backward compatibility if someone relies on TDNNLayer externally (though I think it's unlikely).
TDNN Is not documented so not really part of our Public API, which means we could change it (but mostly only if we really have a significant gain in performance)
It's a breaking change for people who trained a model with LoRA no? That my only concern
I think it's unlikely that someone will apply LoRA on the TDNN layers. Most likely they will apply LoRA on the transformer backbone only, and train the TDNN layers from scratch. |
My question is rather, what if someone has already apply LoRA on the TDNN layers? (Though it might not be recommended, someone who does not specifically know the model could have just quantized all the linear layers no?) 😉 |
I can't answer that, though I doubt that people who know how to use LoRA, and use it on TDNN layers specifically (most LoRA guides only show how to use it with transformers as far as I know), will face problems with this change. You guys from HF should decide whether this outweighs the cost i.e. speed-up vs annoy people who use LoRA on TDNN layers, and what kind of guarantee the library should provide i.e. a model surgery technique (e.g. LoRA) should still work after an update. |
That is true: LORA will now not apply to the linear layers. Just to clarify, is this just concerned with applying lora weights to the full-precision model and doing a new fine-tuning run? Or, would it actually break inference if someone wants to use lora fine-tuned weights with the TDNN model? If it's the former, I think it's fine since TDNN is not documented and the linear layers are a fraction of the transformer ones. If it's the latter, then we maybe need to think of a workaround |
Pretty sure it will break inference, but we can probably have some kind of fix / deprecation cycle for this 😓 |
hidden_states = hidden_states.transpose(1, 2) | ||
hidden_states = self.kernel(hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @younesbelkada as well: currently we have un-optimised TDNN code, where we call nn.functional.unfold
to do a conv1d, then apply a linear layer self.kernel
@gau-nernst has proposed is a solution where we still use the weights of the linear self.kernel
layer (rather than switching to a conv1d layer which would change the state dict), but then use nn.functional.conv1d
to do the convolution
=> the benchmarks show there's about a 10% performance gain by doing this
However, we worry that since we don't call self.kernel.forward
, the LoRA layers for this layer will no longer work for inference. Is this hunch correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, these changes will make LoRA unusable for the kernel weights on PEFT side as we call torch.nn.functional.conv1d
instead of the forward of the module itself.
However, usually we only adapt self attention modules (query, key, value) and in rare cases MLP layers of the transformer block. So I am not sure users really adapt kernel
layers.
I am not sure if we could make that PR BC (through a config variable maybe?)
cc @pacman100 & @BenjaminBossan as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is an edge case that is unlikely to be practically used, I would be fine with breaking backwards compatibility here, given the benefits.
If we are worried about this, however, we could make a check and raise an error/give a warning if a user happens to apply LoRA here. That check could be made here (check if self.kernel
looks like a LoRA layer) or we could add a check in PEFT if we detect that a user tries to apply LoRA to this specific layer of this specific model.
Btw. this also affects other adapters such as IA³.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome - thanks both for your input! It sounds like it's super low likelihood that a user is applying LoRA to these TDNN kernel layers, so I'd be happy proceeding with these changes if it's good for you @gau-nernst @ArthurZucker?
If not, and you'd like to preserve backwards comp for this edge case, I like @BenjaminBossan's idea of checking whether the weights look like a LoRA layer, and triggering a warning if so. There's also the option of having two implementations for the TDNN layer:
- Old, un-optimised variant that is LoRA compatible
- New, optimised variant that is not LoRA compatible
=> we could use 2 by default, but switch to 1 if the user has applied LoRA to the kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with just warning the user if there are lora weights !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay so the problem here is that Wav2Vec2 doesn't have input embeddings: it works directly with the raw audio waveform and computes positional embeddings directly from the input. Since there is no notion of input embeddings, we don't implement the method .get_input_embeddings
. This breaks the PEFT LoRA logic, which expects this method to be implemented.
This suggests that Wav2Vec2 is not compatible at all with PEFT LoRA as things currently stand. Would you mind indeed opening an issue on PEFT with the codesnippet you shared, and a copy-and-paste explanation of why it doesn't work?
This means though that we don't need to worry about breaking LoRA for Wav2Vec2ForXVector
, since it doesn't work with it anyway at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error is due to setting gradient_checkpointing=True
in Wav2Vec2, which is the default for facebook/wav2vec2-base
. Other Wav2Vec2 checkpoints, like facebook/wav2vec2-xls-r-300m
, have gradient_checkpointing=False
, thus there is no error, and PEFT LoRA still works with them. Overriding gradient_checkpointing=False
can also make facebook/wav2vec2-base
works with PEFT LoRA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok thanks for the clarification! Let's stick to our original strategy then of raising a warning if there are LoRA weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @gau-nernst - would you be interested in finishing this PR? Now that we've agreed on the logic here and you've done a great job at laying down the foundations, it should be quite fast to tidy it up and complete the feature!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @sanchit-gandhi. I'm happy to finish this. Do let me know which parts you want to be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the LoRA warnings, this PR looks great @gau-nernst 🙌 Requesting final review from @ArthurZucker!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really sorry for the late review!
Can you rebase and I'll merge!
Great work here on finding an elegant solution 🤗 |
@ArthurZucker i have merged master and make some small changes accordingly |
Thanks! Merging 🤗 |
* use conv for tdnn * run make fixup * update TDNN * add PEFT LoRA check * propagate tdnn warnings to others * add missing imports * update TDNN in wav2vec2_bert * add missing imports
…uggingface#29145 (#1) * Add qwen2 (#29145) * add config, modeling, and tokenization * add auto and init * update readme * update readme * update team name * fixup * fixup * update config * update code style * update for fixup * update for fixup * update for fixup * update for testing * update for testing * fix bug for config and tokenization * fix bug for bos token * not doctest * debug tokenizer * not doctest * debug tokenization * debug init for tokenizer * fix style * update init * delete if in token auto * add tokenizer doc * add tokenizer in init * Update dummy_tokenizers_objects.py * update * update * debug * Update tokenization_qwen2.py * debug * Update convert_slow_tokenizer.py * add copies * add copied from and make style * update files map * update test * fix style * fix merge reading and update tests * fix tests * fix tests * fix style * debug a variable in readme * Update src/transformers/models/qwen2/configuration_qwen2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update test and copied from * fix style * update qwen2 tokenization and tests * Update tokenization_qwen2.py * delete the copied from after property * fix style * update tests * update tests * add copied from * fix bugs * update doc * add warning for sliding window attention * update qwen2 tokenization * fix style * Update src/transformers/models/qwen2/modeling_qwen2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix tokenizer fast --------- Co-authored-by: Ren Xuancheng <jklj077@users.noreply.github.com> Co-authored-by: renxuancheng.rxc <renxuancheng.rxc@alibaba-inc.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix SDPA tests (#28552) * skip bf16 test if not supported by device * fix * fix bis * use is_torch_bf16_available_on_device * use is_torch_fp16_available_on_device * fix & use public llama * use 1b model * fix flacky test --------- Co-authored-by: Your Name <you@example.com> * Allow to train dinov2 with different dtypes like bf16 (#28504) I want to train dinov2 with bf16 but I get the following error in https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/dinov2/modeling_dinov2.py#L635: ``` RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same ``` Since the input dtype is torch.float32, the parameter dtype has to be torch.float32... @LZHgrla and I checked the code of clip vision encoder and found there is an automatic dtype transformation (https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/clip/modeling_clip.py#L181-L182). So I add similar automatic dtype transformation to modeling_dinov2.py. * Fix Switch Transformers When sparse_step = 1 (#28564) Fix sparse_step = 1 I case sparse_step = 1, the current code will not work. * Save `Processor` (#27761) * save processor * Update tests/models/auto/test_processor_auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/test_processing_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Use `weights_only` only if torch >= 1.13 (#28506) * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * [`Core Tokenization`] Support a fix for spm fast models (#26678) * fix * last attempt * current work * fix forward compatibility * save all special tokens * current state * revert additional changes * updates * remove tokenizer.model * add a test and the fix * nit * revert one more break * fix typefield issue * quality * more tests * fix fields for FC * more nits? * new additional changes * how * some updates * the fix * where do we stand * nits * nits * revert unrelated changes * nits nits nits * styling * don't break llama just yet * revert llama changes * safe arg check * fixup * Add a test for T5 * Necessary changes * Tests passing, added tokens need to not be normalized. If the added tokens are normalized, it will the stripping which seems to be unwanted for a normal functioning * Add even more tests, when normalization is set to True (which does not work :sweat: ) * Add even more tests, when normalization is set to True (which does not work :sweat: ) * Update to main * nits * fmt * more and more test * comments * revert change as tests are failing * make the test more readble * nits * refactor the test * nit * updates * simplify * style * style * style convert slow * Update src/transformers/convert_slow_tokenizer.py * chore: Fix multiple typos (#28574) * Add new meta w2v2-conformer BERT-like model (#28165) * first commit * correct default value non causal * update config and modeling code * update converting checkpoint * clean modeling and fix tests * make style * add new config parameters to docstring * fix copied from statements * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * make position_embeddings_type docstrings clearer * clean converting script * remove function not used * clean modeling file * apply suggestion for test file + add convert script to not_doctested * modify tests according to review - cleaner logic and more tests * Apply nit suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add checker of valid position embeddings type * instantiate new layer norm layer with the right eps * fix freeze_feature_encoder since it can be None in some cases * add test same output in convert script * restore wav2vec2conformer and add new model * create processor and FE + clean * add new model code * fix convert script and set default config parameters * correct model id paths * make style * make fix-copies and cleaning files * fix copied from statements * complete .md and fixe copies * clean convert script argument defaults * fix config parameters docstrings * fix config docstring * add copied from and enrich FE tests * fix copied from and repo-consistency * add autotokenizer * make test input length shorter and change docstring code * fix docstrings and copied from * add add_adapter to ASR training example * make testing of adapters more robust * adapt to multi adapter layers * refactor input_values->input_features and remove w2v2-bert feature extractor * remove pretraining model * remove depreciated features and useless lines * add copied from and ignore statements to modeling tests * remove pretraining model #2 * change import in convert script * change default in convert script * update readme and remove useless line * Update tests/models/wav2vec2_bert/test_processor_wav2vec2_bert.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * refactor BERT to Bert for consistency * remove useless ignore copy statement * add persistent to buffer in rotary * add eps in LayerNorm init and remove copied from * add adapter activation parameters and add copied from statements * Fix copied statements and add unitest.skip reasons * add copied statement in test_processor * refactor processor * make style * replace numpy random by torch rand * remove expected output CTC * improve converting script with processor class * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove gumbel class * remove tests related to previously deleted class * Update src/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * correct typos * remove uused parameters * update processor to takes both text and audio * update checkpoints * update expected output and add ctc expected output * add label_attention_mask * replace pt with np in processor tests * fix typo * revert to behaviour with labels_attention_mask --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Use `LoggingLevel` context manager in 3 tests (#28575) * inside with LoggingLevel * remove is_flaky --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Fix the documentation checkpoint for xlm-roberta-xl (#28567) * Fix the documentation checkpoint for xlm-roberta-xl * Improve docstring consistency * [ASR Pipe] Update init to set model type and subsequently call parent init method (#28486) * add image processor arg * super * rm args * [Whisper Tok] Move token ids to CPU when computing offsets (#28485) * move token ids to cpu * check for torch attr * [Whisper] Fix audio classification with weighted layer sum (#28563) * fix * tests * fix test * Making CTC training example more general (#28582) * add w2v2bert compatibility * Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Don't save `processor_config.json` if a processor has no extra attribute (#28584) * not save if empty * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * v4.38.dev.0 * Add w2v2bert to pipeline (#28585) * generalize asr pipeline to fbank models * change w2v2 pipeline output * Update test_pipelines_automatic_speech_recognition.py * feat: Sequential beam search (#26304) * [Whisper] Finalize batched SOTA long-form generation (#27658) * finalize * make fix copies whisper * [Tests] Make sure that we don't run tests mulitple times * Update src/transformers/models/whisper/modeling_whisper.py * [Tests] Make sure that we don't run tests mulitple times * fix more * improve * improve * improve further * improve more * improve * fix more * git commit and git push * fix more * fix more * fix more * New try * Fix more whisper stuff * Improve * correct more * correct more * correct more * Fix some tests * Add more tests * correct more * correct more * correct more * push * correct more * Fix more * Better * without dec mask * correct more * clean * save intermediate * Fix more * Fix VAD for large-v2 * Save new * Correct more * make cleaner * correct tests * correct src * Finish * Fix more * Fix more * finish * Fix edge cases * fix return_dict_in_generate * fix all tests * make style * add docstrings * add docstrings * Fix logit processor * make style * fix pipeline test * fix more style * Apply suggestions from code review * apply feedback Sanchit * correct more * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct more * correct more * correct more * Fix staticmethod * correct more * fix * fix slow tests * make style * fix tokenizer test * fix tokenizer test * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * finish * finish * revert kwargs change --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix wrong xpu device in DistributedType.MULTI_XPU mode (#28386) * remove elif xpu * remove redudant code * [SigLIP] Don't pad by default (#28578) First draft * [`Llava`] Fix convert_llava_weights_to_hf.py script (#28570) * Update convert_llava_weights_to_hf.py Fix call to `tokenizer.add_tokens` * Add special_tokens to tokenizer.add_tokens in convert_vipllava_weights_to_hf.py * Allow add_tokens for ESM (#28535) * Allow non-special tokens to be added * Add test, fix token adding code * Revert changes to id_to_token and token_to_id * Update the ESM tokenizer to be a bit more standardized * Update src/transformers/models/esm/tokenization_esm.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Fix `_speculative_sampling` implementation (#28508) * RWKV: raise informative exception when attempting to manipulate `past_key_values` (#28600) * Fix auxiliary loss related code in transformers (#28406) * [DETA] fix freeze/unfreeze function * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add freeze/unfreeze test case in DETA * fix type * fix typo 2 * fix : enable aux and enc loss in training pipeline * Add unsynced variables from original DETA for training * modification for passing CI test * make style * make fix * manual make fix * change deta_modeling_test of configuration 'two_stage' default to TRUE and minor change of dist checking * remove print * divide configuration in DetaModel and DetaForObjectDetection * image smaller size than 224 will give topk error * pred_boxes and logits should be equivalent to two_stage_num_proposals * add missing part in DetaConfig * Update src/transformers/models/deta/modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add docstring in configure and prettify TO DO part * change distribute related code to accelerate * Update src/transformers/models/deta/configuration_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/deta/test_modeling_deta.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * protect importing accelerate * change variable name to specific value * wrong import * fix aux_loss in conditional_detr * add test aux_loss * add aux_loss test in deta and table_transformer * fix yolos since it doesn't have auxiliary function * fix maskformer auxiliary_loss related code * make style * change param 'auxiliary_loss' to 'use_auxiliary_loss' * change param 'auxiliary_loss' to 'use_auxiliary_loss' in tests * make style & fix-copies, also revert yolos related parameter * revert variable name 'use_auxiliary_loss' to 'auxiliary_loss' due to DetrConfig * revert variable name in yolos * revert maskformer * add aux_loss test in maskformer * make style * Update src/transformers/models/yolos/configuration_yolos.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * [`GPTNeoX`] Fix BC issue with 4.36 (#28602) * fix dtype issue * add a test * update copied from mentions * nits * fixup * fix copies * Apply suggestions from code review * Fix id2label assignment in run_classification.py (#28590) * Add missing key to TFLayoutLM signature (#28640) Fix missing bbox in LayoutLM signature * Avoid root logger's level being changed (#28638) * avoid root logger's level being changed --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Add config tip to custom model docs (#28601) Add tip to custom model docs * Fix lr_scheduler in no_trainer training scripts (#27872) * Fix lr_scheduler * Fix lr scheduler * [`Llava`] Update convert_llava_weights_to_hf.py script (#28617) * Update convert_llava_weights_to_hf.py script * Remove config update of adding padding to `vocab_size` and `text_config.vocab_size` which causes `ValueError` exception. * Remove keys that ends with `inv_freq` from the state dict. * Add examples and instructions for creating `model_state_dict.bin` that can be used by the script. * Update convert_llava_weights_to_hf.py * Update convert_vipllava_weights_to_hf.py * [`GPTNeoX`] Fix GPTNeoX + Flash Attention 2 issue (#28645) Update modeling_gpt_neox.py * Update image_processing_deformable_detr.py (#28561) * Update image_processing_deformable_detr.py * Changes after running make fix-copies * [`SigLIP`] Only import tokenizer if sentencepiece available (#28636) Only import class if sp available * Fix phi model doc checkpoint (#28581) Co-authored-by: Pashmina Cameron <11311835+pashminacameron@users.noreply.github.com> * get default device through `PartialState().default_device` as it has been officially released (#27256) get default device through `PartialState().default_device` as it has been officially released * integrations: fix DVCLiveCallback model logging (#28653) * Enable safetensors conversion from PyTorch to other frameworks without the torch requirement (#27599) * Initial commit * Requirements & tests * Tests * Tests * Rogue import * Rogue torch import * Cleanup * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * bfloat16 management * Sanchit's comments * Import shield * apply suggestions from code review * correct bf16 * rebase --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> * Enable instantiating model with pretrained backbone weights (#28214) * Enable instantiating model with pretrained backbone weights * Update tests so backbone checkpoint isn't passed in * Remove doc updates until changes made in modeling code * Clarify pretrained import * Update configs - docs and validation check * Update src/transformers/utils/backbone_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clarify exception message * Update config init in tests * Add test for when use_timm_backbone=True * Small test updates --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * `tensor_size` - fix copy/paste error msg typo (#28660) Fix copy/paste error msg typo * Fix windows err with checkpoint race conditions (#28637) Fix windows err * add dataloader prefetch factor in training args and trainer (#28498) * add dataloader prefetch factor in training args and trainer * remove trailing spaces * prevent dataloader_num_workers == 0 and dataloader_prefetch_factor != None dataloader_prefetch_factor works only when data is loaded in a different process as the main one. This commit adds the necessary checks to avoid having prefetch_factor set when there is no such process. * Remove whitespaces in empty line * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Support single token decode for `CodeGenTokenizer` (#28628) convert token id to list in .decode() * Remove deprecated eager_serving fn (#28665) * Remove deprecated eager_serving fn * Fix the input_signature docstring while I'm here * fix a hidden bug of `GenerationConfig`, now the `generation_config.json` can be loaded successfully (#28604) * fix a hidden bug of GenerationConfig * keep `sort_keys=True` to maintain visibility * Update src/transformers/generation/configuration_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update configuration_utils.py in case `obj` is a list, check the items in the list --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update README_es.md (#28612) Fixing grammatical errors in the text * Exclude the load balancing loss of padding tokens in Mixtral-8x7B (#28517) * fix the function load_balancing_loss_func in Mixtral_Moe to include attention_mask * format code using black and ruff * skip computing mask if attention_mask=None * add tests for load balancing loss Mixtral-Moe * fix assert loss is different in mixtral_test * fix pad_leng * use assertNotAlmostEqual and print to debug * remove print for debug * minor updates * reduce rtol and atol * Use save_safetensor to disable safe serialization for XLA (#28669) * Use save_safetensor to disable safe serialization for XLA https://github.com/huggingface/transformers/issues/28438 * Style fixup * Add back in generation types (#28681) * [docs] DeepSpeed (#28542) * config * optim * pre deploy * deploy * save weights, memory, troubleshoot, non-Trainer * done * Improved type hinting for all attention parameters (#28479) * Changed type hinting for all attention inputs to 'Optional[Tuple[torch.FloatTensor,...]] = None' * Fixed the ruff formatting issue * fixed type hinting for all hidden_states to 'Optional[Tuple[torch.FloatTensor, ...]] = None' * Changed type hinting in these 12 scripts modeling_dpr.py,modeling_nat.py,idefics/vision.py,modeling_tf_dpr.py,modeling_luke.py,modeling_swin.py,modeling_tf_swin.py,modeling_blip.py,modeling_tf_blip.py,modeling_donut_swin.py,modeling_dinat.py,modeling_swinv2.py * test fail update * fixed type hinting for these 15 scripts modeling_xlnet.py,modeling_tf_xlnet.py,modeling_led.py,modeling_tf_led.py,modleing_rwkv.py,modeling_dpt.py,modeling_tf_cvt.py,modeling_clip.py,modeling_flax_clip.py,modeling_tf_clip.py,modeling_longformer.py,modeling_tf_longformer.py,modeling_siglip.py,modeling_clap.py,modeling_git.py * Changed type hinting in these 12 scripts modeling_dpr.py,modeling_nat.py,idefics/vision.py,modeling_tf_dpr.py,modeling_luke.py,modeling_swin.py,modeling_tf_swin.py,modeling_blip.py,modeling_tf_blip.py,modeling_donut_swin.py,modeling_dinat.py,modeling_swinv2.py * test fail update * Removed the myvenv file * Fixed type hinting for these 8 scripts modeling_tvlt.py,modeling_sam.py,modeling_tf_sam.py,modeling_tvp.py,modeling_rag.py,modeling_tf_rag.py,modeling_tf_xlm.py,modeling_xlm.py * improve efficient training on CPU documentation (#28646) * update doc * revert * typo fix * refine * add dtypes * Update docs/source/en/perf_train_cpu.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/perf_train_cpu.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/perf_train_cpu.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * no comma * use avx512-vnni --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * [docs] Fix doc format (#28684) * fix hfoptions * revert changes to other files * fix * Add Depth Anything (#28654) * First draft * More improvements * More improvements * More improvements * More improvements * Add docs * Remove file * Add copied from * Address comments * Address comments * Address comments * Fix style * Update docs * Convert all checkpoints, add integration test * Rename checkpoints * Add pretrained backbone attributes * Fix default config * Address comment * Add figure to docs * Fix bug thanks to @xenova * Update conversion script * Fix integration test * [`chore`] Add missing space in warning (#28695) Add missing space in warning * Improve Backbone API docs (#28666) Update backbones.md * Update question_answering.md (#28694) fix typo: from: "model = TFAutoModelForQuestionAnswering("distilbert-base-uncased")" to: model = TFAutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased") * [`Vilt`] align input and model dtype in the ViltPatchEmbeddings forward pass (#28633) align dtype * [`docs`] Improve visualization for vertical parallelism (#28583) The documentation says "We refer to this Model parallelism as “Vertical” because of how models are typically visualized.", but then visualizes the model horizontally. This change visualizes the model indeed vertically. * Don't fail when `LocalEntryNotFoundError` during `processor_config.json` loading (#28709) * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Fix duplicate & unnecessary flash attention warnings (#28557) * fix duplicate & unnecessary flash warnings * trigger ci * warning_once * if/else order --------- Co-authored-by: Your Name <you@example.com> * support PeftMixedModel signature inspect (#28321) * support PeftMixedModel signature inspect * import PeftMixedModel only peft>=0.7.0 * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fix styling * Update src/transformers/trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/trainer.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style fixup * fix note --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix: corrected misleading log message in save_pretrained function (#28699) * [`docs`] Update preprocessing.md (#28719) * Update preprocessing.md adjust ImageProcessor link to working target (same as in lower section of file) * Update preprocessing.md * Initialize _tqdm_active with hf_hub_utils.are_progress_bars_disabled(… (#28717) Initialize _tqdm_active with hf_hub_utils.are_progress_bars_disabled() to respect HF_HUB_DISABLE_PROGRESS_BARS It seems like enable_progress_bar() and disable_progress_bar() sync up with huggingface_hub, but the initial value is always True. This changes will make sure the user's preference is respected implicity on initialization. * Fix `weights_only` (#28725) fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Stop confusing the TF compiler with ModelOutput objects (#28712) * Stop confusing the TF compiler with ModelOutput objects * Stop confusing the TF compiler with ModelOutput objects * fix: suppress `GatedRepoError` to use cache file (fix #28558). (#28566) * fix: suppress `GatedRepoError` to use cache file (fix #28558). * move condition_to_return parameter back to outside. * Unpin pydantic (#28728) * try pydantic v2 * try pydantic v2 --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * [docs] Fix datasets in guides (#28715) * change datasets * fix * [Flax] Update no init test for Flax v0.7.1 (#28735) * Falcon: removed unused function (#28605) * Generate: deprecate old src imports (#28607) * [`Siglip`] protect from imports if sentencepiece not installed (#28737) [Siglip] protect from imports if sentencepiece not installed * Add serialization logic to pytree types (#27871) * Add serialized type name to pytrees * Modify context * add serde test * Fix `DepthEstimationPipeline`'s docstring (#28733) * fix * fix * Fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Fix input data file extension in examples (#28741) * [Docs] Fix Typo in English & Japanese CLIP Model Documentation (TMBD -> TMDB) (#28751) * [Docs] Fix Typo in English CLIP model_doc * [Docs] Fix Typo in Japanese CLIP model_doc * PatchtTST and PatchTSMixer fixes (#28083) * :bug: fix .max bug * remove prediction_length from regression output dimensions * fix parameter names, fix output names, update tests * ensure shape for PatchTST * ensure output shape for PatchTSMixer * update model, batch, and expected for regression distribution test * update test expected Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com> * Update tests/models/patchtst/test_modeling_patchtst.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/patchtst/test_modeling_patchtst.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/patchtst/test_modeling_patchtst.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * standardize on patch_length Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com> * Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Make arguments more explicit Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com> * adjust prepared inputs Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com> --------- Signed-off-by: Wesley M. Gifford <wmgifford@us.ibm.com> Co-authored-by: Wesley M. Gifford <wmgifford@us.ibm.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Enable Gradient Checkpointing in Deformable DETR (#28686) * Enabled gradient checkpointing in Deformable DETR * Enabled gradient checkpointing in Deformable DETR encoder * Removed # Copied from headers in modeling_deta.py to break dependence on Deformable DETR code * small doc update for CamemBERT (#28644) * Pin pytest version <8.0.0 (#28758) * Pin pytest version <8.0.0 * Update setup.py * make deps_table_update * Mark test_constrained_beam_search_generate as flaky (#28757) * Make test_constrained_beam_search_generate as flaky * Update tests/generation/test_utils.py * Fix typo of `Block`. (#28727) * [Whisper] Make tokenizer normalization public (#28136) * [Whisper] Make tokenizer normalization public * add to docs * Support saving only PEFT adapter in checkpoints when using PEFT + FSDP (#28297) * Update trainer.py * Revert "Update trainer.py" This reverts commit 0557e2cc9effa3a41304322032239a3874b948a7. * Make trainer.py use adapter_only=True when using FSDP + PEFT * Support load_best_model with adapter_only=True * Ruff format * Inspect function args for save_ load_ fsdp utility functions and only pass adapter_only=True if they support it * Add French translation: french README.md (#28696) * doc: french README Signed-off-by: ThibaultLengagne <thibaultl@padok.fr> * doc: Add Depth Anything Signed-off-by: ThibaultLengagne <thibaultl@padok.fr> * doc: Add french link in other docs Signed-off-by: ThibaultLengagne <thibaultl@padok.fr> * doc: Add missing links in fr docs * doc: fix several mistakes in translation Signed-off-by: ThibaultLengagne <thibaultl@padok.fr> --------- Signed-off-by: ThibaultLengagne <thibaultl@padok.fr> Co-authored-by: Sarapuce <alexandreh@padok.fr> * Don't allow passing `load_in_8bit` and `load_in_4bit` at the same time (#28266) * Update quantization_config.py * Style * Protect from setting directly * add tests * Update tests/quantization/bnb/test_4bit.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Move CLIP _no_split_modules to CLIPPreTrainedModel (#27841) Add _no_split_modules to CLIPModel * `HfQuantizer` class for quantization-related stuff in `modeling_utils.py` (#26610) * squashed earlier commits for easier rebase * rm rebase leftovers * 4bit save enabled @quantizers * TMP gptq test use exllama * fix AwqConfigTest::test_wrong_backend for A100 * quantizers AWQ fixes * _load_pretrained_model low_cpu_mem_usage branch * quantizers style * remove require_low_cpu_mem_usage attr * rm dtype arg from process_model_before_weight_loading * rm config_origin from Q-config * rm inspect from q_config * fixed docstrings in QuantizationConfigParser * logger.warning fix * mv is_loaded_in_4(8)bit to BnbHFQuantizer * is_accelerate_available error msg fix in quantizer * split is_model_trainable in bnb quantizer class * rm llm_int8_skip_modules as separate var in Q * Q rm todo * fwd ref to HFQuantizer in type hint * rm note re optimum.gptq.GPTQQuantizer * quantization_config in __init__ simplified * replaced NonImplemented with create_quantized_param * rm load_in_4/8_bit deprecation warning * QuantizationConfigParser refactoring * awq-related minor changes * awq-related changes * awq config.modules_to_not_convert * raise error if no q-method in q-config in args * minor cleanup * awq quantizer docstring * combine common parts in bnb process_model_before_weight_loading * revert test_gptq * .process_model_ cleanup * restore dict config warning * removed typevars in quantizers.py * cleanup post-rebase 16 jan * QuantizationConfigParser classmethod refactor * rework of handling of unexpected aux elements of bnb weights * moved q-related stuff from save_pretrained to quantizers * refactor v1 * more changes * fix some tests * remove it from main init * ooops * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fix awq issues * fix * fix * fix * fix * fix * fix * add docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/hf_quantizer.md * address comments * fix * fixup * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address final comment * update * Update src/transformers/quantizers/base.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/quantizers/auto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * add kwargs update * fixup * add `optimum_quantizer` attribute * oops * rm unneeded file * fix doctests --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * [`HfQuantizer`] Move it to "Developper guides" (#28768) Update _toctree.yml * Use Conv1d for TDNN (#25728) * use conv for tdnn * run make fixup * update TDNN * add PEFT LoRA check * propagate tdnn warnings to others * add missing imports * update TDNN in wav2vec2_bert * add missing imports * Fix transformers.utils.fx compatibility with torch<2.0 (#28774) guard sdpa on torch>=2.0 * Further pin pytest version (in a temporary way) (#28780) fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * [`Backbone`] Use `load_backbone` instead of `AutoBackbone.from_config` (#28661) * Enable instantiating model with pretrained backbone weights * Remove doc updates until changes made in modeling code * Use load_backbone instead * Add use_timm_backbone to the model configs * Add missing imports and arguments * Update docstrings * Make sure test is properly configured * Include recent DPT updates * Task-specific pipeline init args (#28439) * Abstract out pipeline init args * Address PR comments * Reword * BC PIPELINE_INIT_ARGS * Remove old arguments * Small fix * Add tf_keras imports to prepare for Keras 3 (#28588) * Port core files + ESM (because ESM code is odd) * Search-replace in modelling code * Fix up transfo_xl as well * Fix other core files + tests (still need to add correct import to tests) * Fix cookiecutter * make fixup, fix imports in some more core files * Auto-add imports to tests * Cleanup, add imports to sagemaker tests * Use correct exception for importing tf_keras * Fixes in modeling_tf_utils * make fixup * Correct version parsing code * Ensure the pipeline tests correctly revert to float32 after each test * Ensure the pipeline tests correctly revert to float32 after each test * More tf.keras -> keras * Add dtype cast * Better imports of tf_keras * Add a cast for tf.assign, just in case * Fix callback imports * Pin Torch to <2.2.0 (#28785) * Pin torch to <2.2.0 * Pin torchvision and torchaudio as well * Playing around with versions to see if this helps * twiddle something to restart the CI * twiddle it back * Try changing the natten version * make fixup * Revert "Try changing the natten version" This reverts commit de0d6592c35dc39ae8b5a616c27285db28262d06. * make fixup * fix fix fix * fix fix fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * [`bnb`] Fix bnb slow tests (#28788) fix bnb slow tests * Prevent MLflow exception from disrupting training (#28779) Modified MLflow logging metrics from synchronous to asynchronous Co-authored-by: codiceSpaghetti <alessio.ser@hotmail.it> * don't initialize the output embeddings if we're going to tie them to input embeddings (#28192) * test that tied output embeddings aren't initialized on load * don't initialize the output embeddings if we're going to tie them to the input embeddings * [`HFQuantizer`] Remove `check_packages_compatibility` logic (#28789) remove `check_packages_compatibility` logic * [Whisper] Refactor forced_decoder_ids & prompt ids (#28687) * up * Fix more * Correct more * Fix more tests * fix fast tests * Fix more * fix more * push all files * finish all * make style * Fix timestamp wrap * make style * make style * up * up * up * Fix lang detection behavior * Fix lang detection behavior * Add lang detection test * Fix lang detection behavior * make style * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * better error message * make style tests * add warning --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Resolve DeepSpeed cannot resume training with PeftModel (#28746) * fix: resolve deepspeed resume peft model issues * chore: update something * chore: update model instance pass into is peft model checks * chore: remove hard code value to tests * fix: format code * canonical repos moves (#28795) * canonical repos moves * Style --------- Co-authored-by: Lysandre <lysandre@huggingface.co> * Wrap Keras methods to support BatchEncoding (#28734) * Shim the Keras methods to support BatchEncoding * Extract everything to a convert_batch_encoding function * Convert BatchFeature too (thanks Amy) * tf.keras -> keras * Flax mistral (#26943) * direct copy from llama work * mistral modules forward pass working * flax mistral forward pass with sliding window * added tests * added layer collection approach * Revert "added layer collection approach" This reverts commit 0e2905bf2236ec323163fc1a9f0c016b21aa8b8f. * Revert "Revert "added layer collection approach"" This reverts commit fb17b6187ac5d16da7c461e1130514dc3d137a43. * fixed attention outputs * added mistral to init and auto * fixed import name * fixed layernorm weight dtype * freeze initialized weights * make sure conversion consideres bfloat16 * added backend * added docstrings * added cache * fixed sliding window causal mask * passes cache tests * passed all tests * applied make style * removed commented out code * applied fix-copies ignored other model changes * applied make fix-copies * removed unused functions * passed generation integration test * slow tests pass * fixed slow tests * changed default dtype from jax.numpy.float32 to float32 for docstring check * skip cache test for FlaxMistralForSequenceClassification since if pad_token_id in input_ids it doesn't score previous input_ids * updated checkpoint since from_pt not included * applied black style * removed unused args * Applied styling and fixup * changed checkpoint for doc back * fixed rf after adding it to hf hub * Add dummy ckpt * applied styling * added tokenizer to new ckpt * fixed slice format * fix init and slice * changed ref for placeholder TODO * added copies from Llama * applied styling * applied fix-copies * fixed docs * update weight dtype reconversion for sharded weights * removed Nullable input ids * Removed unnecessary output attentions in Module * added embedding weight initialziation * removed unused past_key_values * fixed deterministic * Fixed RMS Norm and added copied from * removed input_embeds * applied make style * removed nullable input ids from sequence classification model * added copied from GPTJ * added copied from Llama on FlaxMistralDecoderLayer * added copied from to FlaxMistralPreTrainedModel methods * fix test deprecation warning * freeze gpt neox random_params and fix copies * applied make style * fixed doc issue * skipped docstring test to allign # copied from * applied make style * removed FlaxMistralForSequenceClassification * removed unused padding_idx * removed more sequence classification * removed sequence classification * applied styling and consistency * added copied from in tests * removed sequence classification test logic * applied styling * applied make style * removed freeze and fixed copies * undo test change * changed repeat_kv to tile * fixed to key value groups * updated copyright year * split casual_mask * empty to rerun failed pt_flax_equivalence test FlaxWav2Vec2ModelTest * went back to 2023 for tests_pr_documentation_tests * went back to 2024 * changed tile to repeat * applied make style * empty for retry on Wav2Vec2 * DeepSpeed: hardcode `torch.arange` dtype on `float` usage to avoid incorrect initialization (#28760) * Add artifact name in job step to maintain job / artifact correspondence (#28682) * avoid using job name * apply to other files --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Split daily CI using 2 level matrix (#28773) * update / add new workflow files * Add comment * Use env.NUM_SLICES * use scripts * use scripts * use scripts * Fix * using one script * Fix * remove unused file * update * fail-fast: false * remove unused file * fix * fix * use matrix * inputs * style * update * fix * fix * no model name * add doc * allow args * style * pass argument --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * [docs] Correct the statement in the docstirng of compute_transition_scores in generation/utils.py (#28786) * Adding [T5/MT5/UMT5]ForTokenClassification (#28443) * Adding [T5/MT5/UMT5]ForTokenClassification * Add auto mappings for T5ForTokenClassification and variants * Adding ForTokenClassification to the list of models * Adding attention_mask param to the T5ForTokenClassification test * Remove outdated comment in test * Adding EncoderOnly and Token Classification tests for MT5 and UMT5 * Fix typo in umt5 string * Add tests for all the existing MT5 models * Fix wrong comment in dependency_versions_table * Reverting change to common test for _keys_to_ignore_on_load_missing The test is correctly picking up redundant keys in _keys_to_ignore_on_load_missing. * Removing _keys_to_ignore_on_missing from MT5 since the key is not used in the model * Add fix-copies to MT5ModelTest * Make `is_torch_bf16_available_on_device` more strict (#28796) fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Fix symbolic_trace with kv cache (#28724) * fix symbolic_trace with kv cache * comment & better test * Add tip on setting tokenizer attributes (#28764) * Add tip on setting tokenizer attributes * Grammar * Remove the bit that was causing doc builds to fail * enable graident checkpointing in DetaObjectDetection and add tests in Swin/Donut_Swin (#28615) * enable graident checkpointing in DetaObjectDetection * fix missing part in original DETA * make style * make fix-copies * Revert "make fix-copies" This reverts commit 4041c86c29248f1673e8173b677c20b5a4511358. * remove fix-copies of DetaDecoder * enable swin gradient checkpointing * fix gradient checkpointing in donut_swin * add tests for deta/swin/donut * Revert "fix gradient checkpointing in donut_swin" This reverts commit 1cf345e34d3cc0e09eb800d9895805b1dd9b474d. * change supports_gradient_checkpointing pipeline to PreTrainedModel * Revert "add tests for deta/swin/donut" This reverts commit 6056ffbb1eddc3cb3a99e4ebb231ae3edf295f5b. * Revert "Revert "fix gradient checkpointing in donut_swin"" This reverts commit 24e25d0a14891241de58a0d86f817d0b5d2a341f. * Simple revert * enable deformable detr gradient checkpointing * add gradient in encoder * [docs] fix some bugs about parameter description (#28806) Co-authored-by: p_spozzhang <p_spozzhang@tencent.com> * Add models from deit (#28302) * Add modelss * Add 2 more models * add models to tocrree * Add modles * Update docs/source/ja/model_doc/detr.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/ja/model_doc/deit.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/ja/model_doc/deplot.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix bugs --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * [docs] Backbone (#28739) * backbones * fix path * fix paths * fix code snippet * fix links * [docs] HfQuantizer (#28820) * tidy * fix path * [Docs] Fix spelling and grammar mistakes (#28825) * Fix typos and grammar mistakes in docs and examples * Fix typos in docstrings and comments * Fix spelling of `tokenizer` in model tests * Remove erroneous spaces in decorators * Remove extra spaces in Markdown link texts * Explicitly check if token ID's are None in TFBertTokenizer constructor (#28824) Add an explicit none-check, since token ids can be 0 * Add missing None check for hf_quantizer (#28804) * Add missing None check for hf_quantizer * Add test, fix logic. * make style * Switch test model to Mistral * Comment * Update tests/test_modeling_utils.py --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Fix issues caused by natten (#28834) try Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * fix / skip (for now) some tests before switch to torch 2.2 (#28838) * fix / skip some tests before we can switch to torch 2.2 * style --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Use `-v` for `pytest` on CircleCI (#28840) use -v in pytest Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Reduce GPU memory usage when using FSDP+PEFT (#28830) support FSDP+PEFT * Mark `test_encoder_decoder_model_generate` for `vision_encoder_deocder` as flaky (#28842) Mark test as flaky * Bump dash from 2.3.0 to 2.15.0 in /examples/research_projects/decision_transformer (#28845) Bump dash in /examples/research_projects/decision_transformer Bumps [dash](https://github.com/plotly/dash) from 2.3.0 to 2.15.0. - [Release notes](https://github.com/plotly/dash/releases) - [Changelog](https://github.com/plotly/dash/blob/dev/CHANGELOG.md) - [Commits](https://github.com/plotly/dash/compare/v2.3.0...v2.15.0) --- updated-dependencies: - dependency-name: dash dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Support custom scheduler in deepspeed training (#26831) Reuse trainer.create_scheduler to create scheduler for deepspeed * [Docs] Fix bad doc: replace save with logging (#28855) Fix bad doc: replace save with logging * Ability to override clean_code_for_run (#28783) * Add clean_code_for_run function * Call clean_code_for_run from agent method * [WIP] Hard error when ignoring tensors. (#27484) * [WIP] Hard error when ignoring tensors. * Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. * Adding a failing test on `main` that passes here. * We don't need to keep the subfolder logic in this test. * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * [`Doc`] update contribution guidelines (#28858) update guidelines * Correct wav2vec2-bert inputs_to_logits_ratio (#28821) * Correct wav2vec2-bert inputs_to_logits_ratio * correct ratio * correct ratio, clean asr pipeline * refactor on one line * Image Feature Extraction pipeline (#28216) * Draft pipeline * Fixup * Fix docstrings * Update doctest * Update pipeline_model_mapping * Update docstring * Update tests * Update src/transformers/pipelines/image_feature_extraction.py Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> * Fix docstrings - review comments * Remove pipeline mapping for composite vision models * Add to pipeline tests * Remove for flava (multimodal) * safe pil import * Add requirements for pipeline run * Account for super slow efficientnet * Review comments * Fix tests * Swap order of kwargs * Use build_pipeline_init_args * Add back FE pipeline for Vilt * Include image_processor_kwargs in docstring * Mark test as flaky * Update TODO * Update tests/pipelines/test_pipelines_image_feature_extraction.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add license header --------- Co-authored-by: Omar Sanseviero <osanseviero@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * ClearMLCallback enhancements: support multiple runs and handle logging better (#28559) * add clearml tracker * support multiple train runs * remove bad code * add UI entries for config/hparams overrides * handle models in different tasks * run ruff format * tidy code based on code review --------- Co-authored-by: Eugen Ajechiloae <eugenajechiloae@gmail.com> * Do not use mtime for checkpoint rotation. (#28862) Resolve https://github.com/huggingface/transformers/issues/26961 * Adds LlamaForQuestionAnswering class in modeling_llama.py along with AutoModel Support (#28777) * This is a test commit * testing commit * final commit with some changes * Removed copy statement * Fixed formatting issues * Fixed error added past_key_values in the forward method * Fixed a trailing whitespace. Damn the formatting rules are strict * Added the copy statement * Bump cryptography from 41.0.2 to 42.0.0 in /examples/research_projects/decision_transformer (#28879) Bump cryptography in /examples/research_projects/decision_transformer Bumps [cryptography](https://github.com/pyca/cryptography) from 41.0.2 to 42.0.0. - [Changelog](https://github.com/pyca/cryptography/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pyca/cryptography/compare/41.0.2...42.0.0) --- updated-dependencies: - dependency-name: cryptography dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [Docs] Update project names and links in awesome-transformers (#28878) Update project names and repository links in awesome-transformers * Fix LongT5ForConditionalGeneration initialization of lm_head (#28873) * Raise error when using `save_only_model` with `load_best_model_at_end` for DeepSpeed/FSDP (#28866) * Raise error when using `save_only_model` with `load_best_model_at_end` for DeepSpeed/FSDP * Update trainer.py * Fix `FastSpeech2ConformerModelTest` and skip it on CPU (#28888) * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Revert "[WIP] Hard error when ignoring tensors." (#28898) Revert "[WIP] Hard error when ignoring tensors. (#27484)" This reverts commit 2da28c4b41bba23969a8afe97c3dfdcbc47a57dc. * unpin torch (#28892) * unpin torch * check * check * check --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Explicit server error on gated model (#28894) * [Docs] Fix backticks in inline code and documentation links (#28875) Fix backticks in code blocks and documentation links * Hotfix - make `torchaudio` get the correct version in `torch_and_flax_job` (#28899) * check * check * check --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * [Docs] Add missing language options and fix broken links (#28852) * Add missing entries to the language selector * Add links to the Colab and AWS Studio notebooks for ONNX * Use anchor links in CONTRIBUTING.md * Fix broken hyperlinks due to spaces * Fix links to OpenAI research articles * Remove confusing footnote symbols from author names, as they are also considered invalid markup * fix: Fixed the documentation for `logging_first_step` by removing "evaluate" (#28884) Fixed the documentation for logging_first_step by removing evaluate. * fix Starcoder FA2 implementation (#28891) * Fix Keras scheduler import so it works for older versions of Keras (#28895) Fix our schedule import so it works for older versions of Keras * ⚠️ Raise `Exception` when trying to generate 0 tokens ⚠️ (#28621) * change warning to exception * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * validate `max_new_tokens` > 0 in `GenerationConfig` * fix truncation test parameterization in `TextGenerationPipelineTests` --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update the cache number (#28905) * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> * Add npu device for pipeline (#28885) add npu device for pipeline Co-authored-by: unit_test <test@unit.com> * [Docs] Fix placement of tilde character (#28913) Fix placement of tilde character * [Docs] Revert translation of '@slow' decorator (#28912) * Fix utf-8 yaml load for marian conversion to pytorch in Windows (#28618) Fix utf-8 yaml in marian conversion * [`Core generation`] Adds support for static KV cache (#27931) Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Remove dead TF loading code (#28926) Remove dead code * fix: torch.int32 instead of torch.torch.int32 (#28883) * pass kwargs in stopping criteria list (#28927) * Support batched input for decoder start ids (#28887) * support batched input for decoder start ids * Fix typos Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * minor changes * fix: decoder_start_id as list * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit * empty commit --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * [Docs] Fix broken links and syntax issues (#28918) * Fix model documentation links in attention.md * Fix external link syntax * Fix target anchor names of section links * Fix copyright statement comments * Fix documentation headings * Fix max_position_embeddings default value for llama2 to 4096 #28241 (#28754) * Changed max_position_embeddings default value from 2048 to 4096 * force push * Fixed formatting issues. Fixed missing argument in write_model. * Reverted to the default value 2048 in the Llama config. Added comments for the llama_version argument. * Fixed issue with default value value of max_position_embeddings in docstring * Updated help message for llama versions Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fix a wrong link to CONTRIBUTING.md section in PR template (#28941) * Fix type annotations on neftune_noise_alpha and fsdp_config TrainingArguments parameters (#28942) * [i18n-de] Translate README.md to German (#28933) * Translate README.md to German * Add links to README_de.md * Remove invisible characters in README * Change to a formal tone and fix punctuation marks * [Nougat] Fix pipeline (#28242) * Fix pipeline * Remove print statements * Address comments * Address issue * Remove unused imports * [Docs] Update README and default pipelines (#28864) * Update README and docs * Update README * Update README * Convert `torch_dtype` as `str` to actual torch data type (i.e. "float16" …to `torch.float16`) (#28208) * Convert torch_dtype as str to actual torch data type (i.e. "float16" to torch.float16) * Check if passed torch_dtype is an attribute in torch * Update src/transformers/pipelines/__init__.py Check type via isinstance Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * [`pipelines`] updated docstring with vqa alias (#28951) updated docstring with vqa alias * Tests: tag `test_save_load_fast_init_from_base` as flaky (#28930) * Updated requirements for image-classification samples: datasets>=2.14.0 (#28974) Updated datasets requirements. Need a package version >= 2.14.0 * Always initialize tied output_embeddings if it has a bias term (#28947) Continue to initialize tied output_embeddings if it has a bias term The bias term is not tied, and so will need to be initialized accordingly. * Clean up staging tmp checkpoint directory (#28848) clean up remaining tmp checkpoint dir Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> * [Docs] Add language identifiers to fenced code blocks (#28955) Add language identifiers to code blocks * [Docs] Add video section (#28958) Add video section * [i18n-de] Translate CONTRIBUTING.md to German (#28954) * Translate contributing.md to German * Fix formatting issues in contributing.md * Address review comments * Fix capitalization * [`NllbTokenizer`] refactor with added tokens decoder (#27717) * refactor with addedtokens decoder * style * get rid of lang code to id * style * keep some things for BC * update tests * add the mask token at the end of the vocab * nits * nits * fix final tests * style * nits * Update src/transformers/models/nllb/tokenization_nllb_fast.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * nits * style? * Update src/transformers/convert_slow_tokenizer.py * make it a tad bit more custom * ruff please stop Co-Authored by avidale <dale.david@mail.ru> * Update Co-authored-by: avidale <dale.david@mail.ru> * Update Co-authored-by: avidale <dale.david@mail.ru> * oupts * ouft * nites * test * fix the remaining failing tests * style * fix failing test * ficx other test * temp dir + test the raw init * update test * style --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Add sudachi_projection option to BertJapaneseTokenizer (#28503) * add sudachi_projection option * Upgrade sudachipy>=0.6.8 * add a test case for sudachi_projection * Compatible with older versions of SudachiPy * make fixup * make style * error message for unidic download * revert jumanpp test cases * format options for sudachi_projection Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * format options for sudachi_split_mode and sudachi_dict_type * comment * add tests for full_tokenizer kwargs * pass projection arg directly * require_sudachi_projection * make style * revert upgrade sudachipy * check is_sudachi_projection_available() * revert dependency_version_table and bugfix * style format * simply raise ImportError Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * simply raise ImportError --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Static Cache: load models with MQA or GQA (#28975) * Update configuration_llama.py: fixed broken link (#28946) * Update co…
* use conv for tdnn * run make fixup * update TDNN * add PEFT LoRA check * propagate tdnn warnings to others * add missing imports * update TDNN in wav2vec2_bert * add missing imports
What does this PR do?
Partially fixes #25476
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.