Skip to content

Commit

Permalink
allow model to be loaded locally by the new settings use_local_model …
Browse files Browse the repository at this point in the history
…and local_model_path
  • Loading branch information
2320sharon committed Nov 22, 2024
1 parent 5ee2749 commit f414db4
Showing 1 changed file with 38 additions and 4 deletions.
42 changes: 38 additions & 4 deletions src/coastseg/zoo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ def set_settings(self, **kwargs):
"use_GPU": "0",
"implementation": "BEST",
"model_type": "global_segformer_RGB_4class_14036903",
"local_model_path": "", # local path to the directory containing the model
"use_local_model": True, # Use local model (not one from zeneodo)
"otsu": False,
"tta": False,
"cloud_thresh": 0.5, # threshold on maximum cloud cover
Expand Down Expand Up @@ -995,6 +997,39 @@ def postprocess_data(
file_utilities.move_files(outputs_path, session_path, delete_src=True)
session.save(session.path)

def get_weights_directory(self,model_implementation:str, model_id: str) -> str:
"""
Retrieves the directory path where the model weights are stored.
This method determines whether to use a local model path or to download the model
from a remote source based on the settings provided. If the local model path is
specified and exists, it will use that path. Otherwise, it will create a directory
for the model and download the weights.
Args:
model_implementation (str): The implementation type of the model either 'BEST' or 'ENSEMBLE'
model_id (str): The identifier for the model. This is the zenodo ID located at the end of the URL
Returns:
str: The directory path where the model weights are stored.
Raises:
FileNotFoundError: If the local model path is specified but does not exist.
"""

USE_LOCAL_MODEL = self.settings.get("use_local_model", False)
LOCAL_MODEL_PATH = self.settings.get("local_model_path", "")

if USE_LOCAL_MODEL and not os.path.exists(LOCAL_MODEL_PATH):
raise FileNotFoundError(f"The local model path does not exist at {LOCAL_MODEL_PATH}")

# check if a local model should be loaded or not
if USE_LOCAL_MODEL == False or LOCAL_MODEL_PATH == "":
# create the model directory & download the model
weights_directory = self.get_model_directory(model_id)
self.download_model(model_implementation, model_id, weights_directory)
else:
# load the model from the local model path
weights_directory = LOCAL_MODEL_PATH

return weights_directory

def prepare_model(self, model_implementation: str, model_id: str):
"""
Prepares the model for use by downloading the required files and loading the model.
Expand All @@ -1003,12 +1038,10 @@ def prepare_model(self, model_implementation: str, model_id: str):
model_implementation (str): The model implementation either 'BEST' or 'ENSEMBLE'
model_id (str): The ID of the model.
"""
self.clear_zoo_model()
# create the model directory
self.weights_directory = self.get_model_directory(model_id)
# weights_directory is the directory that contains the model weights, the model card json files and the BEST_MODEL.txt file
self.weights_directory = self.get_weights_directory(model_implementation, model_id)
logger.info(f"self.weights_directory:{self.weights_directory}")

self.download_model(model_implementation, model_id, self.weights_directory)
weights_list = self.get_weights_list(model_implementation)

# Load the model from the config files
Expand Down Expand Up @@ -1096,6 +1129,7 @@ def run_model(
logger.info(f"use_tta: {use_tta}")

print(f"Running model {model_name}")
# print(f"self.settings: {self.settings}")
self.prepare_model(model_implementation, model_name)

# create a session
Expand Down

0 comments on commit f414db4

Please sign in to comment.