Skip to content

Commit

Permalink
fix dit weights convert to ppdiffusers (PaddlePaddle#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
nemonameless authored Oct 18, 2024
1 parent 10692bb commit 13fb61b
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
def main(args):
num_layers, hidden_size, patch_size, num_heads = arch_settings[args.model_name]

state_dict = paddle.load(args.model_weights)
state_dict_prefix = paddle.load(args.model_weights)
state_dict = {k.replace("transformer.", ""): v for k, v in state_dict_prefix.items()}
del state_dict_prefix

state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
Expand Down Expand Up @@ -158,7 +160,7 @@ def main(args):
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)

if args.save:
pipeline.save_pretrained(args.checkpoint_path)
pipeline.save_pretrained(args.checkpoint_path, safe_serialization=False)


if __name__ == "__main__":
Expand Down

0 comments on commit 13fb61b

Please sign in to comment.