Skip to content

Commit

Permalink
Merge pull request #209 from ctlearn-project/issue208
Browse files Browse the repository at this point in the history
Update run_model.py
  • Loading branch information
TjarkMiener authored Oct 1, 2024
2 parents 88a1ba8 + d68832d commit 06ccaed
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions ctlearn/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ctlearn.utils import *


def run_model(config, mode="train", debug=False, log_to_file=False):
def run_model(config, mode="train", debug=False, log_to_file=False, save_best_only=True):
# Load options relating to logging and checkpointing
root_model_dir = model_dir = config["Logging"]["model_directory"]

Expand Down Expand Up @@ -299,7 +299,7 @@ def run_model(config, mode="train", debug=False, log_to_file=False):
monitor=monitor,
verbose=1,
mode=monitor_mode,
save_best_only=True,
save_best_only=save_best_only,
initial_value_threshold=initial_value_threshold,
)
# Tensorboard callback
Expand Down Expand Up @@ -541,7 +541,14 @@ def main():
parser.add_argument(
"--debug", action="store_true", help="Print debug/logger messages"
)


parser.add_argument(
"--save_best_only",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag, it only saves when the model is considered the 'best' and the latest best model according to the quantity monitored will not be overwritten",
)

args = parser.parse_args()

# Use the default CTLearn config file if no config file is provided
Expand Down Expand Up @@ -655,6 +662,7 @@ def main():
# Exception handling for the input directory
if not os.path.isdir(abs_file_dir):
raise NotADirectoryError(f"'{abs_file_dir}' is not a directory.")

with open(training_file_list, "a") as file_list:
for pattern in args.pattern:
files = glob.glob(os.path.join(abs_file_dir, pattern))
Expand All @@ -680,7 +688,7 @@ def main():
if args.trigger_patches_from_file:
config["Data"]["trigger_settings"]["get_trigger_patch"] = "file"

run_model(config, mode="train", debug=args.debug, log_to_file=args.log_to_file)
run_model(config, mode="train", debug=args.debug, log_to_file=args.log_to_file, save_best_only= args.save_best_only)

if "predict" in args.mode:
if args.input:
Expand All @@ -689,6 +697,7 @@ def main():
# Exception handling for the input directory
if not os.path.isdir(abs_file_dir):
raise NotADirectoryError(f"'{abs_file_dir}' is not a directory.")

for pattern in args.pattern:
files = glob.glob(os.path.join(abs_file_dir, pattern))
if not files:
Expand Down

0 comments on commit 06ccaed

Please sign in to comment.