Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DDP checkpoint #1415

Merged
merged 9 commits into from
Aug 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from datetime import datetime

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import execute_and_distribute_from_master


try:
@@ -16,6 +17,7 @@
logger = get_logger(__name__)


@execute_and_distribute_from_master
def generate_run_id() -> str:
"""Generate a unique run ID based on the current timestamp.

@@ -35,11 +37,12 @@ def is_run_dir(dirname: str) -> bool:

def get_latest_run_id(experiment_name: str, checkpoints_root_dir: Optional[str] = None) -> Optional[str]:
"""
:param experiment_name: Name of the experiment.
:param checkpoints_root_dir: Path to the directory where all the experiments are organised, each sub-folder representing a specific experiment.
:param experiment_name: Name of the experiment.
:param checkpoints_root_dir: Path to the directory where all the experiments are organised, each sub-folder representing a specific experiment.
If None, SG will first check if a package named 'checkpoints' exists.
If not, SG will look for the root of the project that includes the script that was launched.
If not found, raise an error.
:return: Latest valid run ID. in the format "RUN_<year>"
"""
experiment_dir = get_experiment_dir_path(checkpoints_root_dir=checkpoints_root_dir, experiment_name=experiment_name)

@@ -51,7 +54,7 @@ def get_latest_run_id(experiment_name: str, checkpoints_root_dir: Optional[str]
f"Trying to load the n-1 most recent run..."
)
else:
return run_dir
return os.path.basename(run_dir)


def validate_run_id(run_id: str, experiment_name: str, ckpt_root_dir: Optional[str] = None):
94 changes: 93 additions & 1 deletion src/super_gradients/common/environment/ddp_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
import socket
from functools import wraps
import os
from typing import Any, List, Callable

import torch
import torch.distributed as dist

from super_gradients.common.environment.device_utils import device_config
from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
@@ -77,3 +81,91 @@ def find_free_port() -> int:
sock.bind(("", 0))
_ip, port = sock.getsockname()
return port


def get_local_rank():
"""
Returns the local rank if running in DDP, and 0 otherwise
:return: local rank
"""
return dist.get_rank() if dist.is_initialized() else 0


def require_ddp_setup() -> bool:
from super_gradients.common import MultiGPUMode

return device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and device_config.assigned_rank != get_local_rank()


def is_ddp_subprocess():
return torch.distributed.get_rank() > 0 if dist.is_initialized() else False


def get_world_size() -> int:
"""
Returns the world size if running in DDP, and 1 otherwise
:return: world size
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()


def get_device_ids() -> List[int]:
return list(range(get_world_size()))


def count_used_devices() -> int:
return len(get_device_ids())


