Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wan] ModuleNotFoundError: No module named 'dependencies.peft' #312

Open
dorpxam opened this issue Mar 10, 2025 · 1 comment · May be fixed by #316
Open

[Wan] ModuleNotFoundError: No module named 'dependencies.peft' #312

dorpxam opened this issue Mar 10, 2025 · 1 comment · May be fixed by #316

Comments

@dorpxam
Copy link

dorpxam commented Mar 10, 2025

Latest update work well and pass the previous error. But the seconds/it count in minutes now (20 minutes), and this is always the same problem of my 16GB. Here the full memory use is something like 34GB (including shared RAM). That's why is incredibly slow.

To attempt to recover this point, I add the --layerwise_upcasting_modules transformer and change the optimizer to adamw-bnb-8bit just for a test. The optimizer pass but the layerwise_upcasting_modules crash in the patches/__init__.py file.

This is just a typo error, adding . before dependencies works !

from .dependencies.peft import patch

Hope that help!

+ export WANDB_MODE=online
+ WANDB_MODE=online
+ export NCCL_P2P_DISABLE=1
+ NCCL_P2P_DISABLE=1
+ export TORCH_NCCL_ENABLE_MONITORING=0
+ TORCH_NCCL_ENABLE_MONITORING=0
+ export FINETRAINERS_LOG_LEVEL=DEBUG
+ FINETRAINERS_LOG_LEVEL=DEBUG
+ BACKEND=ptd
+ NUM_GPUS=1
+ CUDA_VISIBLE_DEVICES=0
+ TRAINING_DATASET_CONFIG=scripts/wan/elizabeth/training.json
+ VALIDATION_DATASET_FILE=scripts/wan/elizabeth/validation.json
+ DDP_1='--parallel_backend ptd --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1'
+ DDP_2='--parallel_backend ptd --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1'
+ DDP_4='--parallel_backend ptd --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1'
+ FSDP_2='--parallel_backend ptd --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1'
+ FSDP_4='--parallel_backend ptd --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1'
+ HSDP_2_2='--parallel_backend ptd --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1'
+ parallel_cmd=($DDP_1)
+ model_cmd=(--model_name "wan" --pretrained_model_name_or_path "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" --layerwise_upcasting_modules transformer)
+ dataset_cmd=(--dataset_config $TRAINING_DATASET_CONFIG --dataset_shuffle_buffer_size 24 --precomputation_items 24 --precomputation_once --enable_precomputation)
+ dataloader_cmd=(--dataloader_num_workers 1)
+ diffusion_cmd=(--flow_weighting_scheme "logit_normal")
+ training_cmd=(--training_type "lora" --seed 42 --batch_size 1 --train_steps 2400 --rank 32 --lora_alpha 32 --target_modules "blocks.*(to_q|to_k|to_v|to_out.0)" --gradient_accumulation_steps 1 --gradient_checkpointing --checkpointing_steps 96 --checkpointing_limit 1000 --resume_from_checkpoint latest --enable_slicing --enable_tiling)
+ optimizer_cmd=(--optimizer "adamw-bnb-8bit" --lr 5e-5 --lr_scheduler "constant_with_warmup" --lr_warmup_steps 240 --lr_num_cycles 1 --beta1 0.9 --beta2 0.99 --weight_decay 1e-4 --epsilon 1e-8 --max_grad_norm 1.0)
+ validation_cmd=()
+ miscellaneous_cmd=(--tracker_name "finetrainers-wan" --output_dir "/mnt/f/training/wan/elizabeth" --init_timeout 600 --nccl_timeout 600 --report_to "wandb")
+ '[' ptd == accelerate ']'
+ '[' ptd == ptd ']'
+ export CUDA_VISIBLE_DEVICES=0
+ CUDA_VISIBLE_DEVICES=0
+ torchrun --standalone --nnodes=1 --nproc_per_node=1 --rdzv_backend c10d --rdzv_endpoint=localhost:0 train.py --parallel_backend ptd --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1 --model_name wan --pretrained_model_name_or_path Wan-AI/Wan2.1-T2V-1.3B-Diffusers --layerwise_upcasting_modules transformer --dataset_config scripts/wan/elizabeth/training.json --dataset_shuffle_buffer_size 24 --precomputation_items 24 --precomputation_once --enable_precomputation --dataloader_num_workers 1 --flow_weighting_scheme logit_normal --training_type lora --seed 42 --batch_size 1 --train_steps 2400 --rank 32 --lora_alpha 32 --target_modules 'blocks.*(to_q|to_k|to_v|to_out.0)' --gradient_accumulation_steps 1 --gradient_checkpointing --checkpointing_steps 96 --checkpointing_limit 1000 --resume_from_checkpoint latest --enable_slicing --enable_tiling --optimizer adamw-bnb-8bit --lr 5e-5 --lr_scheduler constant_with_warmup --lr_warmup_steps 240 --lr_num_cycles 1 --beta1 0.9 --beta2 0.99 --weight_decay 1e-4 --epsilon 1e-8 --max_grad_norm 1.0 --tracker_name finetrainers-wan --output_dir /mnt/f/training/wan/elizabeth --init_timeout 600 --nccl_timeout 600 --report_to wandb
2025-03-10 09:07:19,718 - finetrainers - DEBUG - Successfully imported bitsandbytes version 0.45.3
DEBUG:finetrainers:Successfully imported bitsandbytes version 0.45.3
2025-03-10 09:07:19,721 - finetrainers - DEBUG - Remaining unparsed arguments: []
DEBUG:finetrainers:Remaining unparsed arguments: []
2025-03-10 09:07:20,386 - finetrainers - INFO - Initialized parallel state with:
  - World size: 1
  - Pipeline parallel degree: 1
  - Data parallel degree: 1
  - Context parallel degree: 1
  - Tensor parallel degree: 1
  - Data parallel shards: 1

