diff --git a/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py b/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py index 0021fcff56..0d578c4fdd 100755 --- a/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py +++ b/src/super_gradients/common/data_interface/adnn_model_repository_data_interface.py @@ -13,7 +13,7 @@ class ADNNModelRepositoryDataInterfaces(ILogger): ResearchModelRepositoryDataInterface """ - def __init__(self, data_connection_location: str = 'local', data_connection_credentials: str = None): + def __init__(self, data_connection_location: str = "local", data_connection_credentials: str = None): """ ModelCheckpointsDataInterface :param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name' @@ -22,22 +22,22 @@ def __init__(self, data_connection_location: str = 'local', data_connection_cred AWS_PROFILE if left empty """ super().__init__() - self.tb_events_file_prefix = 'events.out.tfevents' - self.log_file_prefix = 'log_' - self.latest_checkpoint_filename = 'ckpt_latest.pth' - self.best_checkpoint_filename = 'ckpt_best.pth' + self.tb_events_file_prefix = "events.out.tfevents" + self.log_file_prefix = "log_" + self.latest_checkpoint_filename = "ckpt_latest.pth" + self.best_checkpoint_filename = "ckpt_best.pth" - if data_connection_location.startswith('s3'): - assert data_connection_location.index('s3://') >= 0, 'S3 path must be formatted s3://bucket-name' - self.model_repo_bucket_name = data_connection_location.split('://')[1] - self.data_connection_source = 's3' + if data_connection_location.startswith("s3"): + assert data_connection_location.index("s3://") >= 0, "S3 path must be formatted s3://bucket-name" + self.model_repo_bucket_name = data_connection_location.split("://")[1] + self.data_connection_source = "s3" if data_connection_credentials is None: - data_connection_credentials = os.getenv('AWS_PROFILE') + data_connection_credentials = os.getenv("AWS_PROFILE") self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name) - @explicit_params_validation(validation_type='None') + @explicit_params_validation(validation_type="None") def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str): """ load_all_remote_checkpoint_files @@ -45,12 +45,10 @@ def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: :param model_checkpoint_local_dir: :return: """ - self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, - logging_type='tensorboard') - self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, - logging_type='text') + self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="tensorboard") + self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="text") - @explicit_params_validation(validation_type='None') + @explicit_params_validation(validation_type="None") def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str): """ save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo @@ -64,9 +62,10 @@ def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_loc self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name) self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir) - @explicit_params_validation(validation_type='None') - def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, - ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False) -> str: + @explicit_params_validation(validation_type="None") + def load_remote_checkpoints_file( + self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False + ) -> str: """ load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file :param ckpt_source_remote_dir: The source folder to download from @@ -76,27 +75,26 @@ def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destina is to overwrite a previous version of the same files :return: Model Checkpoint File Path -> Depends on model architecture """ - ckpt_file_local_full_path = ckpt_destination_local_dir + '/' + ckpt_file_name + ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name - if self.data_connection_source == 's3': + if self.data_connection_source == "s3": if overwrite_local_checkpoints_file: # DELETE THE LOCAL VERSION ON THE MACHINE if os.path.exists(ckpt_file_local_full_path): os.remove(ckpt_file_local_full_path) - key_to_download = ckpt_source_remote_dir + '/' + ckpt_file_name - download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, - key_to_download=key_to_download) + key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name + download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, key_to_download=key_to_download) if not download_success: - failed_download_path = 's3://' + self.model_repo_bucket_name + '/' + key_to_download - error_msg = 'Failed to Download Model Checkpoint from ' + failed_download_path + failed_download_path = "s3://" + self.model_repo_bucket_name + "/" + key_to_download + error_msg = "Failed to Download Model Checkpoint from " + failed_download_path self._logger.error(error_msg) raise ModelCheckpointNotFoundException(error_msg) return ckpt_file_local_full_path - @explicit_params_validation(validation_type='NoneOrEmpty') + @explicit_params_validation(validation_type="NoneOrEmpty") def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str): """ load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository @@ -106,24 +104,23 @@ def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: :return: """ if not os.path.isdir(model_checkpoint_dir_name): - raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist') + raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist") # LOADS THE DATA FROM THE REMOTE REPOSITORY s3_bucket_path_prefix = model_name - if logging_type == 'tensorboard': - if self.data_connection_source == 's3': - self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix, - local_download_dir=model_checkpoint_dir_name, - s3_file_path_prefix=self.tb_events_file_prefix) - elif logging_type == 'text': - if self.data_connection_source == 's3': - self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix, - local_download_dir=model_checkpoint_dir_name, - s3_file_path_prefix=self.log_file_prefix) - - @explicit_params_validation(validation_type='NoneOrEmpty') - def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, - checkpoints_file_name: str) -> bool: + if logging_type == "tensorboard": + if self.data_connection_source == "s3": + self.s3_connector.download_keys_by_prefix( + s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.tb_events_file_prefix + ) + elif logging_type == "text": + if self.data_connection_source == "s3": + self.s3_connector.download_keys_by_prefix( + s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.log_file_prefix + ) + + @explicit_params_validation(validation_type="NoneOrEmpty") + def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, checkpoints_file_name: str) -> bool: """ save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo :param model_name: The Model Name for S3 Prefix @@ -132,14 +129,14 @@ def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_d :return: True/False for Operation Success/Failure """ # LOAD THE LOCAL VERSION - model_checkpoint_file_full_path = model_checkpoint_local_dir + '/' + checkpoints_file_name + model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name # SAVE ON THE REMOTE S3 REPOSITORY - if self.data_connection_source == 's3': - model_checkpoint_s3_in_bucket_path = model_name + '/' + checkpoints_file_name + if self.data_connection_source == "s3": + model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path) - @explicit_params_validation(validation_type='NoneOrEmpty') + @explicit_params_validation(validation_type="NoneOrEmpty") def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str): """ save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely @@ -147,18 +144,18 @@ def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_ :param model_checkpoint_dir_name: The directory where the files are stored in """ if not os.path.isdir(model_checkpoint_dir_name): - raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist') + raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist") for tb_events_file_name in os.listdir(model_checkpoint_dir_name): if tb_events_file_name.startswith(self.tb_events_file_prefix): - upload_success = self.save_remote_checkpoints_file(model_name=model_name, - model_checkpoint_local_dir=model_checkpoint_dir_name, - checkpoints_file_name=tb_events_file_name) + upload_success = self.save_remote_checkpoints_file( + model_name=model_name, model_checkpoint_local_dir=model_checkpoint_dir_name, checkpoints_file_name=tb_events_file_name + ) if not upload_success: - self._logger.error('Failed to upload tb_events_file: ' + tb_events_file_name) + self._logger.error("Failed to upload tb_events_file: " + tb_events_file_name) - @explicit_params_validation(validation_type='NoneOrEmpty') + @explicit_params_validation(validation_type="NoneOrEmpty") def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str): """ __update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path @@ -169,10 +166,10 @@ def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str): # DELETE KEY TO UPDATE THE FILE IN S3 delete_response = self.s3_connector.delete_key(s3_key_path) if delete_response: - self._logger.info('Removed previous checkpoint from S3') + self._logger.info("Removed previous checkpoint from S3") upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path) if not upload_success: - self._logger.error('Failed to upload model checkpoint') + self._logger.error("Failed to upload model checkpoint") return upload_success diff --git a/src/super_gradients/common/factories/callbacks_factory.py b/src/super_gradients/common/factories/callbacks_factory.py index 14f079653a..f0aa09719d 100644 --- a/src/super_gradients/common/factories/callbacks_factory.py +++ b/src/super_gradients/common/factories/callbacks_factory.py @@ -3,6 +3,5 @@ class CallbacksFactory(BaseFactory): - def __init__(self): super().__init__(CALLBACKS) diff --git a/src/super_gradients/common/factories/list_factory.py b/src/super_gradients/common/factories/list_factory.py index 14a20fb5a7..c5cf967983 100644 --- a/src/super_gradients/common/factories/list_factory.py +++ b/src/super_gradients/common/factories/list_factory.py @@ -4,7 +4,6 @@ class ListFactory(AbstractFactory): - def __init__(self, factry: AbstractFactory): self.factry = factry diff --git a/src/super_gradients/common/factories/losses_factory.py b/src/super_gradients/common/factories/losses_factory.py index d986817c20..923cc0592b 100644 --- a/src/super_gradients/common/factories/losses_factory.py +++ b/src/super_gradients/common/factories/losses_factory.py @@ -3,6 +3,5 @@ class LossesFactory(BaseFactory): - def __init__(self): super().__init__(LOSSES) diff --git a/src/super_gradients/common/factories/metrics_factory.py b/src/super_gradients/common/factories/metrics_factory.py index fee767a762..e94b54f20c 100644 --- a/src/super_gradients/common/factories/metrics_factory.py +++ b/src/super_gradients/common/factories/metrics_factory.py @@ -3,6 +3,5 @@ class MetricsFactory(BaseFactory): - def __init__(self): super().__init__(METRICS) diff --git a/src/super_gradients/common/factories/samplers_factory.py b/src/super_gradients/common/factories/samplers_factory.py index 8bb9e6803f..72439bfe2b 100644 --- a/src/super_gradients/common/factories/samplers_factory.py +++ b/src/super_gradients/common/factories/samplers_factory.py @@ -3,6 +3,5 @@ class SamplersFactory(BaseFactory): - def __init__(self): super().__init__(SAMPLERS)