Skip to content

Commit

Permalink
[Flax] Fix sample batch size DreamBooth (open-mmlab#1129)
Browse files Browse the repository at this point in the history
fix sample batch size
  • Loading branch information
duongna21 authored Nov 4, 2022
1 parent bde4880 commit c62b3a2
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ def main():
logger.info(f"Number of class images to sample: {num_new_images}.")

sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)

for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
Expand Down

0 comments on commit c62b3a2

Please sign in to comment.