Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Black on factories and data_interface #620

Merged
merged 2 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ADNNModelRepositoryDataInterfaces(ILogger):
ResearchModelRepositoryDataInterface
"""

def __init__(self, data_connection_location: str = 'local', data_connection_credentials: str = None):
def __init__(self, data_connection_location: str = "local", data_connection_credentials: str = None):
"""
ModelCheckpointsDataInterface
:param data_connection_location: 'local' or s3 bucket 's3://my-bucket-name'
Expand All @@ -22,35 +22,33 @@ def __init__(self, data_connection_location: str = 'local', data_connection_cred
AWS_PROFILE if left empty
"""
super().__init__()
self.tb_events_file_prefix = 'events.out.tfevents'
self.log_file_prefix = 'log_'
self.latest_checkpoint_filename = 'ckpt_latest.pth'
self.best_checkpoint_filename = 'ckpt_best.pth'
self.tb_events_file_prefix = "events.out.tfevents"
self.log_file_prefix = "log_"
self.latest_checkpoint_filename = "ckpt_latest.pth"
self.best_checkpoint_filename = "ckpt_best.pth"

if data_connection_location.startswith('s3'):
assert data_connection_location.index('s3://') >= 0, 'S3 path must be formatted s3://bucket-name'
self.model_repo_bucket_name = data_connection_location.split('://')[1]
self.data_connection_source = 's3'
if data_connection_location.startswith("s3"):
assert data_connection_location.index("s3://") >= 0, "S3 path must be formatted s3://bucket-name"
self.model_repo_bucket_name = data_connection_location.split("://")[1]
self.data_connection_source = "s3"

if data_connection_credentials is None:
data_connection_credentials = os.getenv('AWS_PROFILE')
data_connection_credentials = os.getenv("AWS_PROFILE")

self.s3_connector = S3Connector(data_connection_credentials, self.model_repo_bucket_name)

@explicit_params_validation(validation_type='None')
@explicit_params_validation(validation_type="None")
def load_all_remote_log_files(self, model_name: str, model_checkpoint_local_dir: str):
"""
load_all_remote_checkpoint_files
:param model_name:
:param model_checkpoint_local_dir:
:return:
"""
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
logging_type='tensorboard')
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir,
logging_type='text')
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="tensorboard")
self.load_remote_logging_files(model_name=model_name, model_checkpoint_dir_name=model_checkpoint_local_dir, logging_type="text")

@explicit_params_validation(validation_type='None')
@explicit_params_validation(validation_type="None")
def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_local_dir: str, log_file_name: str):
"""
save_all_remote_checkpoint_files - Saves all of the local Checkpoint data into Remote Repo
Expand All @@ -64,9 +62,10 @@ def save_all_remote_checkpoint_files(self, model_name: str, model_checkpoint_loc
self.save_remote_checkpoints_file(model_name, model_checkpoint_local_dir, log_file_name)
self.save_remote_tensorboard_event_files(model_name, model_checkpoint_local_dir)

@explicit_params_validation(validation_type='None')
def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str,
ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False) -> str:
@explicit_params_validation(validation_type="None")
def load_remote_checkpoints_file(
self, ckpt_source_remote_dir: str, ckpt_destination_local_dir: str, ckpt_file_name: str, overwrite_local_checkpoints_file: bool = False
) -> str:
"""
load_remote_checkpoints_file - Loads a model's checkpoint from local/cloud file
:param ckpt_source_remote_dir: The source folder to download from
Expand All @@ -76,27 +75,26 @@ def load_remote_checkpoints_file(self, ckpt_source_remote_dir: str, ckpt_destina
is to overwrite a previous version of the same files
:return: Model Checkpoint File Path -> Depends on model architecture
"""
ckpt_file_local_full_path = ckpt_destination_local_dir + '/' + ckpt_file_name
ckpt_file_local_full_path = ckpt_destination_local_dir + "/" + ckpt_file_name

if self.data_connection_source == 's3':
if self.data_connection_source == "s3":
if overwrite_local_checkpoints_file:
# DELETE THE LOCAL VERSION ON THE MACHINE
if os.path.exists(ckpt_file_local_full_path):
os.remove(ckpt_file_local_full_path)

key_to_download = ckpt_source_remote_dir + '/' + ckpt_file_name
download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path,
key_to_download=key_to_download)
key_to_download = ckpt_source_remote_dir + "/" + ckpt_file_name
download_success = self.s3_connector.download_key(target_path=ckpt_file_local_full_path, key_to_download=key_to_download)

if not download_success:
failed_download_path = 's3://' + self.model_repo_bucket_name + '/' + key_to_download
error_msg = 'Failed to Download Model Checkpoint from ' + failed_download_path
failed_download_path = "s3://" + self.model_repo_bucket_name + "/" + key_to_download
error_msg = "Failed to Download Model Checkpoint from " + failed_download_path
self._logger.error(error_msg)
raise ModelCheckpointNotFoundException(error_msg)

return ckpt_file_local_full_path

@explicit_params_validation(validation_type='NoneOrEmpty')
@explicit_params_validation(validation_type="NoneOrEmpty")
def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name: str, logging_type: str):
"""
load_remote_tensorboard_event_files - Downloads all of the tb_events Files from remote repository
Expand All @@ -106,24 +104,23 @@ def load_remote_logging_files(self, model_name: str, model_checkpoint_dir_name:
:return:
"""
if not os.path.isdir(model_checkpoint_dir_name):
raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

