Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 708 time units #1181

Merged
merged 69 commits into from
Jun 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
c7861b8
Added torch.compile support
BloodAxe Feb 24, 2023
7f04468
Timer
BloodAxe Feb 25, 2023
b6a70fc
Timer
BloodAxe Feb 25, 2023
ea0ff39
Targets fpr torch compile
BloodAxe Feb 27, 2023
9665332
Disable DDP
BloodAxe Feb 27, 2023
eedd4ac
train_dataloader_params:
BloodAxe Feb 27, 2023
67b4cd1
Fix Timer callback to log events per global step
BloodAxe Feb 27, 2023
a9af29c
load_backbone: False
BloodAxe Feb 27, 2023
116b0a0
Detection models
BloodAxe Feb 27, 2023
81033d6
Lower LR
BloodAxe Feb 27, 2023
fe86a85
Added notes
BloodAxe Feb 27, 2023
f151b17
Added per-epoch timers
BloodAxe Feb 27, 2023
bd20e04
Merge remote-tracking branch 'origin/feature/SG-686-torch-compile' in…
BloodAxe Feb 27, 2023
1c54daf
Fix wrong nesting of drop_last
BloodAxe Feb 27, 2023
a5b5dd8
Merge branch 'master' into feature/SG-686-torch-compile
BloodAxe Feb 27, 2023
3fe6c78
Fixes to logging
BloodAxe Feb 27, 2023
fc35853
Log values per step/epoch explictly
BloodAxe Feb 27, 2023
c8096cf
Merge remote-tracking branch 'origin/feature/SG-686-torch-compile' in…
BloodAxe Feb 27, 2023
3334cdd
Fixes to logging
BloodAxe Feb 27, 2023
8a178ce
Merge remote-tracking branch 'origin/feature/SG-686-torch-compile' in…
BloodAxe Feb 27, 2023
75ff37e
Fixes to logging
BloodAxe Feb 27, 2023
e586b1f
Increase num epochs
BloodAxe Feb 27, 2023
7126094
Update numbers
BloodAxe Feb 27, 2023
84455b2
Added epoch_total_time_sec
BloodAxe Mar 2, 2023
6453c4b
cityscapes_stdc_seg50
BloodAxe Mar 2, 2023
527ed10
load_backbone: False
BloodAxe Mar 2, 2023
7f491e3
imagenet_regnetY
BloodAxe Mar 2, 2023
e13cba0
imagenet_regnetY
BloodAxe Mar 2, 2023
440d9eb
cityscapes_stdc_seg50 with different compilation modes
BloodAxe Mar 6, 2023
b5224ab
cityscapes_stdc_seg50
BloodAxe Mar 6, 2023
83f0fe7
cityscapes_ddrnet
BloodAxe Mar 6, 2023
cafd7f3
Update makefile targets
BloodAxe Mar 6, 2023
f8a517e
Ensure we log only on master
BloodAxe Mar 7, 2023
b00f173
Add sync point to ensure we've compiled model on all nodes before goi…
BloodAxe Mar 7, 2023
52a5f29
Reduce bs
BloodAxe Mar 7, 2023
e488ed6
Merge master
BloodAxe Jun 14, 2023
d023671
Adding makefile targets
BloodAxe Jun 14, 2023
2d63cb3
Yolo Nas configs
BloodAxe Jun 14, 2023
348d302
Yolo Nas configs
BloodAxe Jun 14, 2023
f98aa1d
Add timer
BloodAxe Jun 14, 2023
344b39a
Add timer
BloodAxe Jun 14, 2023
56ec3ae
segmentation_compile_tests
BloodAxe Jun 14, 2023
085122f
segmentation_compile_tests
BloodAxe Jun 14, 2023
06040d0
segmentation_compile_tests
BloodAxe Jun 14, 2023
6479910
Call to torch.compile after we set up DDP
BloodAxe Jun 15, 2023
26a974c
cityscapes_ddrnet_test
BloodAxe Jun 15, 2023
8a0e09f
Omit to(device) after converting model to syncbn
BloodAxe Jun 15, 2023
c98eef6
Change default torch_compile_mode to reduce-overhead
BloodAxe Jun 15, 2023
d840675
Update makeifle
BloodAxe Jun 15, 2023
800e11c
segmentation_compile_tests
BloodAxe Jun 15, 2023
fd94f80
Update makeifle
BloodAxe Jun 15, 2023
3c46eef
Filling table
BloodAxe Jun 15, 2023
f3f24c9
Merge remote-tracking branch 'origin/feature/SG-686-torch-compile' in…
BloodAxe Jun 15, 2023
76f1d7c
Update makeifle
BloodAxe Jun 15, 2023
83e33cd
Filling table
BloodAxe Jun 15, 2023
d053c38
Merge remote-tracking branch 'origin/feature/SG-686-torch-compile' in…
BloodAxe Jun 15, 2023
83cea64
Filling table
BloodAxe Jun 15, 2023
fe73bf2
Update makeifle
BloodAxe Jun 15, 2023
723d8a0
Update makeifle
BloodAxe Jun 15, 2023
6762ff3
Update makeifle
BloodAxe Jun 16, 2023
d6666db
Adding time units
BloodAxe Jun 16, 2023
dd6ba51
Yolo NAS numbers
BloodAxe Jun 16, 2023
624b1ae
Merge branch 'master' into feature/SG-686-torch-compile
BloodAxe Jun 16, 2023
7188903
Yolo NAS numbers
BloodAxe Jun 16, 2023
1226e25
Adding TimerCallback and explicit TimeUnits
BloodAxe Jun 16, 2023
eb028ca
Add import of TimerCallback
BloodAxe Jun 16, 2023
99f3a9c
Fixed the potential crash if TimerCallback used for evaluate_from_recipe
BloodAxe Jun 16, 2023
49de27c
Fix missing inheritance for GlobalBatchStepNumber
BloodAxe Jun 16, 2023
10e9d85
Merge branch 'master' into feature/SG-708-time-units
BloodAxe Jun 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Callbacks:
DETECTION_VISUALIZATION_CALLBACK = "DetectionVisualizationCallback"
DEKR_VISUALIZATION = "DEKRVisualizationCallback"
ROBOFLOW_RESULT_CALLBACK = "RoboflowResultCallback"
TIMER = "TimerCallback"


