Skip to content

Commit

Permalink
Fix padding in dreambooth (open-mmlab#1030)
Browse files Browse the repository at this point in the history
  • Loading branch information
shirayu authored Nov 2, 2022
1 parent 5cd29d6 commit 33c4874
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,12 @@ def collate_fn(examples):
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids

batch = {
"input_ids": input_ids,
Expand Down

0 comments on commit 33c4874

Please sign in to comment.