From 4623f095f3ac00582369bb0a13d208c2072b049d Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Thu, 27 Oct 2022 19:19:13 +0700 Subject: [PATCH] [DreamBooth] Set train mode for text encoder (#1012) Set train mode for text encoder --- examples/dreambooth/train_dreambooth.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3720f0cad..9b7e17241d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -574,6 +574,8 @@ def collate_fn(examples): for epoch in range(args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space