@@ -354,7 +354,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
354
354
# TODO(Patrick, Suraj) - delete later
355
355
if class_name == "DummyChecker" :
356
356
library_name = "stable_diffusion"
357
- class_name = "StableDiffusionSafetyChecker "
357
+ class_name = "FlaxStableDiffusionSafetyChecker "
358
358
359
359
is_pipeline_module = hasattr (pipelines , library_name )
360
360
loaded_sub_model = None
@@ -421,16 +421,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
421
421
loaded_sub_model = cached_folder
422
422
423
423
if issubclass (class_obj , FlaxModelMixin ):
424
- # TODO(Patrick, Suraj) - Fix this as soon as Safety checker is fixed here
424
+ loaded_sub_model , loaded_params = load_method (loadable_folder , from_pt = from_pt , dtype = dtype )
425
+ params [name ] = loaded_params
426
+ elif is_transformers_available () and issubclass (class_obj , FlaxPreTrainedModel ):
427
+ # make sure we don't initialize the weights to save time
425
428
if name == "safety_checker" :
426
429
loaded_sub_model = DummyChecker ()
427
430
loaded_params = DummyChecker ()
428
- else :
429
- loaded_sub_model , loaded_params = load_method (loadable_folder , from_pt = from_pt , dtype = dtype )
430
- params [name ] = loaded_params
431
- elif is_transformers_available () and issubclass (class_obj , FlaxPreTrainedModel ):
432
- # make sure we don't initialize the weights to save time
433
- if from_pt :
431
+ elif from_pt :
434
432
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
435
433
loaded_sub_model = load_method (loadable_folder , from_pt = from_pt )
436
434
loaded_params = loaded_sub_model .params
0 commit comments