class LRSchedulers:
Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/common/sg_loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from super_gradients.common.sg_loggers.deci_platform_sg_logger import DeciPlatformSGLogger
from super_gradients.common.sg_loggers.wandb_sg_logger import WandBSGLogger
from super_gradients.common.sg_loggers.dagshub_sg_logger import DagsHubSGLogger
from super_gradients.common.sg_loggers.time_units import TimeUnit, EpochNumber, GlobalBatchStepNumber

__all__ = ["BaseSGLogger", "ClearMLSGLogger", "DeciPlatformSGLogger", "WandBSGLogger", "DagsHubSGLogger"]
__all__ = ["BaseSGLogger", "ClearMLSGLogger", "DeciPlatformSGLogger", "WandBSGLogger", "DagsHubSGLogger", "TimeUnit", "EpochNumber", "GlobalBatchStepNumber"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from PIL import Image
import torch

from super_gradients.common.sg_loggers.time_units import TimeUnit


class AbstractSGLogger(ABC):
"""
Expand Down Expand Up @@ -40,7 +42,7 @@ def add_config(self, tag: str, config: dict):
raise NotImplementedError

@abstractmethod
def add_scalar(self, tag: str, scalar_value: float, global_step: int = None):
def add_scalar(self, tag: str, scalar_value: float, global_step: Union[int, TimeUnit] = None):
"""
Add scalar data to SGLogger.
Typically, this function will add scalar to tensorboard or other experiment management framework.
Expand Down
15 changes: 9 additions & 6 deletions src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
import torch
from PIL import Image

from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
from super_gradients.common.auto_logging.console_logging import ConsoleSink
from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces
from super_gradients.common.decorators.code_save_decorator import saved_codes
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.monitoring import SystemMonitor
from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
from super_gradients.common.sg_loggers.time_units import TimeUnit
from super_gradients.training.params import TrainingParams
from super_gradients.training.utils import sg_trainer_utils, get_param
from super_gradients.common.environment.monitoring import SystemMonitor
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
from super_gradients.common.auto_logging.console_logging import ConsoleSink

logger = get_logger(__name__)

Expand Down Expand Up @@ -155,7 +156,9 @@ def add_config(self, tag: str, config: dict):
self._write_to_log_file(log_lines)

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: int = None):
def add_scalar(self, tag: str, scalar_value: float, global_step: Union[int, TimeUnit] = None):
if isinstance(global_step, TimeUnit):
global_step = global_step.get_value()
self.tensorboard_writer.add_scalar(tag=tag.lower().replace(" ", "_"), scalar_value=scalar_value, global_step=global_step)

@multi_process_safe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.sg_loggers.time_units import TimeUnit

logger = get_logger(__name__)

