diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a504184ea2f2..c505c5a262a3 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import fnmatch import importlib import inspect @@ -811,6 +812,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + expected_types = pipeline_class._get_signature_types() passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -833,6 +835,26 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} + for key in init_dict.keys(): + if key not in passed_class_obj: + continue + if "scheduler" in key: + continue + + class_obj = passed_class_obj[key] + _expected_class_types = [] + for expected_type in expected_types[key]: + if isinstance(expected_type, enum.EnumMeta): + _expected_class_types.extend(expected_type.__members__.keys()) + else: + _expected_class_types.append(expected_type.__name__) + + _is_valid_type = class_obj.__class__.__name__ in _expected_class_types + if not _is_valid_type: + logger.warning( + f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." + ) + # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 43b01c40f5bb..423c82e0602e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1802,6 +1802,16 @@ def test_pipe_same_device_id_offload(self): sd.maybe_free_model_hooks() assert sd._offload_gpu_id == 5 + def test_wrong_model(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + with self.assertRaises(ValueError) as error_context: + _ = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer + ) + + assert "is of type" in str(error_context.exception) + assert "but should be" in str(error_context.exception) + @slow @require_torch_gpu