diff --git a/src/pytorch_ie/core/hf_hub_mixin.py b/src/pytorch_ie/core/hf_hub_mixin.py index e485d85a..f066f7aa 100644 --- a/src/pytorch_ie/core/hf_hub_mixin.py +++ b/src/pytorch_ie/core/hf_hub_mixin.py @@ -416,14 +416,21 @@ def _from_pretrained( config: Optional[dict] = None, **model_kwargs, ) -> TModel: + + config = (config or {}).copy() + config.update(model_kwargs) + if cls.config_type_key is not None: + config.pop(cls.config_type_key) + model = cls(**config) + """Load Pytorch pretrained weights and return the loaded model.""" if os.path.isdir(model_id): logger.info("Loading weights from local directory") - model_file = os.path.join(model_id, cls.weights_file_name) + model_file = os.path.join(model_id, model.weights_file_name) else: model_file = hf_hub_download( repo_id=model_id, - filename=cls.weights_file_name, + filename=model.weights_file_name, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -433,12 +440,6 @@ def _from_pretrained( local_files_only=local_files_only, ) - config = (config or {}).copy() - config.update(model_kwargs) - if cls.config_type_key is not None: - config.pop(cls.config_type_key) - model = cls(**config) - model.load_weights(model_file, map_location=map_location, strict=strict) return model