Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mipcandy/common/optim/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tens
if d not in (1, 2, 3):
raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions")
labels = convert_ids_to_logits(labels.int(), d, self.num_classes)
labels = labels.float()
labels = labels.to(dtype=masks.dtype)
bce = nn.functional.binary_cross_entropy_with_logits(masks, labels)
masks = masks.sigmoid()
soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth, include_background=self.include_background)
Expand Down
2 changes: 1 addition & 1 deletion mipcandy/presets/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT
str, float]]:
masks = toolbox.model(images)
loss, metrics = toolbox.criterion(masks, labels)
loss.backward()
self._do_backward(loss, toolbox)
return loss.item(), metrics

@override
Expand Down
33 changes: 28 additions & 5 deletions mipcandy/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class TrainerToolbox(object):
scheduler: optim.lr_scheduler.LRScheduler
criterion: nn.Module
ema: nn.Module | None = None
scaler: torch.amp.GradScaler | None = None


@dataclass
Expand Down Expand Up @@ -86,6 +87,8 @@ def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: Trainer
torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth")
torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth")
torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth")
if toolbox.scaler:
torch.save(toolbox.scaler.state_dict(), f"{self.experiment_folder()}/scaler.pth")
torch.save(tracker, f"{self.experiment_folder()}/tracker.pt")
with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f:
dump(training_arguments, f)
Expand All @@ -109,6 +112,10 @@ def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_m
toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth"))
toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth"))
toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth"))
scaler_path = f"{self.experiment_folder()}/scaler.pth"
if exists(scaler_path):
toolbox.scaler = torch.amp.GradScaler()
toolbox.scaler.load_state_dict(torch.load(scaler_path))
return toolbox

def recover_from(self, experiment_id: str) -> Self:
Expand Down Expand Up @@ -353,6 +360,12 @@ def build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, compile_

# Training methods

def _do_backward(self, loss: torch.Tensor, toolbox: TrainerToolbox) -> None:
if toolbox.scaler:
toolbox.scaler.scale(loss).backward()
else:
loss.backward()

@abstractmethod
def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
str, float]]:
Expand All @@ -361,8 +374,14 @@ def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerT
def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
str, float]]:
toolbox.optimizer.zero_grad()
loss, metrics = self.backward(images, labels, toolbox)
toolbox.optimizer.step()
device_type = self._device.type if isinstance(self._device, torch.device) else str(self._device).split(":")[0]
with torch.amp.autocast(device_type=device_type, enabled=toolbox.scaler is not None):
loss, metrics = self.backward(images, labels, toolbox)
if toolbox.scaler:
toolbox.scaler.step(toolbox.optimizer)
toolbox.scaler.update()
else:
toolbox.optimizer.step()
toolbox.scheduler.step()
if toolbox.ema:
toolbox.ema.update_parameters(toolbox.model)
Expand All @@ -389,7 +408,7 @@ def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None:
def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, compile_model: bool = True,
ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5,
val_score_prediction: bool = True, val_score_prediction_degree: int = 5, save_preview: bool = True,
preview_quality: float = .75) -> None:
preview_quality: float = .75, amp: bool = False) -> None:
training_arguments = self.filter_train_params(**locals())
self.init_experiment()
if note:
Expand All @@ -414,6 +433,9 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
toolbox = (self.load_toolbox if self.recovery() else self.build_toolbox)(
num_epochs, example_shape, compile_model, ema
)
if amp:
toolbox.scaler = torch.amp.GradScaler()
self.log("Mixed precision training enabled")
checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth"
es_tolerance = early_stop_tolerance
self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note,
Expand Down Expand Up @@ -486,7 +508,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, co
def filter_train_params(**kwargs) -> dict[str, Setting]:
return {k: v for k, v in kwargs.items() if k in (
"note", "num_checkpoints", "compile_model", "ema", "seed", "early_stop_tolerance", "val_score_prediction",
"val_score_prediction_degree", "save_preview", "preview_quality"
"val_score_prediction_degree", "save_preview", "preview_quality", "amp"
)}

def train_with_settings(self, num_epochs: int, **kwargs) -> None:
Expand All @@ -511,7 +533,8 @@ def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float
worst_score = float("+inf")
metrics = {}
num_cases = len(self._validation_dataloader)
with torch.no_grad(), Progress(
device_type = self._device.type if isinstance(self._device, torch.device) else str(self._device).split(":")[0]
with torch.no_grad(), torch.amp.autocast(device_type=device_type, enabled=toolbox.scaler is not None), Progress(
*Progress.get_default_columns(), SpinnerColumn(), console=self._console
) as progress:
val_prog = progress.add_task(f"Validating", total=num_cases)
Expand Down