diff --git a/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py b/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py index c7729704e4..ebd9c99f18 100644 --- a/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py +++ b/src/super_gradients/common/sg_loggers/deci_platform_sg_logger.py @@ -1,10 +1,13 @@ import os +import io +from contextlib import contextmanager from typing import Optional from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.sg_loggers.base_sg_logger import BaseSGLogger, EXPERIMENT_LOGS_PREFIX, LOGGER_LOGS_PREFIX, CONSOLE_LOGS_PREFIX from super_gradients.common.environment.ddp_utils import multi_process_safe from super_gradients.common.plugins.deci_client import DeciClient +from contextlib import redirect_stdout logger = get_logger(__name__) @@ -91,8 +94,7 @@ def _upload_latest_file_starting_with(self, start_with: str): ] most_recent_file_path = max(files_path, key=os.path.getctime) - self.platform_client.save_experiment_file(file_path=most_recent_file_path) - logger.info(f"File saved to Deci platform: {most_recent_file_path}") + self._save_experiment_file(file_path=most_recent_file_path) @multi_process_safe def _upload_folder_files(self, folder_name: str): @@ -107,5 +109,21 @@ def _upload_folder_files(self, folder_name: str): return for file in os.listdir(folder_path): - self.platform_client.save_experiment_file(file_path=f"{folder_path}/{file}") - logger.info(f"File saved to Deci platform: {folder_path}/{file}") + self._save_experiment_file(file_path=f"{folder_path}/{file}") + + def _save_experiment_file(self, file_path: str): + with log_stdout(): # TODO: remove when platform_client remove prints from save_experiment_file + self.platform_client.save_experiment_file(file_path=file_path) + logger.info(f"File saved to Deci platform: {file_path}") + + +@contextmanager +def log_stdout(): + """Redirect stdout to DEBUG.""" + buffer = io.StringIO() + with redirect_stdout(buffer): + yield + + redirected_str = buffer.getvalue() + if redirected_str: + logger.debug(msg=redirected_str)