# LOADS THE DATA FROM THE REMOTE REPOSITORY
s3_bucket_path_prefix = model_name
if logging_type == 'tensorboard':
if self.data_connection_source == 's3':
self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
local_download_dir=model_checkpoint_dir_name,
s3_file_path_prefix=self.tb_events_file_prefix)
elif logging_type == 'text':
if self.data_connection_source == 's3':
self.s3_connector.download_keys_by_prefix(s3_bucket_path_prefix=s3_bucket_path_prefix,
local_download_dir=model_checkpoint_dir_name,
s3_file_path_prefix=self.log_file_prefix)

@explicit_params_validation(validation_type='NoneOrEmpty')
def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str,
checkpoints_file_name: str) -> bool:
if logging_type == "tensorboard":
if self.data_connection_source == "s3":
self.s3_connector.download_keys_by_prefix(
s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.tb_events_file_prefix
)
elif logging_type == "text":
if self.data_connection_source == "s3":
self.s3_connector.download_keys_by_prefix(
s3_bucket_path_prefix=s3_bucket_path_prefix, local_download_dir=model_checkpoint_dir_name, s3_file_path_prefix=self.log_file_prefix
)

@explicit_params_validation(validation_type="NoneOrEmpty")
def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_dir: str, checkpoints_file_name: str) -> bool:
"""
save_remote_checkpoints_file - Saves a Checkpoints file in the Remote Repo
:param model_name: The Model Name for S3 Prefix
Expand All @@ -132,33 +129,33 @@ def save_remote_checkpoints_file(self, model_name: str, model_checkpoint_local_d
:return: True/False for Operation Success/Failure
"""
# LOAD THE LOCAL VERSION
model_checkpoint_file_full_path = model_checkpoint_local_dir + '/' + checkpoints_file_name
model_checkpoint_file_full_path = model_checkpoint_local_dir + "/" + checkpoints_file_name

# SAVE ON THE REMOTE S3 REPOSITORY
if self.data_connection_source == 's3':
model_checkpoint_s3_in_bucket_path = model_name + '/' + checkpoints_file_name
if self.data_connection_source == "s3":
model_checkpoint_s3_in_bucket_path = model_name + "/" + checkpoints_file_name
return self.__update_or_upload_s3_key(model_checkpoint_file_full_path, model_checkpoint_s3_in_bucket_path)

@explicit_params_validation(validation_type='NoneOrEmpty')
@explicit_params_validation(validation_type="NoneOrEmpty")
def save_remote_tensorboard_event_files(self, model_name: str, model_checkpoint_dir_name: str):
"""
save_remote_tensorboard_event_files - Saves all of the tensorboard files remotely
:param model_name: Prefix for Cloud Storage
:param model_checkpoint_dir_name: The directory where the files are stored in
"""
if not os.path.isdir(model_checkpoint_dir_name):
raise ValueError('[' + sys._getframe().f_code.co_name + '] - Provided directory does not exist')
raise ValueError("[" + sys._getframe().f_code.co_name + "] - Provided directory does not exist")

for tb_events_file_name in os.listdir(model_checkpoint_dir_name):
if tb_events_file_name.startswith(self.tb_events_file_prefix):
upload_success = self.save_remote_checkpoints_file(model_name=model_name,
model_checkpoint_local_dir=model_checkpoint_dir_name,
checkpoints_file_name=tb_events_file_name)
upload_success = self.save_remote_checkpoints_file(
model_name=model_name, model_checkpoint_local_dir=model_checkpoint_dir_name, checkpoints_file_name=tb_events_file_name
)

if not upload_success:
self._logger.error('Failed to upload tb_events_file: ' + tb_events_file_name)
self._logger.error("Failed to upload tb_events_file: " + tb_events_file_name)

@explicit_params_validation(validation_type='NoneOrEmpty')
@explicit_params_validation(validation_type="NoneOrEmpty")
def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
"""
__update_or_upload_s3_key - Uploads/Updates an S3 Key based on a local file path
Expand All @@ -169,10 +166,10 @@ def __update_or_upload_s3_key(self, local_file_path: str, s3_key_path: str):
# DELETE KEY TO UPDATE THE FILE IN S3
delete_response = self.s3_connector.delete_key(s3_key_path)
if delete_response:
self._logger.info('Removed previous checkpoint from S3')
self._logger.info("Removed previous checkpoint from S3")

upload_success = self.s3_connector.upload_file(local_file_path, s3_key_path)
if not upload_success:
self._logger.error('Failed to upload model checkpoint')
self._logger.error("Failed to upload model checkpoint")

return upload_success
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class CallbacksFactory(BaseFactory):

def __init__(self):
super().__init__(CALLBACKS)
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/list_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class ListFactory(AbstractFactory):

def __init__(self, factry: AbstractFactory):
self.factry = factry

Expand Down
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/losses_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class LossesFactory(BaseFactory):

def __init__(self):
super().__init__(LOSSES)
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/metrics_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class MetricsFactory(BaseFactory):

def __init__(self):
super().__init__(METRICS)
1 change: 0 additions & 1 deletion src/super_gradients/common/factories/samplers_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@


class SamplersFactory(BaseFactory):

def __init__(self):
super().__init__(SAMPLERS)