Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow definition of Loss and Optimizer in config file. #132

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ name = "ExampleAutoencoder"
base_channel_size = 32
latent_dim = 64

[criterion]
# The name of the built-in criterion to use or the libpath to an external criterion
name = "torch.nn.CrossEntropyLoss"

[optimizer]
# The name of the built-in optimizer to use or the libpath to an external optimizer
name = "torch.optim.SGD"

# Default PyTorch optimizer parameters. The keys match the names of the parameters
lr = 0.01
momentum = 0.9

[train]
weights_filepath = "example_model.pth"
epochs = 10
Expand Down
7 changes: 0 additions & 7 deletions src/fibad/models/example_cnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa N812
import torch.optim as optim

from .model_registry import fibad_model

Expand Down Expand Up @@ -63,9 +62,3 @@ def train_step(self, batch):
loss.backward()
self.optimizer.step()
return {"loss": loss.item()}

def _criterion(self):
return nn.CrossEntropyLoss()

def _optimizer(self):
return optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
16 changes: 16 additions & 0 deletions src/fibad/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
self.load_state_dict(state_dict, assign=True)


def _torch_criterion(self: nn.Module):
criterion_function_cls = get_or_load_class(self.config["criterion"])
arguments = dict(self.config["criterion"])
del arguments["name"]
return criterion_function_cls(**arguments)

Check warning on line 30 in src/fibad/models/model_registry.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/models/model_registry.py#L27-L30

Added lines #L27 - L30 were not covered by tests


def _torch_optimizer(self: nn.Module):
optimizer_cls = get_or_load_class(self.config["optimizer"])
arguments = dict(self.config["optimizer"])
del arguments["name"]
return optimizer_cls(self.parameters(), **arguments)

Check warning on line 37 in src/fibad/models/model_registry.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/models/model_registry.py#L34-L37

Added lines #L34 - L37 were not covered by tests


def fibad_model(cls):
"""Decorator to register a model with the model registry, and to add common interface functions

Expand All @@ -35,6 +49,8 @@
if issubclass(cls, nn.Module):
cls.save = _torch_save
cls.load = _torch_load
cls._criterion = _torch_criterion if not hasattr(cls, "_criterion") else cls._criterion
cls._optimizer = _torch_optimizer if not hasattr(cls, "_optimizer") else cls._optimizer

required_methods = ["train_step", "forward", "__init__"]
for name in required_methods:
Expand Down
4 changes: 2 additions & 2 deletions src/fibad/plugin_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib


def get_or_load_class(config: dict, registry: dict) -> type:
def get_or_load_class(config: dict, registry: dict = None) -> type:
"""Given a configuration dictionary and a registry dictionary, attempt to return
the requested class either from the registry or by dynamically importing it.

Expand All @@ -28,7 +28,7 @@ def get_or_load_class(config: dict, registry: dict) -> type:
if "name" in config:
class_name = config["name"]

if class_name in registry:
if registry and class_name in registry:
returned_class = registry[class_name]
else:
returned_class = import_module_from_string(class_name)
Expand Down
Loading