INFO:finetrainers:Initialized parallel state with:
  - World size: 1
  - Pipeline parallel degree: 1
  - Data parallel degree: 1
  - Context parallel degree: 1
  - Tensor parallel degree: 1
  - Data parallel shards: 1

2025-03-10 09:07:20,392 - finetrainers - DEBUG - Device mesh: DeviceMesh('cuda', 0)
DEBUG:finetrainers:Device mesh: DeviceMesh('cuda', 0)
2025-03-10 09:07:20,393 - finetrainers - DEBUG - Enabling determinism: {'global_rank': 0, 'seed': 42}
DEBUG:finetrainers:Enabling determinism: {'global_rank': 0, 'seed': 42}
2025-03-10 09:07:20,402 - finetrainers - ERROR - An error occurred during training: No module named 'dependencies.peft'
ERROR:finetrainers:An error occurred during training: No module named 'dependencies.peft'
2025-03-10 09:07:20,402 - finetrainers - ERROR - Traceback (most recent call last):
  File "/home/dorpxam/ai/finetrainers/train.py", line 66, in main
    trainer = SFTTrainer(args, model_specification)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 80, in __init__
    patches.perform_patches_for_training(self.args, self.state.parallel_backend)
  File "/home/dorpxam/ai/finetrainers/finetrainers/patches/__init__.py", line 21, in perform_patches_for_training
    from dependencies.peft import patch
ModuleNotFoundError: No module named 'dependencies.peft'

ERROR:finetrainers:Traceback (most recent call last):
  File "/home/dorpxam/ai/finetrainers/train.py", line 66, in main
    trainer = SFTTrainer(args, model_specification)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 80, in __init__
    patches.perform_patches_for_training(self.args, self.state.parallel_backend)
  File "/home/dorpxam/ai/finetrainers/finetrainers/patches/__init__.py", line 21, in perform_patches_for_training
    from dependencies.peft import patch
ModuleNotFoundError: No module named 'dependencies.peft'

[rank0]:[W310 09:07:20.906514174 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
+ echo -ne '-------------------- Finished executing script --------------------\n\n'
-------------------- Finished executing script --------------------
@dorpxam
Copy link
Author

dorpxam commented Mar 10, 2025

In addition of the previous message, the precomputation seem very fast now ? 1 minute 34 seconds versus more than 25 minutes before.

But right after, I got this error message:

ERROR:finetrainers:Error during training: mat1 and mat2 must have the same dtype, but got Float8_e4m3fn and BFloat16

It seem that the optimization trick not work directly without adapting the dtype. I will see if I can find a solution here.

EDIT:
Here is the full log of the error :

ERROR:finetrainers:Traceback (most recent call last):
  File "/home/dorpxam/ai/finetrainers/train.py", line 70, in main
    trainer.run()
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 97, in run
    raise e
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 92, in run
    self._train()
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 470, in _train
    pred, target, sigmas = self.model_specification.forward(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/ai/finetrainers/finetrainers/models/wan/base_specification.py", line 316, in forward
    pred = transformer(
           ^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py", line 423, in forward
    temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
                                                                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py", line 156, in forward
    temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 1305, in forward
    sample = self.linear_1(sample)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/hooks/hooks.py", line 148, in new_forward
    output = function_reference.forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got Float8_e4m3fn and BFloat16

EDIT2:
Putting --transformer_dtype float8_e4m3fn in the bash script, the RuntimeError is slightly different but always crash while calling the wan transformer pipeline.

Here is the new log:

ERROR:finetrainers:Traceback (most recent call last):
  File "/home/dorpxam/ai/finetrainers/train.py", line 70, in main
    trainer.run()
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 97, in run
    raise e
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 92, in run
    self._train()
  File "/home/dorpxam/ai/finetrainers/finetrainers/trainer/sft_trainer/trainer.py", line 470, in _train
    pred, target, sigmas = self.model_specification.forward(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/ai/finetrainers/finetrainers/models/wan/base_specification.py", line 316, in forward
    pred = transformer(
           ^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/diffusers/models/transformers/transformer_wan.py", line 420, in forward
    hidden_states = self.patch_embedding(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 725, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dorpxam/anaconda3/envs/finetrainers/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 720, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
RuntimeError: Input type (c10::BFloat16) and bias type (c10::Float8_e4m3fn) should be the same

EDIT3:
Without precomputation, I got the same RuntimeError than in EDIT2.

@a-r-r-o-w a-r-r-o-w linked a pull request Mar 10, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant