From c62b3a2e7e9cb1665db1062d96cc13118350ad8d Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Fri, 4 Nov 2022 19:49:57 +0700 Subject: [PATCH] [Flax] Fix sample batch size DreamBooth (#1129) fix sample batch size --- examples/dreambooth/train_dreambooth_flax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index d2652606b5..84493b1d94 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -361,7 +361,8 @@ def main(): logger.info(f"Number of class images to sample: {num_new_images}.") sample_dataset = PromptDataset(args.class_prompt, num_new_images) - sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + total_sample_batch_size = args.sample_batch_size * jax.local_device_count() + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size) for example in tqdm( sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0