Skip to content

Commit

Permalink
Merge pull request #29 from kohya-ss/dev
Browse files Browse the repository at this point in the history
rename another position_ids key (supports wd v1.4)
  • Loading branch information
kohya-ss authored Jan 1, 2023
2 parents 885fd9e + f192338 commit bda0e83
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,16 @@ def convert_key(key):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]

# position_idsの追加
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
# rename or add position_ids
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
if ANOTHER_POSITION_IDS_KEY in new_sd:
# waifu diffusion v1.4
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
del new_sd[ANOTHER_POSITION_IDS_KEY]
else:
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)

new_sd["text_model.embeddings.position_ids"] = position_ids
return new_sd

# endregion
Expand Down

0 comments on commit bda0e83

Please sign in to comment.