Skip to content

Commit

Permalink
Refactor SegmentationDOFA to enable encoder freezing and improve trai…
Browse files Browse the repository at this point in the history
…nable parameter tracking.
  • Loading branch information
valhassan committed Dec 9, 2024
1 parent ae958fd commit 22a21fe
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion geo_deep_learning/tasks_with_models/segmentation_dofa.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,22 @@ def __init__(self,
self.std = std
self.data_type_max = data_type_max
self.num_classes = num_classes
self.model = DOFASeg(encoder, pretrained, image_size, wavelengths, self.num_classes)
self.model = DOFASeg(encoder, pretrained, freeze_encoder=False,
image_size=image_size, wavelengths=wavelengths,
num_classes=self.num_classes)

# param_status = self.model.get_trainable_parameters()
# print(f"Trainable parameters: {param_status['trainable']}")
# print(f"Frozen parameters: {param_status['frozen']}")

# Count trainable vs total parameters
# total_params = sum(p.numel() for p in self.model.parameters())
# trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

# print(f"\nTotal parameters: {total_params:,}")
# print(f"Trainable parameters: {trainable_params:,}")
# print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")

if weights_from_checkpoint_path:
print(f"Loading weights from checkpoint: {weights_from_checkpoint_path}")
checkpoint = torch.load(weights_from_checkpoint_path)
Expand Down

0 comments on commit 22a21fe

Please sign in to comment.