Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExtremeBatchCaseVisualizationCallback now has access to additional_batch_items #1517

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
if self.pre_prediction_callback is not None:
inputs, targets = self.pre_prediction_callback(inputs, targets, batch_idx)

context.update_context(batch_idx=batch_idx, inputs=inputs, target=targets, **additional_batch_items)
context.update_context(
batch_idx=batch_idx, inputs=inputs, target=targets, additional_batch_items=additional_batch_items, **additional_batch_items
)
self.phase_callback_handler.on_train_batch_start(context)

# AUTOCAST IS ENABLED ONLY IF self.training_params.mixed_precision - IF enabled=False AUTOCAST HAS NO EFFECT
Expand Down Expand Up @@ -2097,7 +2099,9 @@ def evaluate(
inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)

# TRIGGER PHASE CALLBACKS CORRESPONDING TO THE EVALUATION TYPE
context.update_context(batch_idx=batch_idx, inputs=inputs, target=targets, **additional_batch_items)
context.update_context(
batch_idx=batch_idx, inputs=inputs, target=targets, additional_batch_items=additional_batch_items, **additional_batch_items
)
if evaluation_type == EvaluationType.VALIDATION:
self.phase_callback_handler.on_validation_batch_start(context)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List
from typing import List, Any

from typing import Optional
import torch
Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(
valid_metrics: Optional[MetricCollection] = None, # noqa: ignore
ema_model: Optional["SgModule"] = None, # noqa: ignore
loss_logging_items_names: Optional[List[str]] = None,
additional_batch_items: Optional[Any] = None,
):
self.epoch = epoch
self.batch_idx = batch_idx
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self.valid_metrics = valid_metrics
self.ema_model = ema_model
self.loss_logging_items_names = loss_logging_items_names
self.additional_batch_items = additional_batch_items

def update_context(self, **kwargs):
for attr, attr_val in kwargs.items():
Expand Down
3 changes: 3 additions & 0 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ def __init__(
self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None
self.extreme_additional_batch_items = None

self._first_call = True
self._idx_loss_tuple = None
Expand Down Expand Up @@ -1224,6 +1225,7 @@ def _on_batch_end(self, context: PhaseContext) -> None:
self.extreme_batch = tensor_container_to_device(context.inputs, device="cpu", detach=True, non_blocking=False)
self.extreme_preds = tensor_container_to_device(context.preds, device="cpu", detach=True, non_blocking=False)
self.extreme_targets = tensor_container_to_device(context.target, device="cpu", detach=True, non_blocking=False)
self.extreme_additional_batch_items = tensor_container_to_device(context.additional_batch_items, device="cpu", detach=True, non_blocking=False)

def _init_loss_attributes(self, context: PhaseContext):
if self.loss_to_monitor not in context.loss_logging_items_names:
Expand All @@ -1236,6 +1238,7 @@ def _reset(self):
self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None
self.extreme_additional_batch_items = None
if self.metric is not None:
self.metric.reset()

Expand Down