Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group env variables #618

Merged
merged 9 commits into from
Jan 18, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging.config
from typing import Union

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig


Expand All @@ -11,7 +12,8 @@ def get_logger(logger_name: str, log_level: Union[str, None] = None) -> logging.
logger: logging.Logger = logging.getLogger(logger_name)
if log_level is not None:
logger.setLevel(log_level)
if int(os.getenv("LOCAL_RANK", -1)) > 0:

if int(env_variables.LOCAL_RANK) > 0:
mute_current_process()
return logger

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import atexit

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.crash_handler.exception import ExceptionInfo
from super_gradients.common.abstractions.abstract_logger import get_logger

Expand All @@ -15,7 +15,7 @@ def crash_tip_handler():


def setup_crash_tips() -> bool:
if os.getenv("CRASH_HANDLER", "TRUE") != "FALSE":
if env_variables.CRASH_HANDLER != "FALSE":
logger.info("Crash tips is enabled. You can set your environment variable to CRASH_HANDLER=FALSE to disable it")
atexit.register(crash_tip_handler)
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import atexit


from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.environment.ddp_utils import multi_process_safe, is_distributed
from super_gradients.common.crash_handler.exception import ExceptionInfo
from super_gradients.common.auto_logging.console_logging import ConsoleSink
Expand Down Expand Up @@ -41,7 +41,8 @@ def exception_upload_handler(platform_client):
def setup_pro_user_monitoring() -> bool:
"""Setup the pro user environment for error logging and monitoring"""
if _imported_deci_lab_failure is None:
upload_console_logs = os.getenv("UPLOAD_LOGS", "TRUE") == "TRUE"

upload_console_logs = env_variables.UPLOAD_LOGS == "TRUE"
if upload_console_logs:
logger.info("deci-lab-client package detected. activating automatic log uploading")
logger.info("If you do not have a deci-lab-client credentials or you wish to disable this feature, please set the env variable UPLOAD_LOGS=FALSE")
Expand All @@ -52,7 +53,8 @@ def setup_pro_user_monitoring() -> bool:

logger.info("Connecting to the deci platform ...")
platform_client = DeciPlatformClient()
platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))

platform_client.login(token=env_variables.DECI_PLATFORM_TOKEN)
logger.info("Connection to the deci platform successful!")

atexit.register(exception_upload_handler, platform_client)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import sys

from super_gradients.common import S3Connector, explicit_params_validation
from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.abstractions.abstract_logger import ILogger


Expand Down Expand Up @@ -33,7 +35,8 @@ def __init__(self, data_connection_location: str = "local", data_connection_cred
self.data_connection_source = "s3"

if data_connection_credentials is None:
data_connection_credentials = os.getenv("AWS_PROFILE")

data_connection_credentials = env_variables.AWS_PROFILE

self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)

Expand Down
8 changes: 4 additions & 4 deletions src/super_gradients/common/environment/device_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import dataclasses

import torch

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.environment.argparse_utils import pop_local_rank


Expand All @@ -11,8 +11,8 @@

def _get_assigned_rank() -> int:
"""Get the rank assigned by DDP launcher. If not DDP subprocess, return -1."""
if os.getenv("LOCAL_RANK") is not None:
return int(os.getenv("LOCAL_RANK"))
if env_variables.LOCAL_RANK != -1:
return env_variables.LOCAL_RANK
else:
return pop_local_rank()

Expand All @@ -21,7 +21,7 @@ def _get_assigned_rank() -> int:
class DeviceConfig:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
multi_gpu: str = None
assigned_rank: str = dataclasses.field(default=_get_assigned_rank(), init=False)
assigned_rank: int = dataclasses.field(default=_get_assigned_rank(), init=False)


# Singleton holding the device information
Expand Down
35 changes: 35 additions & 0 deletions src/super_gradients/common/environment/env_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os


class EnvironmentVariables:
"""Class to dynamically get any environment variables."""

# Infra
@property
def DECI_PLATFORM_TOKEN(self):
return os.getenv("DECI_PLATFORM_TOKEN")

@property
def WANDB_BASE_URL(self):
return os.getenv("WANDB_BASE_URL")

@property
def AWS_PROFILE(self):
return os.getenv("AWS_PROFILE")

# DDP
@property
def LOCAL_RANK(self):
return int(os.getenv("LOCAL_RANK", -1))

# Turn ON/OFF features
@property
def CRASH_HANDLER(self):
return os.getenv("CRASH_HANDLER", "TRUE")

@property
def UPLOAD_LOGS(self):
return os.getenv("UPLOAD_LOGS", "TRUE")


env_variables = EnvironmentVariables()
4 changes: 3 additions & 1 deletion src/super_gradients/common/plugins/deci_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from omegaconf import DictConfig
from torch import nn

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils.hydra_utils import normalize_path

Expand Down Expand Up @@ -140,7 +141,8 @@ def upload_model(self, model: nn.Module, model_meta_data, optimization_request_f
model_meta_data: Metadata to accompany the model
optimization_request_form: The optimization parameters
"""
self.lab_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))

self.lab_client.login(token=env_variables.DECI_PLATFORM_TOKEN)
self.lab_client.add_model(
add_model_request=model_meta_data,
optimization_request=optimization_request_form,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Optional

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
from super_gradients.common.environment.ddp_utils import multi_process_safe
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(
)

self.platform_client = DeciPlatformClient()
self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))
self.platform_client.login(token=env_variables.DECI_PLATFORM_TOKEN)
if model_name is None:
logger.warning(
"'model_name' parameter not passed. "
Expand Down
4 changes: 3 additions & 1 deletion src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from PIL import Image
import matplotlib.pyplot as plt
import torch

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.abstractions.abstract_logger import get_logger

from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger
Expand Down Expand Up @@ -79,7 +81,7 @@ def __init__(
)

if api_server is not None:
if api_server != os.getenv("WANDB_BASE_URL"):
if api_server != env_variables.WANDB_BASE_URL:
logger.warning(f"WANDB_BASE_URL environment parameter not set to {api_server}. Setting the parameter")
os.environ["WANDB_BASE_URL"] = api_server

Expand Down
4 changes: 3 additions & 1 deletion src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import onnxruntime
import torch

from super_gradients.common.environment.env_variables import env_variables
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback
Expand Down Expand Up @@ -153,7 +154,8 @@ def __init__(self, model_meta_data, optimization_request_form, ckpt_name="ckpt_b
self.conversion_kwargs = kwargs
self.ckpt_name = ckpt_name
self.platform_client = DeciPlatformClient("api.deci.ai", 443, https=True)
self.platform_client.login(token=os.getenv("DECI_PLATFORM_TOKEN"))

self.platform_client.login(token=env_variables.DECI_PLATFORM_TOKEN)

@staticmethod
def log_optimization_failed():
Expand Down