diff --git a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py index 21e9204e49..cba8a6d6f8 100644 --- a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py +++ b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py @@ -39,8 +39,8 @@ def get_modules_names( for child_name, child_module in module.named_modules(): if len(child_name) == 0: continue - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_linear = isinstance(child_module, nn.Linear) + is_conv2d = isinstance(child_module, nn.Conv2d) if (is_linear and module.__class__.__name__ in target_replace_modules_linear) or ( is_conv2d and module.__class__.__name__ in target_replace_modules_conv2d