Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training crashing for images that used to work fine, with error: latent shape mismatch: torch.Size([16, 88, 136]) != torch.Size([16, 88, 144]) #977

Closed
a-l-e-x-d-s-9 opened this issue Sep 15, 2024 · 28 comments · Fixed by #1076
Labels
1 / 0 magic not reliably reproducible bug Something isn't working help wanted Extra attention is needed

Comments

@a-l-e-x-d-s-9
Copy link

Using latest version of SimpleTuner producing an error: "latent shape mismatch: torch.Size([16, 88, 136]) != torch.Size([16, 88, 144])"
Examples:
error_crash_latent_mismatch_01.txt
error_crash_latent_mismatch_02.txt

Both images used in the past training without issues.
I have the option in settings: "--delete_problematic_images": true
Tested with two separate servers with clean caching for each.
multidatabackend:
s01_multidatabackend.json
Attaching images that crashed so far, zip password is "password":
images_crashed.zip

Interestingly, I had a crash on step 607 now. With checkpoint made on step 600. I continued the training from last, to see if it will crash on the same step, but training passed over step 607 without crashing.

@a-l-e-x-d-s-9
Copy link
Author

config.json:
config.json

@a-l-e-x-d-s-9
Copy link
Author

I've changed the multidatabackend.json to use pixel_area:
s01_multidatabackend.json
Also I deleted all the old cache, and the training cached everything.
I'm getting an error on a different image now:
error_crash_latent_mismatch_03.txt

@bghira
Copy link
Owner

bghira commented Sep 16, 2024

you should disable disable-bucket-pruning and see

@a-l-e-x-d-s-9
Copy link
Author

I assume that the code generating the cache and the code verifying the cache have a difference in rounding and calculation of the sizes, which is causing the mismatch. From what I see in the code related to disable_bucket_pruning, it just doesn’t remove images, it doesn’t have a different calculation for the image sizes. Maybe not using disable_bucket_pruning can hide the issue, but I don’t think it’s part of the problem or the solution.

@bghira
Copy link
Owner

bghira commented Sep 16, 2024

if it goes away then i'll know where to look, otherwise i have to assume this is a problem with just your setup and can't do anything about it. it's up to you how to proceed

@bghira
Copy link
Owner

bghira commented Sep 16, 2024

there's no difference in rounding between cache generation and cache loading. the actual size of the cache element is checked against other cache entries in the batch.

if you really want to just keep training with all other settings the same, use a batch size of 1 with gradient accumulation steps > 1 to emulate larger batch sizes with the appropriate slowdown?

@a-l-e-x-d-s-9
Copy link
Author

Ok, I'm testing now with disable_bucket_pruning.

@a-l-e-x-d-s-9
Copy link
Author