Expand Down Expand Up @@ -114,8 +115,10 @@ def __add_scalar(self, tag: str, scalar_value: float, global_step: int):
self.clearml_logger.report_scalar(title=tag, series=tag, value=scalar_value, iteration=global_step)

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: int = 0):
def add_scalar(self, tag: str, scalar_value: float, global_step: Union[int, TimeUnit] = 0):
super(ClearMLSGLogger, self).add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)
if isinstance(global_step, TimeUnit):
global_step = global_step.get_value()
self.__add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)

@multi_process_safe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.sg_loggers.time_units import TimeUnit

logger = get_logger(__name__)

Expand Down Expand Up @@ -177,8 +178,10 @@ def add_config(self, tag: str, config: dict):
logger.warning(f"Skip to log {k}: {v}")

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: int = 0):
def add_scalar(self, tag: str, scalar_value: float, global_step: [int, TimeUnit] = 0):
super(DagsHubSGLogger, self).add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)
if isinstance(global_step, TimeUnit):
global_step = global_step.get_value()
mlflow.log_metric(key=tag, value=scalar_value, step=global_step)

@multi_process_safe
Expand Down
49 changes: 49 additions & 0 deletions src/super_gradients/common/sg_loggers/time_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import abc
import dataclasses


class TimeUnit(abc.ABC):
"""
Abstract class for time units. This is used to explicitly log the time unit of a metric/loss.
"""

@abc.abstractmethod
def get_value(self):
...

@abc.abstractmethod
def get_name(self):
...


@dataclasses.dataclass
class EpochNumber(TimeUnit):
"""
A time unit for epoch number.
"""

value: float

def get_value(self):
return self.value

def get_name(self):
return "epoch"


@dataclasses.dataclass
class GlobalBatchStepNumber(TimeUnit):
"""
A time unit for representing total number of batches processed, including training and validation ones.
Suppose training loader has 320 batches and validation loader has 80 batches.
If the current epoch index is 2 (zero-based), and we are on validation loader and current index is 50 (zero-based),
then the global batch step is (320 + 80) * 3 + 320 + 50 = 1570.
"""

value: float

def get_value(self):
return self.value

def get_name(self):
return "global_batch_step"
20 changes: 11 additions & 9 deletions src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import os

from typing import Union, Optional, Any

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.abstractions.abstract_logger import get_logger

from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.registry.registry import register_sg_logger
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
from super_gradients.common.sg_loggers.time_units import TimeUnit

logger = get_logger(__name__)

Expand Down Expand Up @@ -169,9 +168,12 @@ def add_config(self, tag: str, config: dict):
wandb.config.update(config, allow_val_change=self.resumed)

@multi_process_safe
def add_scalar(self, tag: str, scalar_value: float, global_step: int = 0):
def add_scalar(self, tag: str, scalar_value: float, global_step: Union[int, TimeUnit] = 0):
super(WandBSGLogger, self).add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)
wandb.log(data={tag: scalar_value}, step=global_step)
if isinstance(global_step, TimeUnit):
wandb.log(data={tag: scalar_value, global_step.get_name(): global_step.get_value()})
else:
wandb.log(data={tag: scalar_value}, step=global_step)

@multi_process_safe
def add_scalars(self, tag_scalar_dict: dict, global_step: int = 0):
Expand Down
3 changes: 3 additions & 0 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
lr_warmup_epochs=self.training_params.lr_warmup_epochs,
sg_logger=self.sg_logger,
train_loader=self.train_loader,
valid_loader=self.valid_loader,
context_methods=self._get_context_methods(Phase.TRAIN_BATCH_END),
ddp_silent_mode=self.ddp_silent_mode,
)
Expand Down Expand Up @@ -1835,6 +1836,8 @@ def evaluate(
device=device_config.device,
lr_warmup_epochs=lr_warmup_epochs,
sg_logger=self.sg_logger,
train_loader=self.train_loader,
valid_loader=self.valid_loader,
context_methods=self._get_context_methods(Phase.VALIDATION_BATCH_END),
)

Expand Down
2 changes: 2 additions & 0 deletions src/super_gradients/training/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TrainingStageSwitchCallbackBase,
YoloXTrainingStageSwitchCallback,
TestLRCallback,
TimerCallback,
)
from super_gradients.training.utils.callbacks.ppyoloe_switch_callback import PPYoloETrainingStageSwitchCallback
from super_gradients.common.object_names import Callbacks, LRSchedulers, LRWarmups
Expand Down Expand Up @@ -60,4 +61,5 @@
"CallbackHandler",
"TestLRCallback",
"PPYoloETrainingStageSwitchCallback",
"TimerCallback",
]
100 changes: 100 additions & 0 deletions src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from super_gradients.common.plugins.deci_client import DeciClient
from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback
from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks
from super_gradients.common.sg_loggers.time_units import GlobalBatchStepNumber, EpochNumber
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
Expand Down Expand Up @@ -759,3 +760,102 @@ def __init__(self, lr_placeholder):

