Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of Concept: Dynamically Generated DataModels #221

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/roman_datamodels/datamodels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""
This module contains all the DataModel classes and supporting utilities used by the pipeline.
The DataModel classes are generated dynamically at import time from metadata contained
within the RAD schemas.
"""
from ._core import * # noqa: F403
from ._datamodels import * # noqa: F403
from ._mixins import * # noqa: F403

# rename rdm_open to open to match the current roman_datamodels API
from ._utils import rdm_open as open # noqa: F403, F401
4 changes: 0 additions & 4 deletions src/roman_datamodels/datamodels/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ def __init_subclass__(cls, **kwargs):
"""Register each subclass in the MODEL_REGISTRY"""
super().__init_subclass__(**kwargs)

# Allow for sub-registry classes to be defined
if cls.__name__.startswith("_"):
return

# Check the node_type is a tagged object node
if not issubclass(cls._node_type, stnode.TaggedObjectNode):
raise ValueError("Subclass must be a TaggedObjectNode subclass")
Expand Down
201 changes: 18 additions & 183 deletions src/roman_datamodels/datamodels/_datamodels.py
Original file line number Diff line number Diff line change
@@ -1,194 +1,29 @@
"""
This module provides all the specific datamodels used by the Roman pipeline.
These models are what will be read and written by the pipeline to ASDF files.
Note that we require each model to specify a _node_type, which corresponds to
the top-level STNode type that the datamodel wraps. This STNode type is derived
from the schema manifest defined by RAD.
This module dynamically creates all the DataModels from metadata in RAD.
- These models are what will be read and written by the pipeline to ASDF files.
- Note that the DataModels which require additional functionality to the base
DataModel class will have a Mixin defined. This Mixin contains all the additional
functionality and is dynamically added to the DataModel class.
- Unfortunately, this is a dynamic process which occurs at first import time
because roman_datamodels cannot predict what DataModels will be in the version
of RAD used by the user.
"""

import numpy as np

from roman_datamodels import stnode

from ._core import DataModel
from ._factory import datamodel_factory, datamodel_names

__all__ = []


class _DataModel(DataModel):
def _factory(node_type, datamodel_name):
"""
Exists only to populate the __all__ for this file automatically
This is something which is easily missed, but is important for the automatic
documentation generation to work properly.
Wrap the __all__ append and class creation in a function to avoid the linter
getting upset
"""

def __init_subclass__(cls, **kwargs):
"""Register each subclass in the __all__ for this module"""
super().__init_subclass__(**kwargs)
if cls.__name__ in __all__:
raise ValueError(f"Duplicate model type {cls.__name__}")

__all__.append(cls.__name__)


class MosaicModel(_DataModel):
_node_type = stnode.WfiMosaic


class ImageModel(_DataModel):
_node_type = stnode.WfiImage


class ScienceRawModel(_DataModel):
_node_type = stnode.WfiScienceRaw


class MsosStackModel(_DataModel):
_node_type = stnode.MsosStack


class RampModel(_DataModel):
_node_type = stnode.Ramp

@classmethod
def from_science_raw(cls, model):
"""
Construct a RampModel from a ScienceRawModel

Parameters
----------
model : ScienceRawModel or RampModel
The input science raw model (a RampModel will also work)
"""

if isinstance(model, cls):
return model

if isinstance(model, ScienceRawModel):
from roman_datamodels.maker_utils import mk_ramp

instance = mk_ramp(shape=model.shape)

# Copy input_model contents into RampModel
for key in model:
# If a dictionary (like meta), overwrite entries (but keep
# required dummy entries that may not be in input_model)
if isinstance(instance[key], dict):
instance[key].update(getattr(model, key))
elif isinstance(instance[key], np.ndarray):
# Cast input ndarray as RampModel dtype
instance[key] = getattr(model, key).astype(instance[key].dtype)
else:
instance[key] = getattr(model, key)

return cls(instance)

raise ValueError("Input model must be a ScienceRawModel or RampModel")


class RampFitOutputModel(_DataModel):
_node_type = stnode.RampFitOutput


class AssociationsModel(_DataModel):
# Need an init to allow instantiation from a JSON file
_node_type = stnode.Associations

@classmethod
def is_association(cls, asn_data):
"""
Test if an object is an association by checking for required fields

