-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Scaffolding for config file handling and model selection (#6)
* Initial commit for config file handling. * Initial commit for pluggable model scaffolding. * More support and tests for dynamic model loading and default/user config merging. * Addressing PR feedback. * Fixing pyproject.toml syntax error. * Fixing circular import issue.
- Loading branch information
Showing
14 changed files
with
485 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .example_module import greetings, meaning | ||
from .config_utils import get_runtime_config, log_runtime_config, merge_configs | ||
from .plugin_utils import fetch_model_class | ||
|
||
__all__ = ["greetings", "meaning"] | ||
__all__ = ["get_runtime_config", "merge_configs", "log_runtime_config", "fetch_model_class"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import os | ||
|
||
import toml | ||
|
||
DEFAULT_CONFIG_FILEPATH = "fibad_default_config.toml" | ||
|
||
|
||
def get_runtime_config( | ||
runtime_config_filepath: str = None, default_config_filepath: str = DEFAULT_CONFIG_FILEPATH | ||
) -> dict: | ||
"""This function will load the default runtime configuration file, as well | ||
as the user defined runtime configuration file. | ||
The two configurations will be merged with values in the user defined config | ||
overriding the values of the default configuration. | ||
The final merged config will be returned as a dictionary and saved as a file | ||
in the results directory. | ||
Parameters | ||
---------- | ||
runtime_config_filepath : str | ||
The path to the runtime configuration file. | ||
default_config_filepath : str | ||
The path to the default runtime configuration file. | ||
Returns | ||
------- | ||
dict | ||
The parsed runtime configuration. | ||
""" | ||
|
||
if runtime_config_filepath: | ||
if not os.path.exists(runtime_config_filepath): | ||
raise FileNotFoundError(f"Runtime configuration file not found: {runtime_config_filepath}") | ||
|
||
with open(runtime_config_filepath, "r") as f: | ||
users_runtime_config = toml.load(f) | ||
|
||
with open(default_config_filepath, "r") as f: | ||
default_runtime_config = toml.load(f) | ||
|
||
final_runtime_config = merge_configs(default_runtime_config, users_runtime_config) | ||
|
||
# ~ Uncomment when we have a better place to stash results. | ||
# log_runtime_config(final_runtime_config) | ||
|
||
return final_runtime_config | ||
|
||
|
||
def merge_configs(default_config: dict, user_config: dict) -> dict: | ||
"""Merge two configurations dictionaries with the user_config values overriding | ||
the default_config values. | ||
Parameters | ||
---------- | ||
default_config : dict | ||
The default configuration. | ||
user_config : dict | ||
The user defined configuration. | ||
Returns | ||
------- | ||
dict | ||
The merged configuration. | ||
""" | ||
|
||
final_config = default_config.copy() | ||
for k, v in user_config.items(): | ||
if k in final_config and isinstance(final_config[k], dict) and isinstance(v, dict): | ||
final_config[k] = merge_configs(default_config.get(k, {}), v) | ||
else: | ||
final_config[k] = v | ||
|
||
return final_config | ||
|
||
|
||
def log_runtime_config(runtime_config: dict, output_filepath: str = "runtime_config.toml"): | ||
"""Log a runtime configuration. | ||
Parameters | ||
---------- | ||
runtime_config : dict | ||
A dictionary containing runtime configuration values. | ||
output_filepath : str | ||
The path to the output configuration file | ||
""" | ||
|
||
with open(output_filepath, "w") as f: | ||
f.write(toml.dumps(runtime_config)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .example_cnn_classifier import ExampleCNN | ||
|
||
# rethink the location of this module. If we're not careful, we end up with circular imports | ||
# when using the `fibad_model` decorator on models in this module. | ||
from .model_registry import MODEL_REGISTRY, fibad_model | ||
|
||
__all__ = ["fibad_model", "MODEL_REGISTRY", "ExampleCNN"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# ruff: noqa: D101, D102 | ||
|
||
# This example model is taken from the PyTorch CIFAR10 tutorial: | ||
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-convolutional-neural-network | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F # noqa N812 | ||
import torch.optim as optim | ||
|
||
# extra long import here to address a circular import issue | ||
from fibad.models.model_registry import fibad_model | ||
|
||
|
||
@fibad_model | ||
class ExampleCNN(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.confv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = torch.flatten(x, 1) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x | ||
|
||
# ~ The following methods are placeholders for future work | ||
# ~ I don't think this will be the final API!!! | ||
def criterion(self): | ||
return nn.CrossEntropyLoss() | ||
|
||
def optimizer(self): | ||
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9) | ||
|
||
def save(self, path): | ||
torch.save(self.state_dict(), path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
MODEL_REGISTRY = {} | ||
|
||
|
||
def fibad_model(cls): | ||
"""Decorator to register a model with the model registry. | ||
Returns | ||
------- | ||
type | ||
The original, unmodified class. | ||
""" | ||
update_model_registry(cls.__name__, cls) | ||
return cls | ||
|
||
|
||
def update_model_registry(name: str, model_class: type): | ||
"""Add a model to the model registry. | ||
Parameters | ||
---------- | ||
name : str | ||
The name of the model. | ||
model_class : type | ||
The model class. | ||
""" | ||
|
||
MODEL_REGISTRY.update({name: model_class}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import importlib | ||
|
||
from fibad.models import * # noqa: F403 | ||
from fibad.models import MODEL_REGISTRY | ||
|
||
|
||
def fetch_model_class(runtime_config: dict) -> type: | ||
"""Fetch the model class from the model registry. | ||
Parameters | ||
---------- | ||
runtime_config : dict | ||
The runtime configuration dictionary. | ||
Returns | ||
------- | ||
type | ||
The model class. | ||
Raises | ||
------ | ||
ValueError | ||
If a built in model was requested, but not found in the model registry. | ||
ValueError | ||
If no model was specified in the runtime configuration. | ||
""" | ||
|
||
training_config = runtime_config.get("train", {}) | ||
model_cls = None | ||
|
||
# User specifies one of the built in models by name | ||
if "model_name" in training_config: | ||
model_name = training_config.get("model_name", None) | ||
|
||
if model_name not in MODEL_REGISTRY: # noqa: F405 | ||
raise ValueError(f"Model not found in model registry: {model_name}") | ||
|
||
model_cls = MODEL_REGISTRY[model_name] # noqa: F405 | ||
|
||
# User provides a custom model, attempt to import it with the module spec | ||
elif "model_cls" in training_config: | ||
model_cls = _import_module_from_string(training_config["model_cls"]) | ||
|
||
# User failed to define a model to load | ||
else: | ||
raise ValueError("No model specified in the runtime configuration") | ||
|
||
return model_cls | ||
|
||
|
||
def _import_module_from_string(module_path: str) -> type: | ||
"""Dynamically import a module from a string. | ||
Parameters | ||
---------- | ||
module_path : str | ||
The import spec for the model class. Should be of the form: | ||
"module.submodule.class_name" | ||
Returns | ||
------- | ||
model_cls : type | ||
The model class. | ||
Raises | ||
------ | ||
AttributeError | ||
If the model class is not found in the module that is loaded. | ||
ModuleNotFoundError | ||
If the module is not found using the provided import spec. | ||
""" | ||
|
||
module_name, class_name = module_path.rsplit(".", 1) | ||
model_cls = None | ||
|
||
try: | ||
# Attempt to find the module spec, i.e. `module.submodule.`. | ||
# Will raise exception if `submodule`, 'subsubmodule', etc. is not found. | ||
importlib.util.find_spec(module_name) | ||
|
||
# `importlib.util.find_spec()` will return None if `module` is not found. | ||
if (importlib.util.find_spec(module_name)) is not None: | ||
# Load the requested module | ||
module = importlib.import_module(module_name) | ||
|
||
# Check if the requested class is in the module | ||
if hasattr(module, class_name): | ||
model_cls = getattr(module, class_name) | ||
else: | ||
raise AttributeError(f"Model class {class_name} not found in module {module_name}") | ||
|
||
# Raise an exception if the base module of the spec is not found | ||
else: | ||
raise ModuleNotFoundError(f"Module {module_name} not found") | ||
|
||
# Exception raised when a submodule of the spec is not found | ||
except ModuleNotFoundError as exc: | ||
raise ModuleNotFoundError(f"Module {module_name} not found") from exc | ||
|
||
return model_cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.