diff --git a/library/model_util.py b/library/model_util.py index f34530252..398b6404b 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -624,8 +624,16 @@ def convert_key(key): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # position_idsの追加 - new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids return new_sd # endregion