Parameters
----------
asn_data :
The data to be tested.
"""
return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data


class GuidewindowModel(_DataModel):
_node_type = stnode.Guidewindow


class FlatRefModel(_DataModel):
_node_type = stnode.FlatRef


class DarkRefModel(_DataModel):
_node_type = stnode.DarkRef


class DistortionRefModel(_DataModel):
_node_type = stnode.DistortionRef


class GainRefModel(_DataModel):
_node_type = stnode.GainRef


class IpcRefModel(_DataModel):
_node_type = stnode.IpcRef


class LinearityRefModel(_DataModel):
_node_type = stnode.LinearityRef

def get_primary_array_name(self):
"""
Returns the name "primary" array for this model, which
controls the size of other arrays that are implicitly created.
This is intended to be overridden in the subclasses if the
primary array's name is not "data".
"""
return "coeffs"


class InverseLinearityRefModel(_DataModel):
_node_type = stnode.InverseLinearityRef

def get_primary_array_name(self):
"""
Returns the name "primary" array for this model, which
controls the size of other arrays that are implicitly created.
This is intended to be overridden in the subclasses if the
primary array's name is not "data".
"""
return "coeffs"


class MaskRefModel(_DataModel):
_node_type = stnode.MaskRef

def get_primary_array_name(self):
"""
Returns the name "primary" array for this model, which
controls the size of other arrays that are implicitly created.
This is intended to be overridden in the subclasses if the
primary array's name is not "data".
"""
return "dq"


class PixelareaRefModel(_DataModel):
_node_type = stnode.PixelareaRef


class ReadnoiseRefModel(_DataModel):
_node_type = stnode.ReadnoiseRef


class SuperbiasRefModel(_DataModel):
_node_type = stnode.SuperbiasRef


class SaturationRefModel(_DataModel):
_node_type = stnode.SaturationRef


class WfiImgPhotomRefModel(_DataModel):
_node_type = stnode.WfiImgPhotomRef
globals()[datamodel_name] = datamodel_factory(node_type, datamodel_name) # Add to namespace of module
__all__.append(datamodel_name) # add to __all__ so it's imported with `from . import *`


class RefpixRefModel(_DataModel):
_node_type = stnode.RefpixRef
# Main dynamic class creation loop
# Locates each not_type/datamodel_name pair and creates a DataModel class for it
for node_type, datamodel_name in datamodel_names():
_factory(node_type, datamodel_name)
55 changes: 55 additions & 0 deletions src/roman_datamodels/datamodels/_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Factories for creating all the DataModel classes from RAD
These are used to dynamically create all the DataModels which are actually used.
"""
from roman_datamodels import stnode

from . import _mixins
from ._core import DataModel


def datamodel_names():
"""
A generator to grab all the datamodel names and base STNode classes from RAD

Yields
------
node_type, datamodel_name
"""
for tag in stnode._stnode.DATAMODELS_MANIFEST["tags"]:
schema = stnode._factories.load_schema_from_uri(tag["schema_uri"])
if "datamodel_name" in schema:
yield stnode.OBJECT_NODE_CLASSES_BY_TAG[tag["tag_uri"]], schema["datamodel_name"]


def datamodel_factory(node_type, datamodel_name):
"""
The factory for dynamically creating a DataModel class from a node_type and datamodel_name
Note: For DataModels requiring additional functionality, a Mixin must be added to ._mixins.py
with the name <datamodel_name>Mixin.

Parameters
----------
node_type : type
The base STNode class to use as the base for the DataModel
datamodel_name : str
The name of the DataModel to create

Returns
-------
A DataModel object class
"""
if hasattr(_mixins, mixin := f"{datamodel_name}Mixin"):
class_type = (getattr(_mixins, mixin), DataModel)
else:
class_type = (DataModel,)

return type(
datamodel_name,
class_type,
{
"_node_type": node_type,
"__module__": "roman_datamodels.datamodels",
"__doc__": f"Roman {datamodel_name} model",
},
)
105 changes: 105 additions & 0 deletions src/roman_datamodels/datamodels/_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
This module provides all the Mixin classes which will be dynamically mixed into
the DataModel classes at import time.
The name of the mixin must be of the form <DataModelName>Mixin in order for
this to work properly.
"""
import numpy as np


class RampModelMixin:
"""
Mixin class for dynamically generated RampModel
"""

@classmethod
def from_science_raw(cls, model):
"""
Construct a RampModel from a ScienceRawModel

Parameters
----------
model : ScienceRawModel or RampModel
The input science raw model (a RampModel will also work)
"""
from roman_datamodels.datamodels import ScienceRawModel

if isinstance(model, cls):
return model

if isinstance(model, ScienceRawModel):
from roman_datamodels.maker_utils import mk_ramp

instance = mk_ramp(shape=model.shape)

# Copy input_model contents into RampModel
for key in model:
# If a dictionary (like meta), overwrite entries (but keep
# required dummy entries that may not be in input_model)
if isinstance(instance[key], dict):
instance[key].update(getattr(model, key))
elif isinstance(instance[key], np.ndarray):
# Cast input ndarray as RampModel dtype
instance[key] = getattr(model, key).astype(instance[key].dtype)
else:
instance[key] = getattr(model, key)

return cls(instance)

raise ValueError("Input model must be a ScienceRawModel or RampModel")


class AssociationsModelMixin:
"""
Mixin class for dynamically generated AssociationsModel
"""

@classmethod
def is_association(cls, asn_data):
"""
Test if an object is an association by checking for required fields

Parameters
----------
asn_data :
The data to be tested.
"""
return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data


class LinearityRefModelMixin:
"""
Mixin class for dynamically generated LinearityRefModel
"""

def get_primary_array_name(self):
"""
Returns the name "primary" array for this model, which
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be changed to reflect that it is the override, rather than intended to be overridden.

controls the size of other arrays that are implicitly created.
This is intended to be overridden in the subclasses if the
primary array's name is not "data".
"""
return "coeffs"


class InverseLinearityRefModelMixin(LinearityRefModelMixin):
"""
Mixin class for dynamically generated InverseLinearityRefModel
"""

pass


class MaskRefModelMixin:
"""
Mixin class for dynamically generated MaskRefModel
"""

def get_primary_array_name(self):
"""
Returns the name "primary" array for this model, which
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise.

controls the size of other arrays that are implicitly created.
This is intended to be overridden in the subclasses if the
primary array's name is not "data".
"""
return "dq"
Loading