Closed
Description
Describe the bug
While trying to launch the train script getting below error for the torch library
Traceback (most recent call last):
File "train_dreambooth.py", line 861, in <module>
main(args)
File "train_dreambooth.py", line 701, in main
if unet.dtype != torch.float32:
File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1269, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DistributedDataParallel' object has no attribute 'dtype'
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 82105) of binary: /usr/bin/python3
Reproduction
Followed the steps in the Readme to install the packages and ran below script:
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="/home/ubuntu/character_images/dreambooth_example"
export OUTPUT_DIR="/home/ubuntu/character_images/model"
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=400
Logs
No response
System Info
- `diffusers` version: 0.12.0.dev0
- Platform: Linux-5.15.0-1026-aws-x86_64-with-glibc2.29
- Python version: 3.8.10
- PyTorch version (GPU?): 1.13.1+cu117 (True)
- Huggingface_hub version: 0.11.1
- Transformers version: 0.15.0
- Accelerate version: not installed
- xFormers version: not installed
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
ubuntu@ip-172-31-1-217:~/diffusers/examples/dreambooth$ python3
Python 3.8.10 (default, Nov 14 2022, 12:59:47)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import diffusers
/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.14) or chardet (3.0.4) doesn't match a supported version!
warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
>>> diffusers.__version__
'0.12.0.dev0'