From d572465c6ea598a94518baa900e420616bc3650f Mon Sep 17 00:00:00 2001 From: valhassan Date: Mon, 9 Dec 2024 14:23:15 -0500 Subject: [PATCH] added freeze_encoder param --- geo_deep_learning/tasks_with_models/segmentation_segformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/geo_deep_learning/tasks_with_models/segmentation_segformer.py b/geo_deep_learning/tasks_with_models/segmentation_segformer.py index dc073486..26fe4f5d 100644 --- a/geo_deep_learning/tasks_with_models/segmentation_segformer.py +++ b/geo_deep_learning/tasks_with_models/segmentation_segformer.py @@ -25,6 +25,7 @@ def __init__(self, std: List[float], data_type_max: float, loss: Callable, + freeze_encoder: bool = False, weights: str = None, class_labels: List[str] = None, class_colors: List[str] = None, @@ -38,7 +39,7 @@ def __init__(self, self.data_type_max = data_type_max self.class_colors = class_colors self.num_classes = num_classes - self.model = SegFormer(encoder, in_channels, weights, self.num_classes) + self.model = SegFormer(encoder, in_channels, weights, freeze_encoder, self.num_classes) if weights_from_checkpoint_path: print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}") checkpoint = torch.load(weights_from_checkpoint_path)