diff --git a/src/roman_datamodels/datamodels/__init__.py b/src/roman_datamodels/datamodels/__init__.py index d6ed51e5..643496a5 100644 --- a/src/roman_datamodels/datamodels/__init__.py +++ b/src/roman_datamodels/datamodels/__init__.py @@ -1,5 +1,11 @@ +""" +This module contains all the DataModel classes and supporting utilities used by the pipeline. + The DataModel classes are generated dynamically at import time from metadata contained + within the RAD schemas. +""" from ._core import * # noqa: F403 from ._datamodels import * # noqa: F403 +from ._mixins import * # noqa: F403 # rename rdm_open to open to match the current roman_datamodels API from ._utils import rdm_open as open # noqa: F403, F401 diff --git a/src/roman_datamodels/datamodels/_core.py b/src/roman_datamodels/datamodels/_core.py index 04ae3ea6..6bed5192 100644 --- a/src/roman_datamodels/datamodels/_core.py +++ b/src/roman_datamodels/datamodels/_core.py @@ -41,10 +41,6 @@ def __init_subclass__(cls, **kwargs): """Register each subclass in the MODEL_REGISTRY""" super().__init_subclass__(**kwargs) - # Allow for sub-registry classes to be defined - if cls.__name__.startswith("_"): - return - # Check the node_type is a tagged object node if not issubclass(cls._node_type, stnode.TaggedObjectNode): raise ValueError("Subclass must be a TaggedObjectNode subclass") diff --git a/src/roman_datamodels/datamodels/_datamodels.py b/src/roman_datamodels/datamodels/_datamodels.py index 047558c5..61aeb22a 100644 --- a/src/roman_datamodels/datamodels/_datamodels.py +++ b/src/roman_datamodels/datamodels/_datamodels.py @@ -1,194 +1,29 @@ """ -This module provides all the specific datamodels used by the Roman pipeline. - These models are what will be read and written by the pipeline to ASDF files. - Note that we require each model to specify a _node_type, which corresponds to - the top-level STNode type that the datamodel wraps. This STNode type is derived - from the schema manifest defined by RAD. +This module dynamically creates all the DataModels from metadata in RAD. + - These models are what will be read and written by the pipeline to ASDF files. + - Note that the DataModels which require additional functionality to the base + DataModel class will have a Mixin defined. This Mixin contains all the additional + functionality and is dynamically added to the DataModel class. + - Unfortunately, this is a dynamic process which occurs at first import time + because roman_datamodels cannot predict what DataModels will be in the version + of RAD used by the user. """ -import numpy as np - -from roman_datamodels import stnode - -from ._core import DataModel +from ._factory import datamodel_factory, datamodel_names __all__ = [] -class _DataModel(DataModel): +def _factory(node_type, datamodel_name): """ - Exists only to populate the __all__ for this file automatically - This is something which is easily missed, but is important for the automatic - documentation generation to work properly. + Wrap the __all__ append and class creation in a function to avoid the linter + getting upset """ - - def __init_subclass__(cls, **kwargs): - """Register each subclass in the __all__ for this module""" - super().__init_subclass__(**kwargs) - if cls.__name__ in __all__: - raise ValueError(f"Duplicate model type {cls.__name__}") - - __all__.append(cls.__name__) - - -class MosaicModel(_DataModel): - _node_type = stnode.WfiMosaic - - -class ImageModel(_DataModel): - _node_type = stnode.WfiImage - - -class ScienceRawModel(_DataModel): - _node_type = stnode.WfiScienceRaw - - -class MsosStackModel(_DataModel): - _node_type = stnode.MsosStack - - -class RampModel(_DataModel): - _node_type = stnode.Ramp - - @classmethod - def from_science_raw(cls, model): - """ - Construct a RampModel from a ScienceRawModel - - Parameters - ---------- - model : ScienceRawModel or RampModel - The input science raw model (a RampModel will also work) - """ - - if isinstance(model, cls): - return model - - if isinstance(model, ScienceRawModel): - from roman_datamodels.maker_utils import mk_ramp - - instance = mk_ramp(shape=model.shape) - - # Copy input_model contents into RampModel - for key in model: - # If a dictionary (like meta), overwrite entries (but keep - # required dummy entries that may not be in input_model) - if isinstance(instance[key], dict): - instance[key].update(getattr(model, key)) - elif isinstance(instance[key], np.ndarray): - # Cast input ndarray as RampModel dtype - instance[key] = getattr(model, key).astype(instance[key].dtype) - else: - instance[key] = getattr(model, key) - - return cls(instance) - - raise ValueError("Input model must be a ScienceRawModel or RampModel") - - -class RampFitOutputModel(_DataModel): - _node_type = stnode.RampFitOutput - - -class AssociationsModel(_DataModel): - # Need an init to allow instantiation from a JSON file - _node_type = stnode.Associations - - @classmethod - def is_association(cls, asn_data): - """ - Test if an object is an association by checking for required fields - - Parameters - ---------- - asn_data : - The data to be tested. - """ - return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data - - -class GuidewindowModel(_DataModel): - _node_type = stnode.Guidewindow - - -class FlatRefModel(_DataModel): - _node_type = stnode.FlatRef - - -class DarkRefModel(_DataModel): - _node_type = stnode.DarkRef - - -class DistortionRefModel(_DataModel): - _node_type = stnode.DistortionRef - - -class GainRefModel(_DataModel): - _node_type = stnode.GainRef - - -class IpcRefModel(_DataModel): - _node_type = stnode.IpcRef - - -class LinearityRefModel(_DataModel): - _node_type = stnode.LinearityRef - - def get_primary_array_name(self): - """ - Returns the name "primary" array for this model, which - controls the size of other arrays that are implicitly created. - This is intended to be overridden in the subclasses if the - primary array's name is not "data". - """ - return "coeffs" - - -class InverseLinearityRefModel(_DataModel): - _node_type = stnode.InverseLinearityRef - - def get_primary_array_name(self): - """ - Returns the name "primary" array for this model, which - controls the size of other arrays that are implicitly created. - This is intended to be overridden in the subclasses if the - primary array's name is not "data". - """ - return "coeffs" - - -class MaskRefModel(_DataModel): - _node_type = stnode.MaskRef - - def get_primary_array_name(self): - """ - Returns the name "primary" array for this model, which - controls the size of other arrays that are implicitly created. - This is intended to be overridden in the subclasses if the - primary array's name is not "data". - """ - return "dq" - - -class PixelareaRefModel(_DataModel): - _node_type = stnode.PixelareaRef - - -class ReadnoiseRefModel(_DataModel): - _node_type = stnode.ReadnoiseRef - - -class SuperbiasRefModel(_DataModel): - _node_type = stnode.SuperbiasRef - - -class SaturationRefModel(_DataModel): - _node_type = stnode.SaturationRef - - -class WfiImgPhotomRefModel(_DataModel): - _node_type = stnode.WfiImgPhotomRef + globals()[datamodel_name] = datamodel_factory(node_type, datamodel_name) # Add to namespace of module + __all__.append(datamodel_name) # add to __all__ so it's imported with `from . import *` -class RefpixRefModel(_DataModel): - _node_type = stnode.RefpixRef +# Main dynamic class creation loop +# Locates each not_type/datamodel_name pair and creates a DataModel class for it +for node_type, datamodel_name in datamodel_names(): + _factory(node_type, datamodel_name) diff --git a/src/roman_datamodels/datamodels/_factory.py b/src/roman_datamodels/datamodels/_factory.py new file mode 100644 index 00000000..acb64133 --- /dev/null +++ b/src/roman_datamodels/datamodels/_factory.py @@ -0,0 +1,55 @@ +""" +Factories for creating all the DataModel classes from RAD + These are used to dynamically create all the DataModels which are actually used. +""" +from roman_datamodels import stnode + +from . import _mixins +from ._core import DataModel + + +def datamodel_names(): + """ + A generator to grab all the datamodel names and base STNode classes from RAD + + Yields + ------ + node_type, datamodel_name + """ + for tag in stnode._stnode.DATAMODELS_MANIFEST["tags"]: + schema = stnode._factories.load_schema_from_uri(tag["schema_uri"]) + if "datamodel_name" in schema: + yield stnode.OBJECT_NODE_CLASSES_BY_TAG[tag["tag_uri"]], schema["datamodel_name"] + + +def datamodel_factory(node_type, datamodel_name): + """ + The factory for dynamically creating a DataModel class from a node_type and datamodel_name + Note: For DataModels requiring additional functionality, a Mixin must be added to ._mixins.py + with the name Mixin. + + Parameters + ---------- + node_type : type + The base STNode class to use as the base for the DataModel + datamodel_name : str + The name of the DataModel to create + + Returns + ------- + A DataModel object class + """ + if hasattr(_mixins, mixin := f"{datamodel_name}Mixin"): + class_type = (getattr(_mixins, mixin), DataModel) + else: + class_type = (DataModel,) + + return type( + datamodel_name, + class_type, + { + "_node_type": node_type, + "__module__": "roman_datamodels.datamodels", + "__doc__": f"Roman {datamodel_name} model", + }, + ) diff --git a/src/roman_datamodels/datamodels/_mixins.py b/src/roman_datamodels/datamodels/_mixins.py new file mode 100644 index 00000000..f090fb56 --- /dev/null +++ b/src/roman_datamodels/datamodels/_mixins.py @@ -0,0 +1,105 @@ +""" +This module provides all the Mixin classes which will be dynamically mixed into + the DataModel classes at import time. + The name of the mixin must be of the form Mixin in order for + this to work properly. +""" +import numpy as np + + +class RampModelMixin: + """ + Mixin class for dynamically generated RampModel + """ + + @classmethod + def from_science_raw(cls, model): + """ + Construct a RampModel from a ScienceRawModel + + Parameters + ---------- + model : ScienceRawModel or RampModel + The input science raw model (a RampModel will also work) + """ + from roman_datamodels.datamodels import ScienceRawModel + + if isinstance(model, cls): + return model + + if isinstance(model, ScienceRawModel): + from roman_datamodels.maker_utils import mk_ramp + + instance = mk_ramp(shape=model.shape) + + # Copy input_model contents into RampModel + for key in model: + # If a dictionary (like meta), overwrite entries (but keep + # required dummy entries that may not be in input_model) + if isinstance(instance[key], dict): + instance[key].update(getattr(model, key)) + elif isinstance(instance[key], np.ndarray): + # Cast input ndarray as RampModel dtype + instance[key] = getattr(model, key).astype(instance[key].dtype) + else: + instance[key] = getattr(model, key) + + return cls(instance) + + raise ValueError("Input model must be a ScienceRawModel or RampModel") + + +class AssociationsModelMixin: + """ + Mixin class for dynamically generated AssociationsModel + """ + + @classmethod + def is_association(cls, asn_data): + """ + Test if an object is an association by checking for required fields + + Parameters + ---------- + asn_data : + The data to be tested. + """ + return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data + + +class LinearityRefModelMixin: + """ + Mixin class for dynamically generated LinearityRefModel + """ + + def get_primary_array_name(self): + """ + Returns the name "primary" array for this model, which + controls the size of other arrays that are implicitly created. + This is intended to be overridden in the subclasses if the + primary array's name is not "data". + """ + return "coeffs" + + +class InverseLinearityRefModelMixin(LinearityRefModelMixin): + """ + Mixin class for dynamically generated InverseLinearityRefModel + """ + + pass + + +class MaskRefModelMixin: + """ + Mixin class for dynamically generated MaskRefModel + """ + + def get_primary_array_name(self): + """ + Returns the name "primary" array for this model, which + controls the size of other arrays that are implicitly created. + This is intended to be overridden in the subclasses if the + primary array's name is not "data". + """ + return "dq" diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py index 1b2c522b..1b448d2a 100644 --- a/src/roman_datamodels/stnode/__init__.py +++ b/src/roman_datamodels/stnode/__init__.py @@ -5,5 +5,6 @@ from ._converters import * # noqa: F403 from ._mixins import * # noqa: F403 from ._node import * # noqa: F403 +from ._registry import * # noqa: F403 from ._stnode import * # noqa: F403 from ._tagged import * # noqa: F403 diff --git a/tests/test_models.py b/tests/test_models.py index 1c858f25..6af93b6e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -387,6 +387,11 @@ def test_make_linearity(): assert linearity_model.validate() is None +def test_linearity_ref_mixin(): + linearity = utils.mk_datamodel(datamodels.LinearityRefModel, shape=(2, 8, 8)) + assert linearity.get_primary_array_name() == "coeffs" + + # InverseLinearity tests def test_make_inverse_linearity(): inverselinearity = utils.mk_inverse_linearity(shape=(2, 8, 8)) @@ -399,6 +404,11 @@ def test_make_inverse_linearity(): assert inverselinearity_model.validate() is None +def test_inverse_linearity_ref_mixin(): + inverse_linearity = utils.mk_datamodel(datamodels.InverseLinearityRefModel, shape=(2, 8, 8)) + assert inverse_linearity.get_primary_array_name() == "coeffs" + + # Mask tests def test_make_mask(): mask = utils.mk_mask(shape=(8, 8)) @@ -410,6 +420,11 @@ def test_make_mask(): assert mask_model.validate() is None +def test_mask_ref_mixin(): + mask = utils.mk_datamodel(datamodels.MaskRefModel, shape=(8, 8)) + assert mask.get_primary_array_name() == "dq" + + # Pixel Area tests def test_make_pixelarea(): pixearea = utils.mk_pixelarea(shape=(8, 8))