I tested the training with disable_bucket_pruning=false. And I used repeats=3 with dataset.
The training crashed after 227 steps with an error:

 Epoch 1/1, Steps:   1%|▏                                | 227/30000 [45:30<99:29:22, 12.03s/it, lr=1.67e-5, step_loss=0.0805]████████████████████████████████▋                        | 5/6 [01:29<00:17, 17.90s/it]
 (id=all_dataset_512) File /workspace/input/dataset/npa/66546318_003_043d.jpg latent shape mismatch: torch.Size([16, 56, 80]) != torch.Size([16, 48, 80])
 Traceback (most recent call last):
   File "/workspace/SimpleTuner/train.py", line 49, in <module>
     trainer.train()
   File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1612, in train
     batch = iterator_fn(step, *iterator_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1269, in random_dataloader_iterator
     return next(chosen_iter)
            ^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
     data = self._next_data()
            ^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
     return self.collate_fn(data)
            ^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 882, in <lambda>
     collate_fn=lambda examples: collate_fn(examples),
                                 ^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
     latent_batch = check_latent_shapes(
                    ^^^^^^^^^^^^^^^^^^^^
   File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
     raise ValueError(
 ValueError: (id=all_dataset_512) File /workspace/input/dataset/npa/66546318_003_043d.jpg latent shape mismatch: torch.Size([16, 56, 80]) != torch.Size([16, 48, 80])

The exact settings to reproduce this issue: tr_01
Dataset to reproduce the issue.

@bghira bghira added bug Something isn't working help wanted Extra attention is needed 1 / 0 magic not reliably reproducible labels Sep 19, 2024
@a-l-e-x-d-s-9
Copy link
Author

The attached dataset and settings are reproducing the issue very reliably in a few hundreds steps. I don't think "1 / 0 magic" needed.

@bghira
Copy link
Owner

bghira commented Sep 22, 2024

i still can't reproduce it at all on mac or linux, hence not having a solution yet. i set up your dataset in combination with 12 other datasets containing roughly 3800 to 13000 images in each, plus one with 576,000 images in it and there's no problems locally.

you will probably have to enable SIMPLETUNER_LOG_LEVEL=DEBUG in config.env and reproduce it with the details in debug.log and review the contents to determine why the sizes are going AWOL.

@a-l-e-x-d-s-9
Copy link
Author

a-l-e-x-d-s-9 commented Sep 22, 2024

Using additional datasets is something that might hide the issue, because you will have more buckets having more images. And you will have a smaller chance to get buckets from my dataset that causing the crash.
The dataset that I provided always crashes - and I tested it with a lot of different servers multiple times.
I think that if you use my settings for multidatabackend + config + dataset - without extras, you will be able to reproduce the problem after a few hundreds steps.
I can run it again with debug enabled, but it might be also related to timing - I had another issue with crashes that just disappeared when I enabled debugs. And not sure how helpful the debug log actually will be for you. But I will do it if it can help. I wanted to train with this dataset for a while, and it crashes every time with a different image each time - images that used to work with slightly different multidatabackend configurations and less images. And I don't have any idea for workaround.

@bghira
Copy link
Owner

bghira commented Sep 22, 2024

it is just a dataset of images. there are not two images with the same name and different sizes. having the one dataset didn't cause the error either. i do in fact train quite frequently with just one dataset, and your configuration here has two datasets.

@a-l-e-x-d-s-9
Copy link
Author

My configurations have two resolutions training on the same dataset that I attached in the dataset file.
Is there something wrong with my settings, that doesn't work for you?

@bghira
Copy link
Owner

bghira commented Sep 22, 2024

no, just pointing out that when you mention using additional datasets would somehow hide the issue, your config has the two already. not sure what you meant by hiding the issue with more buckets - a dataset is independent from the other datasets. i didn't run the accelerator, so i am able to get through 6000 steps quite quickly. it just validates shapes, and does not train. everything else is the same.

@a-l-e-x-d-s-9
Copy link
Author

You said dataset in combination with 12 other datasets
I assumed you have literally combined my dataset with 12 other datasets - which will lower the chance of my dataset showing the error.
How can you turn off accelerator and just test the rest?
Is it possible it related to cropping? - I have it off, but maybe...

@a-l-e-x-d-s-9
Copy link
Author

I started a new training with smaller dataset and got the error:

Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/face/479ea11b-721e-4c8a-aad1-33ad3f56e1d2.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 144, 112])

Debug log:
debug.log
Settings:
s01_config_01.json
s01_multidatabackend.json

@a-l-e-x-d-s-9
Copy link
Author

Here is a debug log - I deleted the cache first:
debug2.log
It crashed after 10 steps this time.

@a-l-e-x-d-s-9
Copy link
Author

Error:

/dataset/s19/72a10a68-c6ad-45d6-a5d6-4ba1ae60bbb7.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/s19/72a10a68-c6ad-45d6-a5d6-4ba1ae60bbb7.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])

@a-l-e-x-d-s-9
Copy link
Author

A new run with: "--aspect_bucket_rounding": 2,
Error:

Epoch 1/13, Steps:   0%|                                    | 3/3500 [01:09<19:37:19, 20.20s/it, lr=0.00496, step_loss=0.488](id=all_dataset_768) File /workspace/input/s01_sandwich_master/dataset/s09/6e9279eb-dc17-4a4d-b20f-694479a5445e.jpeg latent shape mismatch: torch.Size([16, 128, 72]) != torch.Size([16, 104, 88])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_768) File /workspace/input/s01_sandwich_master/dataset/s09/6e9279eb-dc17-4a4d-b20f-694479a5445e.jpeg latent shape mismatch: torch.Size([16, 128, 72]) != torch.Size([16, 104, 88])

Log:
debug3.log

@a-l-e-x-d-s-9
Copy link
Author

I uploaded the dataset.

@a-l-e-x-d-s-9
Copy link
Author

I added: "--debug_aspect_buckets": true
Error:

Epoch 1/13, Steps:   0%|                                                    | 0/3500 [00:04<?, ?it/s, lr=0, step_loss=0.0107](id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/s09/a33c672e-0a27-4528-a12b-9173f77d55ec.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1863, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1279, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 888, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_1024) File /workspace/input/s01_sandwich_master/dataset/s09/a33c672e-0a27-4528-a12b-9173f77d55ec.jpeg latent shape mismatch: torch.Size([16, 168, 96]) != torch.Size([16, 136, 120])

Log:
debug4.log

@a-l-e-x-d-s-9
Copy link
Author

My settings:
s01_config_01.json
s01_multidatabackend.json

@a-l-e-x-d-s-9
Copy link
Author

I tried to run with crop enabled, and it crashed anyway with this error:

sandwiches/French Dipped Sandwiches.jpeg latent shape mismatch: torch.Size([16, 128, 96]) != torch.Size([16, 152, 80])
Traceback (most recent call last):
  File "/workspace/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/workspace/SimpleTuner/helpers/training/trainer.py", line 1885, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 1289, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/data_backend/factory.py", line 898, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 413, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/SimpleTuner/helpers/training/collate.py", line 356, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_896) File /workspace/input/s01_sandwich_master/dataset/sandwiches/French Dipped Sandwiches.jpeg latent shape mismatch: torch.Size([16, 128, 96]) != torch.Size([16, 152, 80])

Here is the
s01_multidatabackend.json I used.
Unfortunately, enabling crop not helping with this issue.
This problem is reproducible very consistently on the first steps of the training with a small dataset.

@a-l-e-x-d-s-9
Copy link
Author

a-l-e-x-d-s-9 commented Oct 17, 2024

I updated, simplified settings, converted to SDXL, and tried again with reduced dataset that has only 41 images. Here is the dataset (v4):
dataset_v4.zip

I'm attaching all the files I used for training with all the settings, including full debug log per each run.
Interestingly, I ran it 3 times with the same settings, but it crashed every single time on a different step with different file, on steps: 6, 7, and 14.

Here is the crash log 1 -
all_settings_1.zip

Epoch 1/182, Steps:   0%|                                   | 6/4000 [00:23<4:13:02,  3.80s/it, lr=0.00482, step_loss=0.0243](id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s10_f2922403-2628-46ee-a993-25affcebe234.jpeg latent shape mismatch: torch.Size([4, 88, 48]) != torch.Size([4, 72, 56])
Traceback (most recent call last):
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/trainer.py", line 2134, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 1307, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 908, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 471, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 413, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s10_f2922403-2628-46ee-a993-25affcebe234.jpeg latent shape mismatch: torch.Size([4, 88, 48]) != torch.Size([4, 72, 56])

Epoch 1/182, Steps:   0%|                                   | 6/4000 [00:23<4:22:54,  3.95s/it, lr=0.00482, step_loss=0.0243]

Here is the crash log 2 -
all_settings_2.zip

Epoch 1/182, Steps:   0%|                                   | 7/4000 [00:26<4:05:07,  3.68s/it, lr=0.00476, step_loss=0.0177](id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/s17/33caf417-bf48-4cb5-84fc-ec0c7df9401a.jpeg latent shape mismatch: torch.Size([4, 72, 56]) != torch.Size([4, 88, 48])
Traceback (most recent call last):
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/trainer.py", line 2134, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 1307, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 908, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 471, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 413, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_512) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/s17/33caf417-bf48-4cb5-84fc-ec0c7df9401a.jpeg latent shape mismatch: torch.Size([4, 72, 56]) != torch.Size([4, 88, 48])

Epoch 1/182, Steps:   0%|                                   | 7/4000 [00:26<4:16:20,  3.85s/it, lr=0.00476, step_loss=0.0177]

Here is the crash log 3 -
all_settings_3.zip

Epoch 1/182, Steps:   0%|                                  | 14/4000 [00:53<4:10:38,  3.77s/it, lr=0.00409, step_loss=0.0438](id=all_dataset_768) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s01_89659378-050c-459a-86f5-0bc5d2abb2d5.jpeg latent shape mismatch: torch.Size([4, 128, 72]) != torch.Size([4, 104, 88])
Traceback (most recent call last):
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/train.py", line 49, in <module>
    trainer.train()
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/trainer.py", line 2134, in train
    batch = iterator_fn(step, *iterator_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 1307, in random_dataloader_iterator
    return next(chosen_iter)
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 673, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/data_backend/factory.py", line 908, in <lambda>
    collate_fn=lambda examples: collate_fn(examples),
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 471, in collate_fn
    latent_batch = check_latent_shapes(
                   ^^^^^^^^^^^^^^^^^^^^
  File "/home/alexds9/Documents/stable_diffusion/SimpleTuner/helpers/training/collate.py", line 413, in check_latent_shapes
    raise ValueError(
ValueError: (id=all_dataset_768) File /home/alexds9/Documents/stable_diffusion/SimpleTunerLatentMismatch/Latent_mismatch_crash_02/dataset/full_body/s01_89659378-050c-459a-86f5-0bc5d2abb2d5.jpeg latent shape mismatch: torch.Size([4, 128, 72]) != torch.Size([4, 104, 88])

Epoch 1/182, Steps:   0%|                                  | 14/4000 [00:53<4:12:18,  3.80s/it, lr=0.00409, step_loss=0.0438]

@bghira
Copy link
Owner

bghira commented Oct 17, 2024

ok. i am in the mood for some pain i guess after dinner. i will proverbially dig in after i literally dig in

@a-l-e-x-d-s-9
Copy link
Author

Thank you!
I think it should be easy peasy for you to reproduce 😊, with the small dataset and all settings, you can change the paths in files and it should work.

@bghira
Copy link
Owner

bghira commented Oct 18, 2024

fixed by #1076 locally here

@a-l-e-x-d-s-9
Copy link
Author

I tested the PR with fix, with a small and medium dataset, they both finished the first epoch without crashing, so I think that the fix is working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1 / 0 magic not reliably reproducible bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants