diff --git a/pyproject.toml b/pyproject.toml index 1fe2317..f418d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,10 @@ classifiers = [ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ - "astropy" # Used to load fits files of sources to query HSC cutout server + "astropy", # Used to load fits files of sources to query HSC cutout server + "toml", # Used to load configuration files as dictionaries + "torch", # Used in example model + "torchvision", # Used in example model ] [project.scripts] diff --git a/src/fibad/__init__.py b/src/fibad/__init__.py index b564b85..b4256d7 100644 --- a/src/fibad/__init__.py +++ b/src/fibad/__init__.py @@ -1,3 +1,4 @@ -from .example_module import greetings, meaning +from .config_utils import get_runtime_config, log_runtime_config, merge_configs +from .plugin_utils import fetch_model_class -__all__ = ["greetings", "meaning"] +__all__ = ["get_runtime_config", "merge_configs", "log_runtime_config", "fetch_model_class"] diff --git a/src/fibad/config_utils.py b/src/fibad/config_utils.py new file mode 100644 index 0000000..b22bb00 --- /dev/null +++ b/src/fibad/config_utils.py @@ -0,0 +1,90 @@ +import os + +import toml + +DEFAULT_CONFIG_FILEPATH = "fibad_default_config.toml" + + +def get_runtime_config( + runtime_config_filepath: str = None, default_config_filepath: str = DEFAULT_CONFIG_FILEPATH +) -> dict: + """This function will load the default runtime configuration file, as well + as the user defined runtime configuration file. + + The two configurations will be merged with values in the user defined config + overriding the values of the default configuration. + + The final merged config will be returned as a dictionary and saved as a file + in the results directory. + + Parameters + ---------- + runtime_config_filepath : str + The path to the runtime configuration file. + default_config_filepath : str + The path to the default runtime configuration file. + + Returns + ------- + dict + The parsed runtime configuration. + """ + + if runtime_config_filepath: + if not os.path.exists(runtime_config_filepath): + raise FileNotFoundError(f"Runtime configuration file not found: {runtime_config_filepath}") + + with open(runtime_config_filepath, "r") as f: + users_runtime_config = toml.load(f) + + with open(default_config_filepath, "r") as f: + default_runtime_config = toml.load(f) + + final_runtime_config = merge_configs(default_runtime_config, users_runtime_config) + + # ~ Uncomment when we have a better place to stash results. + # log_runtime_config(final_runtime_config) + + return final_runtime_config + + +def merge_configs(default_config: dict, user_config: dict) -> dict: + """Merge two configurations dictionaries with the user_config values overriding + the default_config values. + + Parameters + ---------- + default_config : dict + The default configuration. + user_config : dict + The user defined configuration. + + Returns + ------- + dict + The merged configuration. + """ + + final_config = default_config.copy() + for k, v in user_config.items(): + if k in final_config and isinstance(final_config[k], dict) and isinstance(v, dict): + final_config[k] = merge_configs(default_config.get(k, {}), v) + else: + final_config[k] = v + + return final_config + + +def log_runtime_config(runtime_config: dict, output_filepath: str = "runtime_config.toml"): + """Log a runtime configuration. + + Parameters + ---------- + runtime_config : dict + A dictionary containing runtime configuration values. + output_filepath : str + The path to the output configuration file + """ + + with open(output_filepath, "w") as f: + f.write(toml.dumps(runtime_config)) diff --git a/src/fibad/example_module.py b/src/fibad/example_module.py deleted file mode 100644 index f76e837..0000000 --- a/src/fibad/example_module.py +++ /dev/null @@ -1,23 +0,0 @@ -"""An example module containing simplistic functions.""" - - -def greetings() -> str: - """A friendly greeting for a future friend. - - Returns - ------- - str - A typical greeting from a software engineer. - """ - return "Hello from LINCC-Frameworks!" - - -def meaning() -> int: - """The meaning of life, the universe, and everything. - - Returns - ------- - int - The meaning of life. - """ - return 42 diff --git a/src/fibad/models/__init__.py b/src/fibad/models/__init__.py new file mode 100644 index 0000000..b6b9911 --- /dev/null +++ b/src/fibad/models/__init__.py @@ -0,0 +1,7 @@ +from .example_cnn_classifier import ExampleCNN + +# rethink the location of this module. If we're not careful, we end up with circular imports +# when using the `fibad_model` decorator on models in this module. +from .model_registry import MODEL_REGISTRY, fibad_model + +__all__ = ["fibad_model", "MODEL_REGISTRY", "ExampleCNN"] diff --git a/src/fibad/models/example_cnn_classifier.py b/src/fibad/models/example_cnn_classifier.py new file mode 100644 index 0000000..719af74 --- /dev/null +++ b/src/fibad/models/example_cnn_classifier.py @@ -0,0 +1,44 @@ +# ruff: noqa: D101, D102 + +# This example model is taken from the PyTorch CIFAR10 tutorial: +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa N812 +import torch.optim as optim + +# extra long import here to address a circular import issue +from fibad.models.model_registry import fibad_model + + +@fibad_model +class ExampleCNN(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.confv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + # ~ The following methods are placeholders for future work + # ~ I don't think this will be the final API!!! + def criterion(self): + return nn.CrossEntropyLoss() + + def optimizer(self): + return optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + + def save(self, path): + torch.save(self.state_dict(), path) diff --git a/src/fibad/models/model_registry.py b/src/fibad/models/model_registry.py new file mode 100644 index 0000000..cf8c0b4 --- /dev/null +++ b/src/fibad/models/model_registry.py @@ -0,0 +1,27 @@ +MODEL_REGISTRY = {} + + +def fibad_model(cls): + """Decorator to register a model with the model registry. + + Returns + ------- + type + The original, unmodified class. + """ + update_model_registry(cls.__name__, cls) + return cls + + +def update_model_registry(name: str, model_class: type): + """Add a model to the model registry. + + Parameters + ---------- + name : str + The name of the model. + model_class : type + The model class. + """ + + MODEL_REGISTRY.update({name: model_class}) diff --git a/src/fibad/plugin_utils.py b/src/fibad/plugin_utils.py new file mode 100644 index 0000000..9405c81 --- /dev/null +++ b/src/fibad/plugin_utils.py @@ -0,0 +1,100 @@ +import importlib + +from fibad.models import * # noqa: F403 +from fibad.models import MODEL_REGISTRY + + +def fetch_model_class(runtime_config: dict) -> type: + """Fetch the model class from the model registry. + + Parameters + ---------- + runtime_config : dict + The runtime configuration dictionary. + + Returns + ------- + type + The model class. + + Raises + ------ + ValueError + If a built in model was requested, but not found in the model registry. + ValueError + If no model was specified in the runtime configuration. + """ + + training_config = runtime_config.get("train", {}) + model_cls = None + + # User specifies one of the built in models by name + if "model_name" in training_config: + model_name = training_config.get("model_name", None) + + if model_name not in MODEL_REGISTRY: # noqa: F405 + raise ValueError(f"Model not found in model registry: {model_name}") + + model_cls = MODEL_REGISTRY[model_name] # noqa: F405 + + # User provides a custom model, attempt to import it with the module spec + elif "model_cls" in training_config: + model_cls = _import_module_from_string(training_config["model_cls"]) + + # User failed to define a model to load + else: + raise ValueError("No model specified in the runtime configuration") + + return model_cls + + +def _import_module_from_string(module_path: str) -> type: + """Dynamically import a module from a string. + + Parameters + ---------- + module_path : str + The import spec for the model class. Should be of the form: + "module.submodule.class_name" + + Returns + ------- + model_cls : type + The model class. + + Raises + ------ + AttributeError + If the model class is not found in the module that is loaded. + ModuleNotFoundError + If the module is not found using the provided import spec. + """ + + module_name, class_name = module_path.rsplit(".", 1) + model_cls = None + + try: + # Attempt to find the module spec, i.e. `module.submodule.`. + # Will raise exception if `submodule`, 'subsubmodule', etc. is not found. + importlib.util.find_spec(module_name) + + # `importlib.util.find_spec()` will return None if `module` is not found. + if (importlib.util.find_spec(module_name)) is not None: + # Load the requested module + module = importlib.import_module(module_name) + + # Check if the requested class is in the module + if hasattr(module, class_name): + model_cls = getattr(module, class_name) + else: + raise AttributeError(f"Model class {class_name} not found in module {module_name}") + + # Raise an exception if the base module of the spec is not found + else: + raise ModuleNotFoundError(f"Module {module_name} not found") + + # Exception raised when a submodule of the spec is not found + except ModuleNotFoundError as exc: + raise ModuleNotFoundError(f"Module {module_name} not found") from exc + + return model_cls diff --git a/src/fibad/train.py b/src/fibad/train.py index 70ba414..01be346 100644 --- a/src/fibad/train.py +++ b/src/fibad/train.py @@ -1,4 +1,7 @@ -"""Scaffolding placeholder for training code.""" +import torch + +from fibad.config_utils import get_runtime_config +from fibad.plugin_utils import fetch_model_class def run(args, config): @@ -14,5 +17,23 @@ def run(args, config): dict """ - print("Prending to run training...") - print(f"Runtime config: {args.runtime_config}") + runtime_config = get_runtime_config(args.runtime_config) + + model_cls = fetch_model_class(runtime_config) + model = model_cls() + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if torch.cuda.device_count() > 1: + # ~ PyTorch docs indicate that batch size should be < number of GPUs. + + # ~ PyTorch documentation recommends using torch.nn.parallel.DistributedDataParallel + # ~ instead of torch.nn.DataParallel for multi-GPU training. + # ~ See: https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead + model = torch.nn.DataParallel(model) + + model.to(device) + + training_config = runtime_config.get("train", {}) + + model.save(training_config.get("model_weights_filepath")) + print("Finished Training") diff --git a/tests/fibad/test_config_utils.py b/tests/fibad/test_config_utils.py new file mode 100644 index 0000000..58a887a --- /dev/null +++ b/tests/fibad/test_config_utils.py @@ -0,0 +1,59 @@ +import os + +from fibad.config_utils import get_runtime_config, merge_configs + + +def test_merge_configs(): + """Basic test to ensure that the merge_configs function will join two dictionaries + correctly, meaning: + 1) The user_config values should override the default_config values. + 2) Values in the default_config that are not in the user_config should remain unchanged. + 3) Values in the user_config that are not in the default_config should be added. + 4) Nested dictionaries should be merged recursively. + """ + default_config = { + "a": 1, + "b": 2, # This tests case 2 + "c": {"d": 3, "e": 4}, + } + + user_config = { + "a": 5, # This tests case 1 + "c": { + "d": 6 # This tests case 4 + }, + "f": 7, # This tests case 3 + } + + expected = {"a": 5, "b": 2, "c": {"d": 6, "e": 4}, "f": 7} + + assert merge_configs(default_config, user_config) == expected + + +def test_get_runtime_config(): + """Test that the get_runtime_config function will load the default and user defined + runtime configuration files, merge them, and return the final configuration as a + dictionary. + """ + + this_file_dir = os.path.dirname(os.path.abspath(__file__)) + runtime_config = get_runtime_config( + runtime_config_filepath=os.path.abspath( + os.path.join(this_file_dir, "./test_data/test_user_config.toml") + ), + default_config_filepath=os.path.abspath( + os.path.join(this_file_dir, "./test_data/test_default_config.toml") + ), + ) + + expected = { + "general": {"use_gpu": False}, + "train": { + "model_name": "example_model", + "model_class": "new_thing.cool_model.CoolModel", + "model": {"model_weights_filepath": "final_best.pth", "layers": 3}, + }, + "predict": {"batch_size": 8}, + } + + assert runtime_config == expected diff --git a/tests/fibad/test_data/test_default_config.toml b/tests/fibad/test_data/test_default_config.toml new file mode 100644 index 0000000..41ddde2 --- /dev/null +++ b/tests/fibad/test_data/test_default_config.toml @@ -0,0 +1,13 @@ +[general] +use_gpu = true + +[train] +model_name = "example_model" # Use a built-in FIBAD model +model_class = "new_thing.cool_model.CoolModel" # Use a custom model + +[train.model] +model_weights_filepath = "example_model.pth" +layers = 3 + +[predict] +batch_size = 32 diff --git a/tests/fibad/test_data/test_user_config.toml b/tests/fibad/test_data/test_user_config.toml new file mode 100644 index 0000000..b531748 --- /dev/null +++ b/tests/fibad/test_data/test_user_config.toml @@ -0,0 +1,9 @@ +[general] +use_gpu = false + +[train.model] +model_weights_filepath = "final_best.pth" +layers = 3 + +[predict] +batch_size = 8 diff --git a/tests/fibad/test_example_module.py b/tests/fibad/test_example_module.py deleted file mode 100644 index 835c911..0000000 --- a/tests/fibad/test_example_module.py +++ /dev/null @@ -1,13 +0,0 @@ -from fibad import example_module - - -def test_greetings() -> None: - """Verify the output of the `greetings` function""" - output = example_module.greetings() - assert output == "Hello from LINCC-Frameworks!" - - -def test_meaning() -> None: - """Verify the output of the `meaning` function""" - output = example_module.meaning() - assert output == 42 diff --git a/tests/fibad/test_plugin_utils.py b/tests/fibad/test_plugin_utils.py new file mode 100644 index 0000000..2d283a5 --- /dev/null +++ b/tests/fibad/test_plugin_utils.py @@ -0,0 +1,105 @@ +import pytest +from fibad import plugin_utils +from fibad.models import fibad_model + + +def test_import_module_from_string(): + """Test the import_module_from_string function.""" + module_path = "builtins.BaseException" + + model_cls = plugin_utils._import_module_from_string(module_path) + + assert model_cls.__name__ == "BaseException" + + +def test_import_module_from_string_no_base_module(): + """Test that the import_module_from_string function raises an error when + the base module is not found.""" + + module_path = "nonexistent.BaseException" + + with pytest.raises(ModuleNotFoundError) as excinfo: + plugin_utils._import_module_from_string(module_path) + + assert "Module nonexistent not found" in str(excinfo.value) + + +def test_import_module_from_string_no_submodule(): + """Test that the import_module_from_string function raises an error when + a submodule is not found.""" + + module_path = "builtins.nonexistent.BaseException" + + with pytest.raises(ModuleNotFoundError) as excinfo: + plugin_utils._import_module_from_string(module_path) + + assert "Module builtins.nonexistent not found" in str(excinfo.value) + + +def test_import_module_from_string_no_class(): + """Test that the import_module_from_string function raises an error when + a class is not found.""" + + module_path = "builtins.Nonexistent" + + with pytest.raises(AttributeError) as excinfo: + plugin_utils._import_module_from_string(module_path) + + assert "Model class Nonexistent not found" in str(excinfo.value) + + +def test_fetch_model_class(): + """Test the fetch_model_class function.""" + config = {"train": {"model_cls": "builtins.BaseException"}} + + model_cls = plugin_utils.fetch_model_class(config) + + assert model_cls.__name__ == "BaseException" + + +def test_fetch_model_class_no_model(): + """Test that the fetch_model_class function raises an error when no model + is specified in the configuration.""" + + config = {"train": {}} + + with pytest.raises(ValueError) as excinfo: + plugin_utils.fetch_model_class(config) + + assert "No model specified in the runtime configuration" in str(excinfo.value) + + +def test_fetch_model_class_no_model_cls(): + """Test that an exception is raised when a non-existent model class is requested.""" + + config = {"train": {"model_cls": "builtins.Nonexistent"}} + + with pytest.raises(AttributeError) as excinfo: + plugin_utils.fetch_model_class(config) + + assert "Model class Nonexistent not found" in str(excinfo.value) + + +def test_fetch_model_class_not_in_registry(): + """Test that an exception is raised when a model is requested that is not in the registry.""" + + config = {"train": {"model_name": "Nonexistent"}} + + with pytest.raises(ValueError) as excinfo: + plugin_utils.fetch_model_class(config) + + assert "Model not found in model registry: Nonexistent" in str(excinfo.value) + + +def test_fetch_model_class_in_registry(): + """Test that a model class is returned when it is in the registry.""" + + # make a no-op model that will be added to the model registry + @fibad_model + class NewClass: + pass + + config = {"train": {"model_name": "NewClass"}} + model_cls = plugin_utils.fetch_model_class(config) + + assert model_cls.__name__ == "NewClass"