-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proof of Concept: Dynamically Generated DataModels #221
Closed
WilliamJamieson
wants to merge
5
commits into
spacetelescope:main
from
WilliamJamieson:feature/dynamic_datamodels
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
bd8e32c
Move into a dynamically generated Datamodel API
WilliamJamieson d4baf35
Move factory methods into their own module
WilliamJamieson 30d1f53
Add documentation to the code for dynamic datamodels
WilliamJamieson 25dfe47
Add tests for get_primary_array_name
WilliamJamieson 5ca92e1
Fix mixin ordering for datamodels
WilliamJamieson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
""" | ||
This module contains all the DataModel classes and supporting utilities used by the pipeline. | ||
The DataModel classes are generated dynamically at import time from metadata contained | ||
within the RAD schemas. | ||
""" | ||
from ._core import * # noqa: F403 | ||
from ._datamodels import * # noqa: F403 | ||
from ._mixins import * # noqa: F403 | ||
|
||
# rename rdm_open to open to match the current roman_datamodels API | ||
from ._utils import rdm_open as open # noqa: F403, F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,194 +1,29 @@ | ||
""" | ||
This module provides all the specific datamodels used by the Roman pipeline. | ||
These models are what will be read and written by the pipeline to ASDF files. | ||
Note that we require each model to specify a _node_type, which corresponds to | ||
the top-level STNode type that the datamodel wraps. This STNode type is derived | ||
from the schema manifest defined by RAD. | ||
This module dynamically creates all the DataModels from metadata in RAD. | ||
- These models are what will be read and written by the pipeline to ASDF files. | ||
- Note that the DataModels which require additional functionality to the base | ||
DataModel class will have a Mixin defined. This Mixin contains all the additional | ||
functionality and is dynamically added to the DataModel class. | ||
- Unfortunately, this is a dynamic process which occurs at first import time | ||
because roman_datamodels cannot predict what DataModels will be in the version | ||
of RAD used by the user. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from roman_datamodels import stnode | ||
|
||
from ._core import DataModel | ||
from ._factory import datamodel_factory, datamodel_names | ||
|
||
__all__ = [] | ||
|
||
|
||
class _DataModel(DataModel): | ||
def _factory(node_type, datamodel_name): | ||
""" | ||
Exists only to populate the __all__ for this file automatically | ||
This is something which is easily missed, but is important for the automatic | ||
documentation generation to work properly. | ||
Wrap the __all__ append and class creation in a function to avoid the linter | ||
getting upset | ||
""" | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
"""Register each subclass in the __all__ for this module""" | ||
super().__init_subclass__(**kwargs) | ||
if cls.__name__ in __all__: | ||
raise ValueError(f"Duplicate model type {cls.__name__}") | ||
|
||
__all__.append(cls.__name__) | ||
|
||
|
||
class MosaicModel(_DataModel): | ||
_node_type = stnode.WfiMosaic | ||
|
||
|
||
class ImageModel(_DataModel): | ||
_node_type = stnode.WfiImage | ||
|
||
|
||
class ScienceRawModel(_DataModel): | ||
_node_type = stnode.WfiScienceRaw | ||
|
||
|
||
class MsosStackModel(_DataModel): | ||
_node_type = stnode.MsosStack | ||
|
||
|
||
class RampModel(_DataModel): | ||
_node_type = stnode.Ramp | ||
|
||
@classmethod | ||
def from_science_raw(cls, model): | ||
""" | ||
Construct a RampModel from a ScienceRawModel | ||
|
||
Parameters | ||
---------- | ||
model : ScienceRawModel or RampModel | ||
The input science raw model (a RampModel will also work) | ||
""" | ||
|
||
if isinstance(model, cls): | ||
return model | ||
|
||
if isinstance(model, ScienceRawModel): | ||
from roman_datamodels.maker_utils import mk_ramp | ||
|
||
instance = mk_ramp(shape=model.shape) | ||
|
||
# Copy input_model contents into RampModel | ||
for key in model: | ||
# If a dictionary (like meta), overwrite entries (but keep | ||
# required dummy entries that may not be in input_model) | ||
if isinstance(instance[key], dict): | ||
instance[key].update(getattr(model, key)) | ||
elif isinstance(instance[key], np.ndarray): | ||
# Cast input ndarray as RampModel dtype | ||
instance[key] = getattr(model, key).astype(instance[key].dtype) | ||
else: | ||
instance[key] = getattr(model, key) | ||
|
||
return cls(instance) | ||
|
||
raise ValueError("Input model must be a ScienceRawModel or RampModel") | ||
|
||
|
||
class RampFitOutputModel(_DataModel): | ||
_node_type = stnode.RampFitOutput | ||
|
||
|
||
class AssociationsModel(_DataModel): | ||
# Need an init to allow instantiation from a JSON file | ||
_node_type = stnode.Associations | ||
|
||
@classmethod | ||
def is_association(cls, asn_data): | ||
""" | ||
Test if an object is an association by checking for required fields | ||
|
||
Parameters | ||
---------- | ||
asn_data : | ||
The data to be tested. | ||
""" | ||
return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data | ||
|
||
|
||
class GuidewindowModel(_DataModel): | ||
_node_type = stnode.Guidewindow | ||
|
||
|
||
class FlatRefModel(_DataModel): | ||
_node_type = stnode.FlatRef | ||
|
||
|
||
class DarkRefModel(_DataModel): | ||
_node_type = stnode.DarkRef | ||
|
||
|
||
class DistortionRefModel(_DataModel): | ||
_node_type = stnode.DistortionRef | ||
|
||
|
||
class GainRefModel(_DataModel): | ||
_node_type = stnode.GainRef | ||
|
||
|
||
class IpcRefModel(_DataModel): | ||
_node_type = stnode.IpcRef | ||
|
||
|
||
class LinearityRefModel(_DataModel): | ||
_node_type = stnode.LinearityRef | ||
|
||
def get_primary_array_name(self): | ||
""" | ||
Returns the name "primary" array for this model, which | ||
controls the size of other arrays that are implicitly created. | ||
This is intended to be overridden in the subclasses if the | ||
primary array's name is not "data". | ||
""" | ||
return "coeffs" | ||
|
||
|
||
class InverseLinearityRefModel(_DataModel): | ||
_node_type = stnode.InverseLinearityRef | ||
|
||
def get_primary_array_name(self): | ||
""" | ||
Returns the name "primary" array for this model, which | ||
controls the size of other arrays that are implicitly created. | ||
This is intended to be overridden in the subclasses if the | ||
primary array's name is not "data". | ||
""" | ||
return "coeffs" | ||
|
||
|
||
class MaskRefModel(_DataModel): | ||
_node_type = stnode.MaskRef | ||
|
||
def get_primary_array_name(self): | ||
""" | ||
Returns the name "primary" array for this model, which | ||
controls the size of other arrays that are implicitly created. | ||
This is intended to be overridden in the subclasses if the | ||
primary array's name is not "data". | ||
""" | ||
return "dq" | ||
|
||
|
||
class PixelareaRefModel(_DataModel): | ||
_node_type = stnode.PixelareaRef | ||
|
||
|
||
class ReadnoiseRefModel(_DataModel): | ||
_node_type = stnode.ReadnoiseRef | ||
|
||
|
||
class SuperbiasRefModel(_DataModel): | ||
_node_type = stnode.SuperbiasRef | ||
|
||
|
||
class SaturationRefModel(_DataModel): | ||
_node_type = stnode.SaturationRef | ||
|
||
|
||
class WfiImgPhotomRefModel(_DataModel): | ||
_node_type = stnode.WfiImgPhotomRef | ||
globals()[datamodel_name] = datamodel_factory(node_type, datamodel_name) # Add to namespace of module | ||
__all__.append(datamodel_name) # add to __all__ so it's imported with `from . import *` | ||
|
||
|
||
class RefpixRefModel(_DataModel): | ||
_node_type = stnode.RefpixRef | ||
# Main dynamic class creation loop | ||
# Locates each not_type/datamodel_name pair and creates a DataModel class for it | ||
for node_type, datamodel_name in datamodel_names(): | ||
_factory(node_type, datamodel_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
Factories for creating all the DataModel classes from RAD | ||
These are used to dynamically create all the DataModels which are actually used. | ||
""" | ||
from roman_datamodels import stnode | ||
|
||
from . import _mixins | ||
from ._core import DataModel | ||
|
||
|
||
def datamodel_names(): | ||
""" | ||
A generator to grab all the datamodel names and base STNode classes from RAD | ||
|
||
Yields | ||
------ | ||
node_type, datamodel_name | ||
""" | ||
for tag in stnode._stnode.DATAMODELS_MANIFEST["tags"]: | ||
schema = stnode._factories.load_schema_from_uri(tag["schema_uri"]) | ||
if "datamodel_name" in schema: | ||
yield stnode.OBJECT_NODE_CLASSES_BY_TAG[tag["tag_uri"]], schema["datamodel_name"] | ||
|
||
|
||
def datamodel_factory(node_type, datamodel_name): | ||
""" | ||
The factory for dynamically creating a DataModel class from a node_type and datamodel_name | ||
Note: For DataModels requiring additional functionality, a Mixin must be added to ._mixins.py | ||
with the name <datamodel_name>Mixin. | ||
|
||
Parameters | ||
---------- | ||
node_type : type | ||
The base STNode class to use as the base for the DataModel | ||
datamodel_name : str | ||
The name of the DataModel to create | ||
|
||
Returns | ||
------- | ||
A DataModel object class | ||
""" | ||
if hasattr(_mixins, mixin := f"{datamodel_name}Mixin"): | ||
class_type = (getattr(_mixins, mixin), DataModel) | ||
else: | ||
class_type = (DataModel,) | ||
|
||
return type( | ||
datamodel_name, | ||
class_type, | ||
{ | ||
"_node_type": node_type, | ||
"__module__": "roman_datamodels.datamodels", | ||
"__doc__": f"Roman {datamodel_name} model", | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
This module provides all the Mixin classes which will be dynamically mixed into | ||
the DataModel classes at import time. | ||
The name of the mixin must be of the form <DataModelName>Mixin in order for | ||
this to work properly. | ||
""" | ||
import numpy as np | ||
|
||
|
||
class RampModelMixin: | ||
""" | ||
Mixin class for dynamically generated RampModel | ||
""" | ||
|
||
@classmethod | ||
def from_science_raw(cls, model): | ||
""" | ||
Construct a RampModel from a ScienceRawModel | ||
|
||
Parameters | ||
---------- | ||
model : ScienceRawModel or RampModel | ||
The input science raw model (a RampModel will also work) | ||
""" | ||
from roman_datamodels.datamodels import ScienceRawModel | ||
|
||
if isinstance(model, cls): | ||
return model | ||
|
||
if isinstance(model, ScienceRawModel): | ||
from roman_datamodels.maker_utils import mk_ramp | ||
|
||
instance = mk_ramp(shape=model.shape) | ||
|
||
# Copy input_model contents into RampModel | ||
for key in model: | ||
# If a dictionary (like meta), overwrite entries (but keep | ||
# required dummy entries that may not be in input_model) | ||
if isinstance(instance[key], dict): | ||
instance[key].update(getattr(model, key)) | ||
elif isinstance(instance[key], np.ndarray): | ||
# Cast input ndarray as RampModel dtype | ||
instance[key] = getattr(model, key).astype(instance[key].dtype) | ||
else: | ||
instance[key] = getattr(model, key) | ||
|
||
return cls(instance) | ||
|
||
raise ValueError("Input model must be a ScienceRawModel or RampModel") | ||
|
||
|
||
class AssociationsModelMixin: | ||
""" | ||
Mixin class for dynamically generated AssociationsModel | ||
""" | ||
|
||
@classmethod | ||
def is_association(cls, asn_data): | ||
""" | ||
Test if an object is an association by checking for required fields | ||
|
||
Parameters | ||
---------- | ||
asn_data : | ||
The data to be tested. | ||
""" | ||
return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data | ||
|
||
|
||
class LinearityRefModelMixin: | ||
""" | ||
Mixin class for dynamically generated LinearityRefModel | ||
""" | ||
|
||
def get_primary_array_name(self): | ||
""" | ||
Returns the name "primary" array for this model, which | ||
controls the size of other arrays that are implicitly created. | ||
This is intended to be overridden in the subclasses if the | ||
primary array's name is not "data". | ||
""" | ||
return "coeffs" | ||
|
||
|
||
class InverseLinearityRefModelMixin(LinearityRefModelMixin): | ||
""" | ||
Mixin class for dynamically generated InverseLinearityRefModel | ||
""" | ||
|
||
pass | ||
|
||
|
||
class MaskRefModelMixin: | ||
""" | ||
Mixin class for dynamically generated MaskRefModel | ||
""" | ||
|
||
def get_primary_array_name(self): | ||
""" | ||
Returns the name "primary" array for this model, which | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise. |
||
controls the size of other arrays that are implicitly created. | ||
This is intended to be overridden in the subclasses if the | ||
primary array's name is not "data". | ||
""" | ||
return "dq" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be changed to reflect that it is the override, rather than intended to be overridden.