def execute_and_distribute_from_master(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator to execute a function on the master process and distribute the result to all other processes.
Useful in parallel computing scenarios where a computational task needs to be performed only on the master
node (e.g., a computational-heavy calculation), and the result must be shared with other nodes without
redundant computation.

Example usage:
>>> @execute_and_distribute_from_master
>>> def some_code_to_run(param1, param2):
>>> return param1 + param2

The wrapped function will only be executed on the master node, and the result will be propagated to all
other nodes.

:param func: The function to be executed on the master process and whose result is to be distributed.
:return: A wrapper function that encapsulates the execute-and-distribute logic.
"""

@wraps(func)
def wrapper(*args, **kwargs):
# Run the function only if it's the master process
if device_config.assigned_rank <= 0:
result = func(*args, **kwargs)
else:
result = None

# Broadcast the result from the master process to all nodes
return broadcast_from_master(result)

return wrapper


def broadcast_from_master(data: Any) -> Any:
"""
Broadcast data from master node to all other nodes. This may be required when you
want to compute something only on master node (e.g computational-heavy metric) and
don't want to waste CPU of other nodes doing the same work simultaneously.

:param data: Data to be broadcasted from master node (rank 0)
:return: Data from rank 0 node
"""
world_size = get_world_size()
if world_size == 1:
return data
broadcast_list = [data] if dist.get_rank() == 0 else [None]
dist.broadcast_object_list(broadcast_list, src=0)
return broadcast_list[0]
3 changes: 2 additions & 1 deletion src/super_gradients/common/environment/omegaconf_utils.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@

from omegaconf import OmegaConf, DictConfig

from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
from hydra.experimental.callback import Callback


@@ -72,6 +71,8 @@ def get_cls(cls_path: str):


def hydra_output_dir_resolver(ckpt_root_dir: str, experiment_name: str) -> str:
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path

return get_checkpoints_dir_path(experiment_name=experiment_name, ckpt_root_dir=ckpt_root_dir)


3 changes: 2 additions & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
@@ -133,7 +133,8 @@ def _setup_dir(self):
# Only if it exists, i.e. if hydra was used.
if os.path.exists(source_hydra_path):
destination_hydra_path = os.path.join(self._local_dir, ".hydra")
shutil.copytree(source_hydra_path, destination_hydra_path, dirs_exist_ok=True)
if not os.path.exists(destination_hydra_path):
shutil.copytree(source_hydra_path, destination_hydra_path)

@multi_process_safe
def _init_log_file(self):
2 changes: 1 addition & 1 deletion src/super_gradients/training/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
@@ -37,8 +37,8 @@
from super_gradients.training.utils import get_param
from super_gradients.training.utils.distributed_training_utils import (
wait_for_the_master,
get_local_rank,
)
from super_gradients.common.environment.ddp_utils import get_local_rank
from super_gradients.training.utils.utils import override_default_params_without_nones
from super_gradients.common.environment.cfg_utils import load_dataset_params
import torch.distributed as dist
2 changes: 1 addition & 1 deletion src/super_gradients/training/datasets/datasets_utils.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from super_gradients.common.registry.registry import register_collate_function, register_callback, register_transform
from super_gradients.training.datasets.auto_augment import rand_augment_transform
from super_gradients.training.utils.detection_utils import DetectionVisualization, Anchors
from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
from super_gradients.common.environment.ddp_utils import get_local_rank, get_world_size
from super_gradients.training.utils.utils import AverageMeter


4 changes: 1 addition & 3 deletions src/super_gradients/training/losses/ppyolo_loss.py
Original file line number Diff line number Diff line change
@@ -10,9 +10,7 @@
from super_gradients.common.registry.registry import register_loss
from super_gradients.training.datasets.data_formats.bbox_formats.cxcywh import cxcywh_to_xyxy
from super_gradients.training.utils.bbox_utils import batch_distance2bbox
from super_gradients.training.utils.distributed_training_utils import (
get_world_size,
)
from super_gradients.common.environment.ddp_utils import get_world_size


def batch_iou_similarity(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-9) -> float:
3 changes: 2 additions & 1 deletion src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from torchmetrics import Metric

import super_gradients
import super_gradients.common.environment.ddp_utils
from super_gradients.common.object_names import Metrics
from super_gradients.common.registry.registry import register_metric
from super_gradients.training.utils import tensor_container_to_device
@@ -222,7 +223,7 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None):
:return:
"""
if self.world_size is None:
self.world_size = torch.distributed.get_world_size() if self.is_distributed else -1
self.world_size = super_gradients.common.environment.ddp_utils.get_world_size() if self.is_distributed else -1
if self.rank is None:
self.rank = torch.distributed.get_rank() if self.is_distributed else -1

Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from torch import Tensor
from torchmetrics import Metric

import super_gradients.common.environment.ddp_utils
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import is_distributed
from super_gradients.common.object_names import Metrics
@@ -258,7 +259,7 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None):
:return:
"""
if self.world_size is None:
self.world_size = torch.distributed.get_world_size() if self.is_distributed else -1
self.world_size = super_gradients.common.environment.ddp_utils.get_world_size() if self.is_distributed else -1
if self.rank is None:
self.rank = torch.distributed.get_rank() if self.is_distributed else -1

Original file line number Diff line number Diff line change
@@ -13,7 +13,8 @@

__all__ = ["CSPResNetBackbone", "CSPResNetBasicBlock"]

from super_gradients.training.utils.distributed_training_utils import wait_for_the_master, get_local_rank
from super_gradients.training.utils.distributed_training_utils import wait_for_the_master
from super_gradients.common.environment.ddp_utils import get_local_rank


class CSPResNetBasicBlock(nn.Module):
6 changes: 1 addition & 5 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
@@ -61,14 +61,10 @@
compute_precise_bn_stats,
setup_device,
get_gpu_mem_utilization,
get_world_size,
get_local_rank,
require_ddp_setup,
get_device_ids,
is_ddp_subprocess,
wait_for_the_master,
DDPNotSetupException,
)
from super_gradients.common.environment.ddp_utils import get_local_rank, require_ddp_setup, is_ddp_subprocess, get_world_size, get_device_ids
from super_gradients.training.utils.ema import ModelEMA
from super_gradients.training.utils.optimizer_utils import build_optimizer
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
3 changes: 2 additions & 1 deletion src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,8 @@
from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.pretrained_models import MODEL_URLS
from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
from super_gradients.training.utils.distributed_training_utils import wait_for_the_master
from super_gradients.common.environment.ddp_utils import get_local_rank
from super_gradients.training.utils.utils import unwrap_model

try:
37 changes: 24 additions & 13 deletions src/super_gradients/training/utils/distributed_training_utils.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from super_gradients.common.deprecate import deprecated
from super_gradients.common.environment.ddp_utils import init_trainer
from super_gradients.common.data_types.enum import MultiGPUMode
from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
@@ -27,6 +28,14 @@
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.type_factory import TypeFactory

from super_gradients.common.environment.ddp_utils import get_local_rank as _get_local_rank
from super_gradients.common.environment.ddp_utils import is_ddp_subprocess as _is_ddp_subprocess
from super_gradients.common.environment.ddp_utils import get_world_size as _get_world_size
from super_gradients.common.environment.ddp_utils import get_device_ids as _get_device_ids
from super_gradients.common.environment.ddp_utils import count_used_devices as _count_used_devices
from super_gradients.common.environment.ddp_utils import require_ddp_setup as _require_ddp_setup


logger = get_logger(__name__)


@@ -145,40 +154,42 @@ def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoad
bn.momentum = momentums[i]


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_get_local_rank)
def get_local_rank():
"""
Returns the local rank if running in DDP, and 0 otherwise
:return: local rank
"""
return dist.get_rank() if dist.is_initialized() else 0


def require_ddp_setup() -> bool:
return device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and device_config.assigned_rank != get_local_rank()
return _get_local_rank()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_is_ddp_subprocess)
def is_ddp_subprocess():
return torch.distributed.get_rank() > 0 if dist.is_initialized() else False
return _is_ddp_subprocess()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_get_world_size)
def get_world_size() -> int:
"""
Returns the world size if running in DDP, and 1 otherwise
:return: world size
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
return _get_world_size()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_get_device_ids)
def get_device_ids() -> List[int]:
return list(range(get_world_size()))
return _get_device_ids()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_count_used_devices)
def count_used_devices() -> int:
return len(get_device_ids())
return _count_used_devices()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_require_ddp_setup)
def require_ddp_setup() -> bool:
return _require_ddp_setup()


@contextmanager
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from tqdm import tqdm

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
from super_gradients.common.environment.ddp_utils import get_local_rank, get_world_size
from torch.distributed import all_gather

from super_gradients.training.utils.utils import infer_model_device