Skip to content

Commit

Permalink
black on factories and data_interface (#620)
Browse files Browse the repository at this point in the history
Co-authored-by: Ofri Masad <ofrimasad@users.noreply.github.com>
  • Loading branch information
Louis-Dupont and ofrimasad authored Jan 12, 2023
1 parent d6f84a7 commit b31f3a8
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
ResearchModelRepositoryDataInterface
"""

def __init__(self, data_connection_location: str = 'local', data_connection_credentials: str = None):
def __init__(self, data_connection_location: str = "local", data_connection_credentials: str = None):
"""
ModelCheckpointsDataInterface
:param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
Expand All @@ -22,35 +22,33 @@ def __init__(self, data_connection_location: str = 'local', data_connection_cred
AWS_PROFILE if left empty
"""
super().__init__()
self.tb_events_file_prefix = 'events.out.tfevents'
self.log_file_prefix = 'log_'
self.latest_checkpoint_filename = 'ckpt_latest.pth'
self.best_checkpoint_filename = 'ckpt_best.pth'
self.tb_events_file_prefix = "events.out.tfevents"
self.log_file_prefix = "log_"
self.latest_checkpoint_filename = "ckpt_latest.pth"
self.best_checkpoint_filename = "ckpt_best.pth"

if data_connection_location.startswith('s3'):
assert data_connection_location.index('s3://') >= 0, 'S3 path must be formatted s3://bucket-name'
self.model_repo_bucket_name = data_connection_location.split('://')[1]
self.data_connection_source = 's3'
if data_connection_location.startswith("s3"):
assert data_connection_location.index("s3://") >= 0, "S3 path must be formatted s3://bucket-name"
self.model_repo_bucket_name = data_connection_location.split("://")[1]
self.data_connection_source = "s3"

if data_connection_credentials is None:
data_connection_credentials = os.getenv('AWS_PROFILE')
data_connection_credentials = os.getenv("AWS_PROFILE")

self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)

@explicit_params_validation(validation_type='None')
@explicit_params_validation(validation_type="None")
def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str):
"""
load_all_remote_checkpoint_files
:param model_name:
:param model_checkpoint_local_dir:
:return:
"""
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
logging_type='tensorboard')
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
logging_type='text')
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="tensorboard")
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="text")

@explicit_params_validation(validation_type='None')
@explicit_params_validation(validation_type="None")
def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str):
"""
save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
Expand All @@ -64,9 +62,10 @@ def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_loc
self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name)
self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir)

@explicit_params_validation(validation_type='None')
def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str,
ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False) -> str:
@explicit_params_validation(validation_type="None")
def load_remote_checkpoints_file(
self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False
) -> str:
"""
load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
:param ckpt_source_remote_dir: The source folder to download from
Expand All @@ -76,27 +75,26 @@ def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destina
is to overwrite a previous version of the same files
:return: Model Checkpoint File Path -> Depends on model architecture
"""
ckpt_file_local_full_path = ckpt_destination_local_dir + '/' + ckpt_file_name
ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name

if self.data_connection_source == 's3':
if self.data_connection_source == "s3":
if overwrite_local_checkpoints_file:
# DELETE THE LOCAL VERSION ON THE MACHINE
if os.path.exists(ckpt_file_local_full_path):
os.remove(ckpt_file_local_full_path)

key_to_download = ckpt_source_remote_dir + '/' + ckpt_file_name
download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path,
key_to_download=key_to_download)
key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name
download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, key_to_download=key_to_download)

if not download_success:
failed_download_path = 's3://' + self.model_repo_bucket_name + '/' + key_to_download
error_msg = 'Failed to Download Model Checkpoint from ' + failed_download_path
failed_download_path = "s3://" + self.model_repo_bucket_name + "/" + key_to_download
error_msg = "Failed to Download Model Checkpoint from " + failed_download_path
self._logger.error(error_msg)
raise ModelCheckpointNotFoundException(error_msg)

return ckpt_file_local_full_path

@explicit_params_validation(validation_type='NoneOrEmpty')
@explicit_params_validation(validation_type="NoneOrEmpty")
def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str):
"""
load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
Expand All @@ -106,24 +104,23 @@ def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name:
:return:
"""
if not os.path.isdir(model_checkpoint_dir_name):
raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

