Skip to content

Commit

Permalink
use weights_file_name from loaded model
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 3, 2024
1 parent 6fdd7ad commit 2416138
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/pytorch_ie/core/hf_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,14 +416,21 @@ def _from_pretrained(
config: Optional[dict] = None,
**model_kwargs,
) -> TModel:

config = (config or {}).copy()
config.update(model_kwargs)
if cls.config_type_key is not None:
config.pop(cls.config_type_key)
model = cls(**config)

"""Load Pytorch pretrained weights and return the loaded model."""
if os.path.isdir(model_id):
logger.info("Loading weights from local directory")
model_file = os.path.join(model_id, cls.weights_file_name)
model_file = os.path.join(model_id, model.weights_file_name)
else:
model_file = hf_hub_download(
repo_id=model_id,
filename=cls.weights_file_name,
filename=model.weights_file_name,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
Expand All @@ -433,12 +440,6 @@ def _from_pretrained(
local_files_only=local_files_only,
)

config = (config or {}).copy()
config.update(model_kwargs)
if cls.config_type_key is not None:
config.pop(cls.config_type_key)
model = cls(**config)

model.load_weights(model_file, map_location=map_location, strict=strict)

return model
Expand Down

0 comments on commit 2416138

Please sign in to comment.