Skip to content

Commit

Permalink
Fix loading encoder weights trained with BYOL
Browse files Browse the repository at this point in the history
  • Loading branch information
BAHL Gaetan authored and adamjstewart committed Jun 13, 2022
1 parent cce37fb commit 3c8ee4c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchgeo/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]:
state_dict = OrderedDict(
{k.replace("model.", ""): v for k, v in state_dict.items()}
)
elif "encoder" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["encoder"]
elif "encoder_name" in checkpoint["hyper_parameters"]:
name = checkpoint["hyper_parameters"]["encoder_name"]
state_dict = checkpoint["state_dict"]
state_dict = OrderedDict(
{k: v for k, v in state_dict.items() if "model.encoder.model" in k}
Expand Down

0 comments on commit 3c8ee4c

Please sign in to comment.