diff --git a/src/super_gradients/common/abstractions/abstract_logger.py b/src/super_gradients/common/abstractions/abstract_logger.py index 18e5602f45..5ad58890e9 100755 --- a/src/super_gradients/common/abstractions/abstract_logger.py +++ b/src/super_gradients/common/abstractions/abstract_logger.py @@ -3,6 +3,7 @@ import logging.config from typing import Union +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig @@ -11,7 +12,8 @@ def get_logger(logger_name: str, log_level: Union[str, None] = None) -> logging. logger: logging.Logger = logging.getLogger(logger_name) if log_level is not None: logger.setLevel(log_level) - if int(os.getenv("LOCAL_RANK", -1)) > 0: + + if int(env_variables.LOCAL_RANK) > 0: mute_current_process() return logger diff --git a/src/super_gradients/common/crash_handler/crash_tips_setup.py b/src/super_gradients/common/crash_handler/crash_tips_setup.py index 88307674b6..0c52c603a5 100644 --- a/src/super_gradients/common/crash_handler/crash_tips_setup.py +++ b/src/super_gradients/common/crash_handler/crash_tips_setup.py @@ -1,6 +1,6 @@ -import os import atexit +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.crash_handler.exception import ExceptionInfo from super_gradients.common.abstractions.abstract_logger import get_logger @@ -15,7 +15,7 @@ def crash_tip_handler(): def setup_crash_tips() -> bool: - if os.getenv("CRASH_HANDLER", "TRUE") != "FALSE": + if env_variables.CRASH_HANDLER != "FALSE": logger.info("Crash tips is enabled. You can set your environment variable to CRASH_HANDLER=FALSE to disable it") atexit.register(crash_tip_handler) return True diff --git a/src/super_gradients/common/crash_handler/exception_monitoring_setup.py b/src/super_gradients/common/crash_handler/exception_monitoring_setup.py index bad4fbddc6..9bb88a6895 100644 --- a/src/super_gradients/common/crash_handler/exception_monitoring_setup.py +++ b/src/super_gradients/common/crash_handler/exception_monitoring_setup.py @@ -2,7 +2,7 @@ import logging import atexit - +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.environment.ddp_utils import multi_process_safe, is_distributed from super_gradients.common.crash_handler.exception import ExceptionInfo from super_gradients.common.auto_logging.console_logging import ConsoleSink @@ -41,7 +41,8 @@ def exception_upload_handler(platform_client): def setup_pro_user_monitoring() -> bool: """Setup the pro user environment for error logging and monitoring""" if _imported_deci_lab_failure is None: - upload_console_logs = os.getenv("UPLOAD_LOGS", "TRUE") == "TRUE" + + upload_console_logs = env_variables.UPLOAD_LOGS == "TRUE" if upload_console_logs: logger.info("deci-lab-client package detected. activating automatic log uploading") logger.info("If you do not have a deci-lab-client credentials or you wish to disable this feature, please set the env variable UPLOAD_LOGS=FALSE") @@ -52,7 +53,8 @@ def setup_pro_user_monitoring() -> bool: logger.info("Connecting to the deci platform ...") platform_client = DeciPlatformClient() - platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN")) + + platform_client.login(token=env_variables.DECI_PLATFORM_TOKEN) logger.info("Connection to the deci platform successful!") atexit.register(exception_upload_handler, platform_client) diff --git a/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py b/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py index 0d578c4fdd..bf92420b53 100755 --- a/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py +++ b/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py @@ -1,6 +1,8 @@ import os import sys + from super_gradients.common import S3Connector, explicit_params_validation +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.abstractions.abstract_logger import ILogger @@ -33,7 +35,8 @@ def __init__(self, data_connection_location: str = "local", data_connection_cred self.data_connection_source = "s3" if data_connection_credentials is None: - data_connection_credentials = os.getenv("AWS_PROFILE") + + data_connection_credentials = env_variables.AWS_PROFILE self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name) diff --git a/src/super_gradients/common/environment/device_utils.py b/src/super_gradients/common/environment/device_utils.py index 3b05fc4c65..519310c4cb 100644 --- a/src/super_gradients/common/environment/device_utils.py +++ b/src/super_gradients/common/environment/device_utils.py @@ -1,8 +1,8 @@ -import os import dataclasses import torch +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.environment.argparse_utils import pop_local_rank @@ -11,8 +11,8 @@ def _get_assigned_rank() -> int: """Get the rank assigned by DDP launcher. If not DDP subprocess, return -1.""" - if os.getenv("LOCAL_RANK") is not None: - return int(os.getenv("LOCAL_RANK")) + if env_variables.LOCAL_RANK != -1: + return env_variables.LOCAL_RANK else: return pop_local_rank() @@ -21,7 +21,7 @@ def _get_assigned_rank() -> int: class DeviceConfig: device: str = "cuda" if torch.cuda.is_available() else "cpu" multi_gpu: str = None - assigned_rank: str = dataclasses.field(default=_get_assigned_rank(), init=False) + assigned_rank: int = dataclasses.field(default=_get_assigned_rank(), init=False) # Singleton holding the device information diff --git a/src/super_gradients/common/environment/env_variables.py b/src/super_gradients/common/environment/env_variables.py new file mode 100644 index 0000000000..800e114062 --- /dev/null +++ b/src/super_gradients/common/environment/env_variables.py @@ -0,0 +1,35 @@ +import os + + +class EnvironmentVariables: + """Class to dynamically get any environment variables.""" + + # Infra + @property + def DECI_PLATFORM_TOKEN(self): + return os.getenv("DECI_PLATFORM_TOKEN") + + @property + def WANDB_BASE_URL(self): + return os.getenv("WANDB_BASE_URL") + + @property + def AWS_PROFILE(self): + return os.getenv("AWS_PROFILE") + + # DDP + @property + def LOCAL_RANK(self): + return int(os.getenv("LOCAL_RANK", -1)) + + # Turn ON/OFF features + @property + def CRASH_HANDLER(self): + return os.getenv("CRASH_HANDLER", "TRUE") + + @property + def UPLOAD_LOGS(self): + return os.getenv("UPLOAD_LOGS", "TRUE") + + +env_variables = EnvironmentVariables() diff --git a/src/super_gradients/common/plugins/deci_client.py b/src/super_gradients/common/plugins/deci_client.py index 0d1a9fdb6b..b4b84d368b 100644 --- a/src/super_gradients/common/plugins/deci_client.py +++ b/src/super_gradients/common/plugins/deci_client.py @@ -12,6 +12,7 @@ from omegaconf import DictConfig from torch import nn +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.training.utils.hydra_utils import normalize_path @@ -140,7 +141,8 @@ def upload_model(self, model: nn.Module, model_meta_data, optimization_request_f model_meta_data: Metadata to accompany the model optimization_request_form: The optimization parameters """ - self.lab_client.login(token=os.getenv("DECI_PLATFORM_TOKEN")) + + self.lab_client.login(token=env_variables.DECI_PLATFORM_TOKEN) self.lab_client.add_model( add_model_request=model_meta_data, optimization_request=optimization_request_form, diff --git a/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py b/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py index 67fbaea870..818df79d36 100644 --- a/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py +++ b/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py @@ -1,6 +1,7 @@ import os from typing import Optional +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger from super_gradients.common.environment.ddp_utils import multi_process_safe @@ -60,7 +61,7 @@ def __init__( ) self.platform_client = DeciPlatformClient() - self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN")) + self.platform_client.login(token=env_variables.DECI_PLATFORM_TOKEN) if model_name is None: logger.warning( "'model_name' parameter not passed. " diff --git a/src/super_gradients/common/sg_loggers/wandb_sg_logger.py b/src/super_gradients/common/sg_loggers/wandb_sg_logger.py index 8c9d90573f..5cdf9196d3 100644 --- a/src/super_gradients/common/sg_loggers/wandb_sg_logger.py +++ b/src/super_gradients/common/sg_loggers/wandb_sg_logger.py @@ -6,6 +6,8 @@ from PIL import Image import matplotlib.pyplot as plt import torch + +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger @@ -79,7 +81,7 @@ def __init__( ) if api_server is not None: - if api_server != os.getenv("WANDB_BASE_URL"): + if api_server != env_variables.WANDB_BASE_URL: logger.warning(f"WANDB_BASE_URL environment parameter not set to {api_server}. Setting the parameter") os.environ["WANDB_BASE_URL"] = api_server diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 3cfe568462..ff3e8b63d3 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -11,6 +11,7 @@ import onnxruntime import torch +from super_gradients.common.environment.env_variables import env_variables from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback @@ -153,7 +154,8 @@ def __init__(self, model_meta_data, optimization_request_form, ckpt_name="ckpt_b self.conversion_kwargs = kwargs self.ckpt_name = ckpt_name self.platform_client = DeciPlatformClient("api.deci.ai", 443, https=True) - self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN")) + + self.platform_client.login(token=env_variables.DECI_PLATFORM_TOKEN) @staticmethod def log_optimization_failed():