diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c37197fdc6..9c512ef571 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -494,7 +494,12 @@ def collate_fn(examples): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids batch = { "input_ids": input_ids,