diff --git a/torchgeo/trainers/utils.py b/torchgeo/trainers/utils.py index 8e04815c937..7cde2b01d6f 100644 --- a/torchgeo/trainers/utils.py +++ b/torchgeo/trainers/utils.py @@ -42,8 +42,8 @@ def extract_encoder(path: str) -> Tuple[str, "OrderedDict[str, Tensor]"]: state_dict = OrderedDict( {k.replace("model.", ""): v for k, v in state_dict.items()} ) - elif "encoder" in checkpoint["hyper_parameters"]: - name = checkpoint["hyper_parameters"]["encoder"] + elif "encoder_name" in checkpoint["hyper_parameters"]: + name = checkpoint["hyper_parameters"]["encoder_name"] state_dict = checkpoint["state_dict"] state_dict = OrderedDict( {k: v for k, v in state_dict.items() if "model.encoder.model" in k}