def __call__(self, context: PhaseContext):
self.lr_placeholder.append(context.optimizer.param_groups[0]["lr"])


@register_callback(Callbacks.TIMER)
class TimerCallback(Callback):
def __init__(self):
self.events = {}

@multi_process_safe
def on_train_loader_start(self, context: PhaseContext) -> None:
self.events["on_train_loader_start"] = cv2.getTickCount()

@multi_process_safe
def on_train_batch_start(self, context: PhaseContext) -> None:
self.events["on_train_batch_start"] = cv2.getTickCount()

@multi_process_safe
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
self.events["on_train_batch_loss_end"] = cv2.getTickCount()
context.sg_logger.add_scalar(
tag="timer/train_batch_forward_with_loss_ms",
scalar_value=self._elapsed_time_between("on_train_batch_start", "on_train_batch_loss_end"),
global_step=GlobalBatchStepNumber(self._infer_global_step(context, is_train_loader=True)),
)

@multi_process_safe
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
self.events["on_train_batch_gradient_step_start"] = cv2.getTickCount()

@multi_process_safe
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
self.events["on_train_batch_gradient_step_end"] = cv2.getTickCount()
context.sg_logger.add_scalar(
tag="timer/train_batch_gradient_time",
scalar_value=self._elapsed_time_between("on_train_batch_gradient_step_start", "on_train_batch_gradient_step_end"),
global_step=GlobalBatchStepNumber(self._infer_global_step(context, is_train_loader=True)),
)

@multi_process_safe
def on_train_batch_end(self, context: PhaseContext) -> None:
self.events["on_train_batch_end"] = cv2.getTickCount()
context.sg_logger.add_scalar(
tag="timer/train_batch_total_time_ms",
scalar_value=self._elapsed_time_between("on_train_batch_start", "on_train_batch_end"),
global_step=GlobalBatchStepNumber(self._infer_global_step(context, is_train_loader=True)),
)

@multi_process_safe
def on_train_loader_end(self, context: PhaseContext) -> None:
self.events["on_train_loader_end"] = cv2.getTickCount()
context.sg_logger.add_scalar(
tag="timer/train_loader_total_time_ms",
scalar_value=self._elapsed_time_between("on_train_loader_start", "on_train_loader_end"),
global_step=EpochNumber(context.epoch),
)

@multi_process_safe
def on_validation_loader_start(self, context: PhaseContext) -> None:
self.events["on_validation_loader_start"] = cv2.getTickCount()

@multi_process_safe
def on_validation_batch_start(self, context: PhaseContext) -> None:
self.events["on_validation_batch_start"] = cv2.getTickCount()

@multi_process_safe
def on_validation_batch_end(self, context: PhaseContext) -> None:
self.events["on_validation_batch_end"] = cv2.getTickCount()
context.sg_logger.add_scalar(
tag="timer/validation_batch_total_time_ms",
scalar_value=self._elapsed_time_between("on_validation_batch_start", "on_validation_batch_end"),
global_step=GlobalBatchStepNumber(self._infer_global_step(context, is_train_loader=False)),
)

@multi_process_safe
def on_validation_loader_end(self, context: PhaseContext) -> None:
self.events["on_validation_loader_end"] = cv2.getTickCount()
context.sg_logger.add_scalar(
tag="timer/validation_loader_total_time_ms",
scalar_value=self._elapsed_time_between("on_validation_loader_start", "on_validation_loader_end"),
global_step=EpochNumber(context.epoch),
)

context.sg_logger.add_scalar(
tag="timer/epoch_total_time_sec",
scalar_value=self._elapsed_time_between("on_train_loader_start", "on_validation_loader_end") / 1000.0,
global_step=EpochNumber(context.epoch),
)

def _elapsed_time_between(self, start_event, end_event):
return 1000.0 * (self.events[end_event] - self.events[start_event]) / cv2.getTickFrequency()

def _infer_global_step(self, context: PhaseContext, is_train_loader: bool):
train_loader_length = len(context.train_loader) if context.train_loader is not None else 0
valid_loader_length = len(context.valid_loader) if context.valid_loader is not None else 0
total_steps_in_epoch = train_loader_length + valid_loader_length
total_steps_in_done = context.epoch * total_steps_in_epoch
if is_train_loader:
return total_steps_in_done + context.batch_idx
else:
return total_steps_in_done + train_loader_length + context.batch_idx