# LOADS THE DATA FROM THE REMOTE REPOSITORY
s3_bucket_path_prefix = model_name
if logging_type == 'tensorboard':
if self.data_connection_source == 's3':
self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
local_download_dir=model_checkpoint_dir_name,
s3_file_path_prefix=self.tb_events_file_prefix)
elif logging_type == 'text':
if self.data_connection_source == 's3':
self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
local_download_dir=model_checkpoint_dir_name,
s3_file_path_prefix=self.log_file_prefix)

@explicit_params_validation(validation_type='NoneOrEmpty')
def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str,
checkpoints_file_name: str) -> bool:
if logging_type == "tensorboard":
if self.data_connection_source == "s3":
self.s3_connector.download_keys_by_prefix(
s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.tb_events_file_prefix
)
elif logging_type == "text":
if self.data_connection_source == "s3":
self.s3_connector.download_keys_by_prefix(
s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.log_file_prefix
)

@explicit_params_validation(validation_type="NoneOrEmpty")
def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, checkpoints_file_name: str) -> bool:
"""
save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
:param model_name: The Model Name for S3 Prefix
Expand All @@ -132,33 +129,33 @@ def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_d
:return: True/False for Operation Success/Failure
"""
# LOAD THE LOCAL VERSION
model_checkpoint_file_full_path = model_checkpoint_local_dir + '/' + checkpoints_file_name
model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name

# SAVE ON THE REMOTE S3 REPOSITORY
if self.data_connection_source == 's3':
model_checkpoint_s3_in_bucket_path = model_name + '/' + checkpoints_file_name
if self.data_connection_source == "s3":
model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name
return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path)

@explicit_params_validation(validation_type='NoneOrEmpty')
@explicit_params_validation(validation_type="NoneOrEmpty")
def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str):
"""
save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
:param model_name: Prefix for Cloud Storage
:param model_checkpoint_dir_name: The directory where the files are stored in
"""
if not os.path.isdir(model_checkpoint_dir_name):
raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

for tb_events_file_name in os.listdir(model_checkpoint_dir_name):
if tb_events_file_name.startswith(self.tb_events_file_prefix):
upload_success = self.save_remote_checkpoints_file(model_name=model_name,
model_checkpoint_local_dir=model_checkpoint_dir_name,
checkpoints_file_name=tb_events_file_name)
upload_success = self.save_remote_checkpoints_file(
model_name=model_name, model_checkpoint_local_dir=model_checkpoint_dir_name, checkpoints_file_name=tb_events_file_name
)

if not upload_success:
self._logger.error('Failed to upload tb_events_file: ' + tb_events_file_name)
self._logger.error("Failed to upload tb_events_file: " + tb_events_file_name)

@explicit_params_validation(validation_type='NoneOrEmpty')
@explicit_params_validation(validation_type="NoneOrEmpty")
def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
"""
__update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
Expand All @@ -169,10 +166,10 @@ def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
# DELETE KEY TO UPDATE THE FILE IN S3
delete_response = self.s3_connector.delete_key(s3_key_path)
if delete_response:
self._logger.info('Removed previous checkpoint from S3')
self._logger.info("Removed previous checkpoint from S3")

upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path)
if not upload_success:
self._logger.error('Failed to upload model checkpoint')
self._logger.error("Failed to upload model checkpoint")

return upload_success
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/callbacks_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class CallbacksFactory(BaseFactory):

def __init__(self):
super().__init__(CALLBACKS)
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/list_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class ListFactory(AbstractFactory):

def __init__(self, factry: AbstractFactory):
self.factry = factry

Expand Down
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/losses_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class LossesFactory(BaseFactory):

def __init__(self):
super().__init__(LOSSES)
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/metrics_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class MetricsFactory(BaseFactory):

def __init__(self):
super().__init__(METRICS)
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/samplers_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class SamplersFactory(BaseFactory):

def __init__(self):
super().__init__(SAMPLERS)

0 comments on commit b31f3a8

Please sign in to comment.