Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG/CLN: Refactor model abstraction so we don't subclass LightningModel, to fix loss logging #737

Closed
Tracked by #614
NickleDave opened this issue Jan 22, 2024 · 2 comments

Comments

@NickleDave
Copy link
Collaborator

When we subclass LightningModel it seems to break logging, see #726.

I am at least able to get better logging -- still does some weird things -- if I remove all the magic sub-classing of vak.models.base.Model and instead define model families that each separately sub-class LightingModule, e.g. FrameClassificationModel(lightning.LightningModule).

This actually can be fine for us; we define a per-family class, and we refactor the logic we have now for converting definitions into models such that it instantiates the components of a model and then passes those components into the model-family class when instantiating it.

@NickleDave NickleDave changed the title CLN: Refactor model abstraction so we don't subclass LightningModel, to fix loss logging BUG/CLN: Refactor model abstraction so we don't subclass LightningModel, to fix loss logging Jan 22, 2024
@NickleDave
Copy link
Collaborator Author

I think for now we can fix this by changing vak.models.base.Model so it does not sub-class LightningModule, and instead just gives us basically a dataclass instance with the attributes we want that we can pass in to model-family-specific LightningModules.

My read of this code now is that I went in over my head on metaprogramming and "came in ass-first thinking [I] invented sliced bread" (to paraphrase Andy Partridge). Not at all obvious to me why I can't un-spaghetti this code so that it is just a dataclass, e.g. we get back a ModelDefinition instance with the attributes we want on a per-model basis.

@NickleDave
Copy link
Collaborator Author

I think a possible fix is to do something like this

def model(modeldef: class, family: LightningModule):
    """Decorator that creates a new class
    representing a model that belongs to a family of models,
    given a class representing the definition of the model
    and the name of the family.
    """
    definition.validate(modeldef)
    is_valid_family(family)

    class Model:
        # class variables
        modeldef = modeldef
        family = family

        def from_config_dict(self, config_dict) -> LightningModule:
            network, optimizer, loss, metrics = from_attributes(self.model)
            return self.family(
                network,
                optimizer,
                loss,
                metrics,
            )
    
    Model = functools.update_wrapper(Model, modeldef)
    return Model()

It's weird that we return an instance of a class we just defined that has a single method from_model_config ... but this seems like the easiest shim I could put in place right now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant