diff --git a/geo_deep_learning/tasks_with_models/segmentation_dofa.py b/geo_deep_learning/tasks_with_models/segmentation_dofa.py index 5209a17a..5735180e 100644 --- a/geo_deep_learning/tasks_with_models/segmentation_dofa.py +++ b/geo_deep_learning/tasks_with_models/segmentation_dofa.py @@ -40,7 +40,22 @@ def __init__(self, self.std = std self.data_type_max = data_type_max self.num_classes = num_classes - self.model = DOFASeg(encoder, pretrained, image_size, wavelengths, self.num_classes) + self.model = DOFASeg(encoder, pretrained, freeze_encoder=False, + image_size=image_size, wavelengths=wavelengths, + num_classes=self.num_classes) + + # param_status = self.model.get_trainable_parameters() + # print(f"Trainable parameters: {param_status['trainable']}") + # print(f"Frozen parameters: {param_status['frozen']}") + + # Count trainable vs total parameters + # total_params = sum(p.numel() for p in self.model.parameters()) + # trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + + # print(f"\nTotal parameters: {total_params:,}") + # print(f"Trainable parameters: {trainable_params:,}") + # print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%") + if weights_from_checkpoint_path: print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}") checkpoint = torch.load(weights_from_checkpoint_path)