From c47f4937df3368a79c0f0e50dd932a9f2f1c10aa Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 13:22:51 -0400 Subject: [PATCH 01/14] Move stnode into a sub-package --- src/roman_datamodels/stnode/__init__.py | 1 + src/roman_datamodels/{stnode.py => stnode/_stnode.py} | 5 ++++- tests/test_stnode.py | 8 ++++---- 3 files changed, 9 insertions(+), 5 deletions(-) create mode 100644 src/roman_datamodels/stnode/__init__.py rename src/roman_datamodels/{stnode.py => stnode/_stnode.py} (98%) diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py new file mode 100644 index 00000000..c9661ace --- /dev/null +++ b/src/roman_datamodels/stnode/__init__.py @@ -0,0 +1 @@ +from ._stnode import * # noqa: F403 diff --git a/src/roman_datamodels/stnode.py b/src/roman_datamodels/stnode/_stnode.py similarity index 98% rename from src/roman_datamodels/stnode.py rename to src/roman_datamodels/stnode/_stnode.py index 3e978384..2164b809 100644 --- a/src/roman_datamodels/stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -23,7 +23,7 @@ from astropy.time import Time from astropy.units import Unit # noqa: F401 -from .validate import ValidationWarning, _check_type, _error_message, will_strict_validate, will_validate +from roman_datamodels.validate import ValidationWarning, _check_type, _error_message, will_strict_validate, will_validate __all__ = [ "WfiMode", @@ -35,6 +35,9 @@ "TaggedObjectNode", "TaggedListNode", "TaggedScalarNode", + "TaggedListNodeConverter", + "TaggedObjectNodeConverter", + "TaggedScalarNodeConverter", ] diff --git a/tests/test_stnode.py b/tests/test_stnode.py index a759c443..5d1900e8 100644 --- a/tests/test_stnode.py +++ b/tests/test_stnode.py @@ -11,15 +11,15 @@ def test_generated_node_classes(manifest): for tag in manifest["tags"]: - class_name = stnode._class_name_from_tag_uri(tag["tag_uri"]) - node_class = getattr(stnode, class_name) + class_name = stnode._stnode._class_name_from_tag_uri(tag["tag_uri"]) + node_class = getattr(stnode._stnode, class_name) assert issubclass(node_class, (stnode.TaggedObjectNode, stnode.TaggedListNode, stnode.TaggedScalarNode)) assert node_class._tag == tag["tag_uri"] assert tag["description"] in node_class.__doc__ assert tag["tag_uri"] in node_class.__doc__ - assert node_class.__module__ == stnode.__name__ - assert node_class.__name__ in stnode.__all__ + assert node_class.__module__.startswith(stnode.__name__) + assert node_class.__name__ in stnode._stnode.__all__ @pytest.mark.parametrize("node_class", stnode.NODE_CLASSES) From 6800502a64f7f7804f12f0721ae4b867d2daa06d Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 13:44:57 -0400 Subject: [PATCH 02/14] Separate base node classes --- src/roman_datamodels/stnode/__init__.py | 1 + src/roman_datamodels/stnode/_node.py | 245 ++++++++++++++++++ src/roman_datamodels/stnode/_registry.py | 4 + src/roman_datamodels/stnode/_stnode.py | 309 +++-------------------- 4 files changed, 282 insertions(+), 277 deletions(-) create mode 100644 src/roman_datamodels/stnode/_node.py create mode 100644 src/roman_datamodels/stnode/_registry.py diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py index c9661ace..9cf1f099 100644 --- a/src/roman_datamodels/stnode/__init__.py +++ b/src/roman_datamodels/stnode/__init__.py @@ -1 +1,2 @@ +from ._node import * # noqa: F403 from ._stnode import * # noqa: F403 diff --git a/src/roman_datamodels/stnode/_node.py b/src/roman_datamodels/stnode/_node.py new file mode 100644 index 00000000..2eb0b81c --- /dev/null +++ b/src/roman_datamodels/stnode/_node.py @@ -0,0 +1,245 @@ +import datetime +import re +import warnings +from collections import UserList +from collections.abc import MutableMapping + +import asdf +import asdf.schema as asdfschema +import asdf.yamlutil as yamlutil +import jsonschema +import numpy as np +from asdf.tags.core import ndarray +from asdf.util import HashableDict +from astropy.time import Time + +from roman_datamodels.validate import ValidationWarning, _check_type, _error_message, will_strict_validate, will_validate + +from ._registry import SCALAR_NODE_CLASSES_BY_KEY + +validator_callbacks = HashableDict(asdfschema.YAML_VALIDATORS) +validator_callbacks.update({"type": _check_type}) + + +def _value_change(path, value, schema, pass_invalid_values, strict_validation, ctx): + """ + Validate a change in value against a schema. + Trap error and return a flag. + """ + try: + _check_value(value, schema, ctx) + update = True + + except jsonschema.ValidationError as error: + update = False + errmsg = _error_message(path, error) + if pass_invalid_values: + update = True + if strict_validation: + raise jsonschema.ValidationError(errmsg) + else: + warnings.warn(errmsg, ValidationWarning) + return update + + +def _check_value(value, schema, validator_context): + """ + Perform the actual validation. + """ + + temp_schema = {"$schema": "http://stsci.edu/schemas/asdf-schema/0.1.0/asdf-schema"} + temp_schema.update(schema) + validator = asdfschema.get_validator(temp_schema, validator_context, validator_callbacks) + + validator.validate(value, _schema=temp_schema) + validator_context.close() + + +def _validate(attr, instance, schema, ctx): + tagged_tree = yamlutil.custom_tree_to_tagged_tree(instance, ctx) + return _value_change(attr, tagged_tree, schema, False, will_strict_validate(), ctx) + + +def _get_schema_for_property(schema, attr): + # Check if attr is a property + subschema = schema.get("properties", {}).get(attr, None) + + # Check if attr is a pattern property + props = schema.get("patternProperties", {}) + for key, value in props.items(): + if re.match(key, attr): + subschema = value + break + + if subschema is not None: + return subschema + for combiner in ["allOf", "anyOf"]: + for subschema in schema.get(combiner, []): + subsubschema = _get_schema_for_property(subschema, attr) + if subsubschema != {}: + return subsubschema + + return {} + + +class DNode(MutableMapping): + _tag = None + _ctx = None + + def __init__(self, node=None, parent=None, name=None): + if node is None: + self.__dict__["_data"] = {} + elif isinstance(node, dict): + self.__dict__["_data"] = node + else: + raise ValueError("Initializer only accepts dicts") + self._x_schema = None + self._schema_uri = None + self._parent = parent + self._name = name + + @property + def ctx(self): + if self._ctx is None: + DNode._ctx = asdf.AsdfFile() + return self._ctx + + @staticmethod + def _convert_to_scalar(key, value): + if key in SCALAR_NODE_CLASSES_BY_KEY: + value = SCALAR_NODE_CLASSES_BY_KEY[key](value) + + return value + + def __getattr__(self, key): + """ + Permit accessing dict keys as attributes, assuming they are legal Python + variable names. + """ + if key.startswith("_"): + raise AttributeError(f"No attribute {key}") + if key in self._data: + value = self._convert_to_scalar(key, self._data[key]) + if isinstance(value, dict): + return DNode(value, parent=self, name=key) + elif isinstance(value, list): + return LNode(value) + else: + return value + else: + raise AttributeError(f"No such attribute ({key}) found in node") + + def __setattr__(self, key, value): + """ + Permit assigning dict keys as attributes. + """ + if key[0] != "_": + value = self._convert_to_scalar(key, value) + if key in self._data: + if will_validate(): + schema = _get_schema_for_property(self._schema(), key) + + if schema == {} or _validate(key, value, schema, self.ctx): + self._data[key] = value + self.__dict__["_data"][key] = value + else: + raise AttributeError(f"No such attribute ({key}) found in node") + else: + self.__dict__[key] = value + + def to_flat_dict(self, include_arrays=True): + """ + Returns a dictionary of all of the schema items as a flat dictionary. + + Each dictionary key is a dot-separated name. For example, the + schema element ``meta.observation.date`` will end up in the + dictionary as:: + + { "meta.observation.date": "2012-04-22T03:22:05.432" } + + """ + + def convert_val(val): + if isinstance(val, datetime.datetime): + return val.isoformat() + elif isinstance(val, Time): + return str(val) + return val + + if include_arrays: + return {key: convert_val(val) for (key, val) in self.items()} + else: + return { + key: convert_val(val) for (key, val) in self.items() if not isinstance(val, (np.ndarray, ndarray.NDArrayType)) + } + + def _schema(self): + """ + If not overridden by a subclass, it will search for a schema from + the parent class, recursing if necessary until one is found. + """ + if self._x_schema is None: + parent_schema = self._parent._schema() + # Extract the subschema corresponding to this node. + subschema = _get_schema_for_property(parent_schema, self._name) + self._x_schema = subschema + + return self._x_schema + + def __asdf_traverse__(self): + return dict(self) + + def __len__(self): + return len(self._data) + + def __getitem__(self, key): + if key in self._data: + return self._data[key] + + raise KeyError(f"No such key ({key}) found in node") + + def __setitem__(self, key, value): + value = self._convert_to_scalar(key, value) + if isinstance(value, dict): + for sub_key, sub_value in value.items(): + value[sub_key] = self._convert_to_scalar(sub_key, sub_value) + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + def __iter__(self): + return iter(self._data) + + def copy(self): + instance = self.__class__.__new__(self.__class__) + instance.__dict__.update(self.__dict__.copy()) + instance.__dict__["_data"] = self.__dict__["_data"].copy() + + return instance + + +class LNode(UserList): + _tag = None + + def __init__(self, node=None): + if node is None: + self.data = [] + elif isinstance(node, list): + self.data = node + elif isinstance(node, self.__class__): + self.data = node.data + else: + raise ValueError("Initializer only accepts lists") + + def __getitem__(self, index): + value = self.data[index] + if isinstance(value, dict): + return DNode(value) + elif isinstance(value, list): + return LNode(value) + else: + return value + + def __asdf_traverse__(self): + return list(self) diff --git a/src/roman_datamodels/stnode/_registry.py b/src/roman_datamodels/stnode/_registry.py new file mode 100644 index 00000000..cf025ee2 --- /dev/null +++ b/src/roman_datamodels/stnode/_registry.py @@ -0,0 +1,4 @@ +OBJECT_NODE_CLASSES_BY_TAG = {} +LIST_NODE_CLASSES_BY_TAG = {} +SCALAR_NODE_CLASSES_BY_TAG = {} +SCALAR_NODE_CLASSES_BY_KEY = {} diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index 2164b809..323c002c 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -2,36 +2,29 @@ Proof of concept of using tags with the data model framework """ -import datetime import importlib.resources -import re -import warnings from abc import ABCMeta -from collections import UserList -from collections.abc import MutableMapping import asdf import asdf.schema as asdfschema -import asdf.yamlutil as yamlutil -import jsonschema -import numpy as np import rad.resources import yaml from asdf.extension import Converter -from asdf.tags.core import ndarray -from asdf.util import HashableDict from astropy.time import Time -from astropy.units import Unit # noqa: F401 -from roman_datamodels.validate import ValidationWarning, _check_type, _error_message, will_strict_validate, will_validate +from ._node import DNode, LNode +from ._registry import ( + LIST_NODE_CLASSES_BY_TAG, + OBJECT_NODE_CLASSES_BY_TAG, + SCALAR_NODE_CLASSES_BY_KEY, + SCALAR_NODE_CLASSES_BY_TAG, +) __all__ = [ "WfiMode", "NODE_CLASSES", "CalLogs", "FileDate", - "DNode", - "LNode", "TaggedObjectNode", "TaggedListNode", "TaggedScalarNode", @@ -41,237 +34,6 @@ ] -validator_callbacks = HashableDict(asdfschema.YAML_VALIDATORS) -validator_callbacks.update({"type": _check_type}) - - -def _value_change(path, value, schema, pass_invalid_values, strict_validation, ctx): - """ - Validate a change in value against a schema. - Trap error and return a flag. - """ - try: - _check_value(value, schema, ctx) - update = True - - except jsonschema.ValidationError as error: - update = False - errmsg = _error_message(path, error) - if pass_invalid_values: - update = True - if strict_validation: - raise jsonschema.ValidationError(errmsg) - else: - warnings.warn(errmsg, ValidationWarning) - return update - - -def _check_value(value, schema, validator_context): - """ - Perform the actual validation. - """ - - temp_schema = {"$schema": "http://stsci.edu/schemas/asdf-schema/0.1.0/asdf-schema"} - temp_schema.update(schema) - validator = asdfschema.get_validator(temp_schema, validator_context, validator_callbacks) - - validator.validate(value, _schema=temp_schema) - validator_context.close() - - -def _validate(attr, instance, schema, ctx): - tagged_tree = yamlutil.custom_tree_to_tagged_tree(instance, ctx) - return _value_change(attr, tagged_tree, schema, False, will_strict_validate(), ctx) - - -def _get_schema_for_property(schema, attr): - # Check if attr is a property - subschema = schema.get("properties", {}).get(attr, None) - - # Check if attr is a pattern property - props = schema.get("patternProperties", {}) - for key, value in props.items(): - if re.match(key, attr): - subschema = value - break - - if subschema is not None: - return subschema - for combiner in ["allOf", "anyOf"]: - for subschema in schema.get(combiner, []): - subsubschema = _get_schema_for_property(subschema, attr) - if subsubschema != {}: - return subsubschema - - return {} - - -class DNode(MutableMapping): - _tag = None - _ctx = None - - def __init__(self, node=None, parent=None, name=None): - if node is None: - self.__dict__["_data"] = {} - elif isinstance(node, dict): - self.__dict__["_data"] = node - else: - raise ValueError("Initializer only accepts dicts") - self._x_schema = None - self._schema_uri = None - self._parent = parent - self._name = name - - @property - def ctx(self): - if self._ctx is None: - DNode._ctx = asdf.AsdfFile() - return self._ctx - - @staticmethod - def _convert_to_scalar(key, value): - if key in _SCALAR_NODE_CLASSES_BY_KEY: - value = _SCALAR_NODE_CLASSES_BY_KEY[key](value) - - return value - - def __getattr__(self, key): - """ - Permit accessing dict keys as attributes, assuming they are legal Python - variable names. - """ - if key.startswith("_"): - raise AttributeError(f"No attribute {key}") - if key in self._data: - value = self._convert_to_scalar(key, self._data[key]) - if isinstance(value, dict): - return DNode(value, parent=self, name=key) - elif isinstance(value, list): - return LNode(value) - else: - return value - else: - raise AttributeError(f"No such attribute ({key}) found in node") - - def __setattr__(self, key, value): - """ - Permit assigning dict keys as attributes. - """ - if key[0] != "_": - value = self._convert_to_scalar(key, value) - if key in self._data: - if will_validate(): - schema = _get_schema_for_property(self._schema(), key) - - if schema == {} or _validate(key, value, schema, self.ctx): - self._data[key] = value - self.__dict__["_data"][key] = value - else: - raise AttributeError(f"No such attribute ({key}) found in node") - else: - self.__dict__[key] = value - - def to_flat_dict(self, include_arrays=True): - """ - Returns a dictionary of all of the schema items as a flat dictionary. - - Each dictionary key is a dot-separated name. For example, the - schema element ``meta.observation.date`` will end up in the - dictionary as:: - - { "meta.observation.date": "2012-04-22T03:22:05.432" } - - """ - - def convert_val(val): - if isinstance(val, datetime.datetime): - return val.isoformat() - elif isinstance(val, Time): - return str(val) - return val - - if include_arrays: - return {key: convert_val(val) for (key, val) in self.items()} - else: - return { - key: convert_val(val) for (key, val) in self.items() if not isinstance(val, (np.ndarray, ndarray.NDArrayType)) - } - - def _schema(self): - """ - If not overridden by a subclass, it will search for a schema from - the parent class, recursing if necessary until one is found. - """ - if self._x_schema is None: - parent_schema = self._parent._schema() - # Extract the subschema corresponding to this node. - subschema = _get_schema_for_property(parent_schema, self._name) - self._x_schema = subschema - - return self._x_schema - - def __asdf_traverse__(self): - return dict(self) - - def __len__(self): - return len(self._data) - - def __getitem__(self, key): - if key in self._data: - return self._data[key] - - raise KeyError(f"No such key ({key}) found in node") - - def __setitem__(self, key, value): - value = self._convert_to_scalar(key, value) - if isinstance(value, dict): - for sub_key, sub_value in value.items(): - value[sub_key] = self._convert_to_scalar(sub_key, sub_value) - self._data[key] = value - - def __delitem__(self, key): - del self._data[key] - - def __iter__(self): - return iter(self._data) - - def copy(self): - instance = self.__class__.__new__(self.__class__) - instance.__dict__.update(self.__dict__.copy()) - instance.__dict__["_data"] = self.__dict__["_data"].copy() - - return instance - - -class LNode(UserList): - _tag = None - - def __init__(self, node=None): - if node is None: - self.data = [] - elif isinstance(node, list): - self.data = node - elif isinstance(node, self.__class__): - self.data = node.data - else: - raise ValueError("Initializer only accepts lists") - - def __getitem__(self, index): - value = self.data[index] - if isinstance(value, dict): - return DNode(value) - elif isinstance(value, list): - return LNode(value) - else: - return value - - def __asdf_traverse__(self): - return list(self) - - -_OBJECT_NODE_CLASSES_BY_TAG = {} - - class TaggedObjectNodeMeta(ABCMeta): """ Metaclass for TaggedObjectNode that maintains a registry @@ -281,9 +43,9 @@ class TaggedObjectNodeMeta(ABCMeta): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.__name__ != "TaggedObjectNode": - if self._tag in _OBJECT_NODE_CLASSES_BY_TAG: + if self._tag in OBJECT_NODE_CLASSES_BY_TAG: raise RuntimeError(f"TaggedObjectNode class for tag '{self._tag}' has been defined twice") - _OBJECT_NODE_CLASSES_BY_TAG[self._tag] = self + OBJECT_NODE_CLASSES_BY_TAG[self._tag] = self class TaggedObjectNode(DNode, metaclass=TaggedObjectNodeMeta): @@ -309,9 +71,6 @@ def get_schema(self): return schema -_LIST_NODE_CLASSES_BY_TAG = {} - - class TaggedListNodeMeta(ABCMeta): """ Metaclass for TaggedListNode that maintains a registry @@ -321,9 +80,9 @@ class TaggedListNodeMeta(ABCMeta): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.__name__ != "TaggedListNode": - if self._tag in _LIST_NODE_CLASSES_BY_TAG: + if self._tag in LIST_NODE_CLASSES_BY_TAG: raise RuntimeError(f"TaggedListNode class for tag '{self._tag}' has been defined twice") - _LIST_NODE_CLASSES_BY_TAG[self._tag] = self + LIST_NODE_CLASSES_BY_TAG[self._tag] = self class TaggedListNode(LNode, metaclass=TaggedListNodeMeta): @@ -332,10 +91,6 @@ def tag(self): return self._tag -_SCALAR_NODE_CLASSES_BY_TAG = {} -_SCALAR_NODE_CLASSES_BY_KEY = {} - - def _scalar_tag_to_key(tag): return tag.split("/")[-1].split("-")[0] @@ -349,10 +104,10 @@ class TaggedScalarNodeMeta(ABCMeta): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.__name__ != "TaggedScalarNode": - if self._tag in _SCALAR_NODE_CLASSES_BY_TAG: + if self._tag in SCALAR_NODE_CLASSES_BY_TAG: raise RuntimeError(f"TaggedScalarNode class for tag '{self._tag}' has been defined twice") - _SCALAR_NODE_CLASSES_BY_TAG[self._tag] = self - _SCALAR_NODE_CLASSES_BY_KEY[_scalar_tag_to_key(self._tag)] = self + SCALAR_NODE_CLASSES_BY_TAG[self._tag] = self + SCALAR_NODE_CLASSES_BY_KEY[_scalar_tag_to_key(self._tag)] = self class TaggedScalarNode(metaclass=TaggedScalarNodeMeta): @@ -424,11 +179,11 @@ class TaggedObjectNodeConverter(Converter): @property def tags(self): - return list(_OBJECT_NODE_CLASSES_BY_TAG.keys()) + return list(OBJECT_NODE_CLASSES_BY_TAG.keys()) @property def types(self): - return list(_OBJECT_NODE_CLASSES_BY_TAG.values()) + return list(OBJECT_NODE_CLASSES_BY_TAG.values()) def select_tag(self, obj, tags, ctx): return obj.tag @@ -437,7 +192,7 @@ def to_yaml_tree(self, obj, tag, ctx): return obj._data def from_yaml_tree(self, node, tag, ctx): - return _OBJECT_NODE_CLASSES_BY_TAG[tag](node) + return OBJECT_NODE_CLASSES_BY_TAG[tag](node) class TaggedListNodeConverter(Converter): @@ -447,11 +202,11 @@ class TaggedListNodeConverter(Converter): @property def tags(self): - return list(_LIST_NODE_CLASSES_BY_TAG.keys()) + return list(LIST_NODE_CLASSES_BY_TAG.keys()) @property def types(self): - return list(_LIST_NODE_CLASSES_BY_TAG.values()) + return list(LIST_NODE_CLASSES_BY_TAG.values()) def select_tag(self, obj, tags, ctx): return obj.tag @@ -460,7 +215,7 @@ def to_yaml_tree(self, obj, tag, ctx): return list(obj) def from_yaml_tree(self, node, tag, ctx): - return _LIST_NODE_CLASSES_BY_TAG[tag](node) + return LIST_NODE_CLASSES_BY_TAG[tag](node) class TaggedScalarNodeConverter(Converter): @@ -470,11 +225,11 @@ class TaggedScalarNodeConverter(Converter): @property def tags(self): - return list(_SCALAR_NODE_CLASSES_BY_TAG.keys()) + return list(SCALAR_NODE_CLASSES_BY_TAG.keys()) @property def types(self): - return list(_SCALAR_NODE_CLASSES_BY_TAG.values()) + return list(SCALAR_NODE_CLASSES_BY_TAG.values()) def select_tag(self, obj, tags, ctx): return obj.tag @@ -493,7 +248,7 @@ def from_yaml_tree(self, node, tag, ctx): converter = ctx.extension_manager.get_converter_for_type(Time) node = converter.from_yaml_tree(node, tag, ctx) - return _SCALAR_NODE_CLASSES_BY_TAG[tag](node) + return SCALAR_NODE_CLASSES_BY_TAG[tag](node) _DATAMODELS_MANIFEST_PATH = importlib.resources.files(rad.resources) / "manifests" / "datamodels-1.0.yaml" @@ -535,12 +290,12 @@ def _class_from_tag(tag, docstring): docstring = tag["description"] + "\n\n" docstring = docstring + f"Class generated from tag '{tag['tag_uri']}'" - if tag["tag_uri"] in _OBJECT_NODE_CLASSES_BY_TAG: - _OBJECT_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring - elif tag["tag_uri"] in _LIST_NODE_CLASSES_BY_TAG: - _LIST_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring - elif tag["tag_uri"] in _SCALAR_NODE_CLASSES_BY_TAG: - _SCALAR_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring + if tag["tag_uri"] in OBJECT_NODE_CLASSES_BY_TAG: + OBJECT_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring + elif tag["tag_uri"] in LIST_NODE_CLASSES_BY_TAG: + LIST_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring + elif tag["tag_uri"] in SCALAR_NODE_CLASSES_BY_TAG: + SCALAR_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring else: _class_from_tag(tag, docstring) @@ -548,7 +303,7 @@ def _class_from_tag(tag, docstring): # List of node classes made available by this library. This is part # of the public API. NODE_CLASSES = ( - list(_OBJECT_NODE_CLASSES_BY_TAG.values()) - + list(_LIST_NODE_CLASSES_BY_TAG.values()) - + list(_SCALAR_NODE_CLASSES_BY_TAG.values()) + list(OBJECT_NODE_CLASSES_BY_TAG.values()) + + list(LIST_NODE_CLASSES_BY_TAG.values()) + + list(SCALAR_NODE_CLASSES_BY_TAG.values()) ) From 6c7f6b7674ae57c5a5e77c8e43cd185bc7507979 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 13:54:16 -0400 Subject: [PATCH 03/14] Pull tagged objects into its own module --- src/roman_datamodels/stnode/__init__.py | 1 + src/roman_datamodels/stnode/_stnode.py | 125 +----------------------- src/roman_datamodels/stnode/_tagged.py | 122 +++++++++++++++++++++++ 3 files changed, 125 insertions(+), 123 deletions(-) create mode 100644 src/roman_datamodels/stnode/_tagged.py diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py index 9cf1f099..86815e8b 100644 --- a/src/roman_datamodels/stnode/__init__.py +++ b/src/roman_datamodels/stnode/__init__.py @@ -1,2 +1,3 @@ from ._node import * # noqa: F403 from ._stnode import * # noqa: F403 +from ._tagged import * # noqa: F403 diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index 323c002c..f9b80fdc 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -3,147 +3,26 @@ """ import importlib.resources -from abc import ABCMeta -import asdf -import asdf.schema as asdfschema import rad.resources import yaml from asdf.extension import Converter from astropy.time import Time -from ._node import DNode, LNode -from ._registry import ( - LIST_NODE_CLASSES_BY_TAG, - OBJECT_NODE_CLASSES_BY_TAG, - SCALAR_NODE_CLASSES_BY_KEY, - SCALAR_NODE_CLASSES_BY_TAG, -) +from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG +from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode __all__ = [ "WfiMode", "NODE_CLASSES", "CalLogs", "FileDate", - "TaggedObjectNode", - "TaggedListNode", - "TaggedScalarNode", "TaggedListNodeConverter", "TaggedObjectNodeConverter", "TaggedScalarNodeConverter", ] -class TaggedObjectNodeMeta(ABCMeta): - """ - Metaclass for TaggedObjectNode that maintains a registry - of subclasses. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.__name__ != "TaggedObjectNode": - if self._tag in OBJECT_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedObjectNode class for tag '{self._tag}' has been defined twice") - OBJECT_NODE_CLASSES_BY_TAG[self._tag] = self - - -class TaggedObjectNode(DNode, metaclass=TaggedObjectNodeMeta): - """ - Expects subclass to define a class instance of _tag - """ - - @property - def tag(self): - return self._tag - - def _schema(self): - if self._x_schema is None: - self._x_schema = self.get_schema() - return self._x_schema - - def get_schema(self): - """Retrieve the schema associated with this tag""" - extension_manager = self.ctx.extension_manager - tag_def = extension_manager.get_tag_definition(self.tag) - schema_uri = tag_def.schema_uris[0] - schema = asdfschema.load_schema(schema_uri, resolve_references=True) - return schema - - -class TaggedListNodeMeta(ABCMeta): - """ - Metaclass for TaggedListNode that maintains a registry - of subclasses. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.__name__ != "TaggedListNode": - if self._tag in LIST_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedListNode class for tag '{self._tag}' has been defined twice") - LIST_NODE_CLASSES_BY_TAG[self._tag] = self - - -class TaggedListNode(LNode, metaclass=TaggedListNodeMeta): - @property - def tag(self): - return self._tag - - -def _scalar_tag_to_key(tag): - return tag.split("/")[-1].split("-")[0] - - -class TaggedScalarNodeMeta(ABCMeta): - """ - Metaclass for TaggedScalarNode that maintains a registry - of subclasses. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.__name__ != "TaggedScalarNode": - if self._tag in SCALAR_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedScalarNode class for tag '{self._tag}' has been defined twice") - SCALAR_NODE_CLASSES_BY_TAG[self._tag] = self - SCALAR_NODE_CLASSES_BY_KEY[_scalar_tag_to_key(self._tag)] = self - - -class TaggedScalarNode(metaclass=TaggedScalarNodeMeta): - _tag = None - _ctx = None - - @property - def ctx(self): - if self._ctx is None: - TaggedScalarNode._ctx = asdf.AsdfFile() - return self._ctx - - def __asdf_traverse__(self): - return self - - @property - def tag(self): - return self._tag - - @property - def key(self): - return _scalar_tag_to_key(self._tag) - - def get_schema(self): - extension_manager = self.ctx.extension_manager - tag_def = extension_manager.get_tag_definition(self.tag) - schema_uri = tag_def.schema_uris[0] - schema = asdf.schema.load_schema(schema_uri, resolve_references=True) - return schema - - def copy(self): - import copy - - return copy.copy(self) - - class WfiMode(TaggedObjectNode): _tag = "asdf://stsci.edu/datamodels/roman/tags/wfi_mode-1.0.0" diff --git a/src/roman_datamodels/stnode/_tagged.py b/src/roman_datamodels/stnode/_tagged.py new file mode 100644 index 00000000..e8ab5eea --- /dev/null +++ b/src/roman_datamodels/stnode/_tagged.py @@ -0,0 +1,122 @@ +from abc import ABCMeta + +import asdf +import asdf.schema as asdfschema + +from ._node import DNode, LNode +from ._registry import ( + LIST_NODE_CLASSES_BY_TAG, + OBJECT_NODE_CLASSES_BY_TAG, + SCALAR_NODE_CLASSES_BY_KEY, + SCALAR_NODE_CLASSES_BY_TAG, +) + + +class TaggedObjectNodeMeta(ABCMeta): + """ + Metaclass for TaggedObjectNode that maintains a registry + of subclasses. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.__name__ != "TaggedObjectNode": + if self._tag in OBJECT_NODE_CLASSES_BY_TAG: + raise RuntimeError(f"TaggedObjectNode class for tag '{self._tag}' has been defined twice") + OBJECT_NODE_CLASSES_BY_TAG[self._tag] = self + + +class TaggedObjectNode(DNode, metaclass=TaggedObjectNodeMeta): + """ + Expects subclass to define a class instance of _tag + """ + + @property + def tag(self): + return self._tag + + def _schema(self): + if self._x_schema is None: + self._x_schema = self.get_schema() + return self._x_schema + + def get_schema(self): + """Retrieve the schema associated with this tag""" + extension_manager = self.ctx.extension_manager + tag_def = extension_manager.get_tag_definition(self.tag) + schema_uri = tag_def.schema_uris[0] + schema = asdfschema.load_schema(schema_uri, resolve_references=True) + return schema + + +class TaggedListNodeMeta(ABCMeta): + """ + Metaclass for TaggedListNode that maintains a registry + of subclasses. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.__name__ != "TaggedListNode": + if self._tag in LIST_NODE_CLASSES_BY_TAG: + raise RuntimeError(f"TaggedListNode class for tag '{self._tag}' has been defined twice") + LIST_NODE_CLASSES_BY_TAG[self._tag] = self + + +class TaggedListNode(LNode, metaclass=TaggedListNodeMeta): + @property + def tag(self): + return self._tag + + +def _scalar_tag_to_key(tag): + return tag.split("/")[-1].split("-")[0] + + +class TaggedScalarNodeMeta(ABCMeta): + """ + Metaclass for TaggedScalarNode that maintains a registry + of subclasses. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.__name__ != "TaggedScalarNode": + if self._tag in SCALAR_NODE_CLASSES_BY_TAG: + raise RuntimeError(f"TaggedScalarNode class for tag '{self._tag}' has been defined twice") + SCALAR_NODE_CLASSES_BY_TAG[self._tag] = self + SCALAR_NODE_CLASSES_BY_KEY[_scalar_tag_to_key(self._tag)] = self + + +class TaggedScalarNode(metaclass=TaggedScalarNodeMeta): + _tag = None + _ctx = None + + @property + def ctx(self): + if self._ctx is None: + TaggedScalarNode._ctx = asdf.AsdfFile() + return self._ctx + + def __asdf_traverse__(self): + return self + + @property + def tag(self): + return self._tag + + @property + def key(self): + return _scalar_tag_to_key(self._tag) + + def get_schema(self): + extension_manager = self.ctx.extension_manager + tag_def = extension_manager.get_tag_definition(self.tag) + schema_uri = tag_def.schema_uris[0] + schema = asdf.schema.load_schema(schema_uri, resolve_references=True) + return schema + + def copy(self): + import copy + + return copy.copy(self) From abdd9c5c33084cb7bfd4b06d28cbfbd002e1aa92 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 14:15:57 -0400 Subject: [PATCH 04/14] Moved fixed node creation into its own module --- src/roman_datamodels/stnode/__init__.py | 3 +++ src/roman_datamodels/stnode/_fixed.py | 31 ++++++++++++++++++++++ src/roman_datamodels/stnode/_stnode.py | 34 ++----------------------- tests/test_stnode.py | 4 +-- 4 files changed, 38 insertions(+), 34 deletions(-) create mode 100644 src/roman_datamodels/stnode/_fixed.py diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py index 86815e8b..c3eb4b74 100644 --- a/src/roman_datamodels/stnode/__init__.py +++ b/src/roman_datamodels/stnode/__init__.py @@ -1,3 +1,6 @@ +from ._fixed import * # noqa: F403 from ._node import * # noqa: F403 from ._stnode import * # noqa: F403 from ._tagged import * # noqa: F403 + +__all__ = [v.__name__ for v in globals().values() if hasattr(v, "__name__")] diff --git a/src/roman_datamodels/stnode/_fixed.py b/src/roman_datamodels/stnode/_fixed.py new file mode 100644 index 00000000..773df242 --- /dev/null +++ b/src/roman_datamodels/stnode/_fixed.py @@ -0,0 +1,31 @@ +from astropy.time import Time + +from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode + + +class WfiMode(TaggedObjectNode): + _tag = "asdf://stsci.edu/datamodels/roman/tags/wfi_mode-1.0.0" + + _GRATING_OPTICAL_ELEMENTS = {"GRISM", "PRISM"} + + @property + def filter(self): + if self.optical_element in self._GRATING_OPTICAL_ELEMENTS: + return None + else: + return self.optical_element + + @property + def grating(self): + if self.optical_element in self._GRATING_OPTICAL_ELEMENTS: + return self.optical_element + else: + return None + + +class CalLogs(TaggedListNode): + _tag = "asdf://stsci.edu/datamodels/roman/tags/cal_logs-1.0.0" + + +class FileDate(Time, TaggedScalarNode): + _tag = "asdf://stsci.edu/datamodels/roman/tags/file_date-1.0.0" diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index f9b80fdc..2a0c5c53 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -9,48 +9,18 @@ from asdf.extension import Converter from astropy.time import Time +from ._fixed import FileDate from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG -from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode +from ._tagged import TaggedObjectNode, TaggedScalarNode __all__ = [ - "WfiMode", "NODE_CLASSES", - "CalLogs", - "FileDate", "TaggedListNodeConverter", "TaggedObjectNodeConverter", "TaggedScalarNodeConverter", ] -class WfiMode(TaggedObjectNode): - _tag = "asdf://stsci.edu/datamodels/roman/tags/wfi_mode-1.0.0" - - _GRATING_OPTICAL_ELEMENTS = {"GRISM", "PRISM"} - - @property - def filter(self): - if self.optical_element in self._GRATING_OPTICAL_ELEMENTS: - return None - else: - return self.optical_element - - @property - def grating(self): - if self.optical_element in self._GRATING_OPTICAL_ELEMENTS: - return self.optical_element - else: - return None - - -class CalLogs(TaggedListNode): - _tag = "asdf://stsci.edu/datamodels/roman/tags/cal_logs-1.0.0" - - -class FileDate(Time, TaggedScalarNode): - _tag = "asdf://stsci.edu/datamodels/roman/tags/file_date-1.0.0" - - class TaggedObjectNodeConverter(Converter): """ Converter for all subclasses of TaggedObjectNode. diff --git a/tests/test_stnode.py b/tests/test_stnode.py index 5d1900e8..0fb715fc 100644 --- a/tests/test_stnode.py +++ b/tests/test_stnode.py @@ -12,14 +12,14 @@ def test_generated_node_classes(manifest): for tag in manifest["tags"]: class_name = stnode._stnode._class_name_from_tag_uri(tag["tag_uri"]) - node_class = getattr(stnode._stnode, class_name) + node_class = getattr(stnode, class_name) assert issubclass(node_class, (stnode.TaggedObjectNode, stnode.TaggedListNode, stnode.TaggedScalarNode)) assert node_class._tag == tag["tag_uri"] assert tag["description"] in node_class.__doc__ assert tag["tag_uri"] in node_class.__doc__ assert node_class.__module__.startswith(stnode.__name__) - assert node_class.__name__ in stnode._stnode.__all__ + assert node_class.__name__ in stnode.__all__ @pytest.mark.parametrize("node_class", stnode.NODE_CLASSES) From 9b67d0143e43fcc70ed8704e55bdee8a813c960e Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 14:26:44 -0400 Subject: [PATCH 05/14] Separate converters --- src/roman_datamodels/stnode/__init__.py | 1 + src/roman_datamodels/stnode/_converters.py | 84 +++++++++++++++++++++ src/roman_datamodels/stnode/_stnode.py | 85 ---------------------- 3 files changed, 85 insertions(+), 85 deletions(-) create mode 100644 src/roman_datamodels/stnode/_converters.py diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py index c3eb4b74..2a186aa4 100644 --- a/src/roman_datamodels/stnode/__init__.py +++ b/src/roman_datamodels/stnode/__init__.py @@ -1,3 +1,4 @@ +from ._converters import * # noqa: F403 from ._fixed import * # noqa: F403 from ._node import * # noqa: F403 from ._stnode import * # noqa: F403 diff --git a/src/roman_datamodels/stnode/_converters.py b/src/roman_datamodels/stnode/_converters.py new file mode 100644 index 00000000..997f7e7b --- /dev/null +++ b/src/roman_datamodels/stnode/_converters.py @@ -0,0 +1,84 @@ +from asdf.extension import Converter +from astropy.time import Time + +from ._fixed import FileDate +from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG + + +class TaggedObjectNodeConverter(Converter): + """ + Converter for all subclasses of TaggedObjectNode. + """ + + @property + def tags(self): + return list(OBJECT_NODE_CLASSES_BY_TAG.keys()) + + @property + def types(self): + return list(OBJECT_NODE_CLASSES_BY_TAG.values()) + + def select_tag(self, obj, tags, ctx): + return obj.tag + + def to_yaml_tree(self, obj, tag, ctx): + return obj._data + + def from_yaml_tree(self, node, tag, ctx): + return OBJECT_NODE_CLASSES_BY_TAG[tag](node) + + +class TaggedListNodeConverter(Converter): + """ + Converter for all subclasses of TaggedListNode. + """ + + @property + def tags(self): + return list(LIST_NODE_CLASSES_BY_TAG.keys()) + + @property + def types(self): + return list(LIST_NODE_CLASSES_BY_TAG.values()) + + def select_tag(self, obj, tags, ctx): + return obj.tag + + def to_yaml_tree(self, obj, tag, ctx): + return list(obj) + + def from_yaml_tree(self, node, tag, ctx): + return LIST_NODE_CLASSES_BY_TAG[tag](node) + + +class TaggedScalarNodeConverter(Converter): + """ + Converter for all subclasses of TaggedScalarNode. + """ + + @property + def tags(self): + return list(SCALAR_NODE_CLASSES_BY_TAG.keys()) + + @property + def types(self): + return list(SCALAR_NODE_CLASSES_BY_TAG.values()) + + def select_tag(self, obj, tags, ctx): + return obj.tag + + def to_yaml_tree(self, obj, tag, ctx): + node = obj.__class__.__bases__[0](obj) + + if tag == FileDate._tag: + converter = ctx.extension_manager.get_converter_for_type(type(node)) + node = converter.to_yaml_tree(node, tag, ctx) + + return node + + def from_yaml_tree(self, node, tag, ctx): + if tag == FileDate._tag: + converter = ctx.extension_manager.get_converter_for_type(Time) + node = converter.from_yaml_tree(node, tag, ctx) + + return SCALAR_NODE_CLASSES_BY_TAG[tag](node) diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index 2a0c5c53..764a20d9 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -6,100 +6,15 @@ import rad.resources import yaml -from asdf.extension import Converter -from astropy.time import Time -from ._fixed import FileDate from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG from ._tagged import TaggedObjectNode, TaggedScalarNode __all__ = [ "NODE_CLASSES", - "TaggedListNodeConverter", - "TaggedObjectNodeConverter", - "TaggedScalarNodeConverter", ] -class TaggedObjectNodeConverter(Converter): - """ - Converter for all subclasses of TaggedObjectNode. - """ - - @property - def tags(self): - return list(OBJECT_NODE_CLASSES_BY_TAG.keys()) - - @property - def types(self): - return list(OBJECT_NODE_CLASSES_BY_TAG.values()) - - def select_tag(self, obj, tags, ctx): - return obj.tag - - def to_yaml_tree(self, obj, tag, ctx): - return obj._data - - def from_yaml_tree(self, node, tag, ctx): - return OBJECT_NODE_CLASSES_BY_TAG[tag](node) - - -class TaggedListNodeConverter(Converter): - """ - Converter for all subclasses of TaggedListNode. - """ - - @property - def tags(self): - return list(LIST_NODE_CLASSES_BY_TAG.keys()) - - @property - def types(self): - return list(LIST_NODE_CLASSES_BY_TAG.values()) - - def select_tag(self, obj, tags, ctx): - return obj.tag - - def to_yaml_tree(self, obj, tag, ctx): - return list(obj) - - def from_yaml_tree(self, node, tag, ctx): - return LIST_NODE_CLASSES_BY_TAG[tag](node) - - -class TaggedScalarNodeConverter(Converter): - """ - Converter for all subclasses of TaggedScalarNode. - """ - - @property - def tags(self): - return list(SCALAR_NODE_CLASSES_BY_TAG.keys()) - - @property - def types(self): - return list(SCALAR_NODE_CLASSES_BY_TAG.values()) - - def select_tag(self, obj, tags, ctx): - return obj.tag - - def to_yaml_tree(self, obj, tag, ctx): - node = obj.__class__.__bases__[0](obj) - - if tag == FileDate._tag: - converter = ctx.extension_manager.get_converter_for_type(type(node)) - node = converter.to_yaml_tree(node, tag, ctx) - - return node - - def from_yaml_tree(self, node, tag, ctx): - if tag == FileDate._tag: - converter = ctx.extension_manager.get_converter_for_type(Time) - node = converter.from_yaml_tree(node, tag, ctx) - - return SCALAR_NODE_CLASSES_BY_TAG[tag](node) - - _DATAMODELS_MANIFEST_PATH = importlib.resources.files(rad.resources) / "manifests" / "datamodels-1.0.yaml" _DATAMODELS_MANIFEST = yaml.safe_load(_DATAMODELS_MANIFEST_PATH.read_bytes()) From f3b88956fee943dc7faacec4d6b893723d59350e Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 15:13:00 -0400 Subject: [PATCH 06/14] Minor cleanups of tests and node dynamic module --- src/roman_datamodels/stnode/_fixed.py | 6 ++++++ src/roman_datamodels/stnode/_stnode.py | 9 ++++----- src/roman_datamodels/stnode/_tagged.py | 22 +++++++++------------- tests/conftest.py | 4 +++- tests/test_stnode.py | 24 +++++++++++++----------- 5 files changed, 35 insertions(+), 30 deletions(-) diff --git a/src/roman_datamodels/stnode/_fixed.py b/src/roman_datamodels/stnode/_fixed.py index 773df242..c7f3c850 100644 --- a/src/roman_datamodels/stnode/_fixed.py +++ b/src/roman_datamodels/stnode/_fixed.py @@ -8,6 +8,8 @@ class WfiMode(TaggedObjectNode): _GRATING_OPTICAL_ELEMENTS = {"GRISM", "PRISM"} + __module__ = "roman_datamodels.stnode" + @property def filter(self): if self.optical_element in self._GRATING_OPTICAL_ELEMENTS: @@ -26,6 +28,10 @@ def grating(self): class CalLogs(TaggedListNode): _tag = "asdf://stsci.edu/datamodels/roman/tags/cal_logs-1.0.0" + __module__ = "roman_datamodels.stnode" + class FileDate(Time, TaggedScalarNode): _tag = "asdf://stsci.edu/datamodels/roman/tags/file_date-1.0.0" + + __module__ = "roman_datamodels.stnode" diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index 764a20d9..a9eccd1a 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -15,8 +15,8 @@ ] -_DATAMODELS_MANIFEST_PATH = importlib.resources.files(rad.resources) / "manifests" / "datamodels-1.0.yaml" -_DATAMODELS_MANIFEST = yaml.safe_load(_DATAMODELS_MANIFEST_PATH.read_bytes()) +DATAMODELS_MANIFEST_PATH = importlib.resources.files(rad.resources) / "manifests" / "datamodels-1.0.yaml" +DATAMODELS_MANIFEST = yaml.safe_load(DATAMODELS_MANIFEST_PATH.read_bytes()) def _class_name_from_tag_uri(tag_uri): @@ -30,8 +30,7 @@ def _class_name_from_tag_uri(tag_uri): def _class_from_tag(tag, docstring): class_name = _class_name_from_tag_uri(tag["tag_uri"]) - schema_uri = tag["schema_uri"] - if "tagged_scalar" in schema_uri: + if "tagged_scalar" in tag["schema_uri"]: cls = type( class_name, (str, TaggedScalarNode), @@ -48,7 +47,7 @@ def _class_from_tag(tag, docstring): __all__.append(class_name) -for tag in _DATAMODELS_MANIFEST["tags"]: +for tag in DATAMODELS_MANIFEST["tags"]: docstring = "" if "description" in tag: docstring = tag["description"] + "\n\n" diff --git a/src/roman_datamodels/stnode/_tagged.py b/src/roman_datamodels/stnode/_tagged.py index e8ab5eea..ff7815cb 100644 --- a/src/roman_datamodels/stnode/_tagged.py +++ b/src/roman_datamodels/stnode/_tagged.py @@ -1,7 +1,7 @@ +import copy from abc import ABCMeta import asdf -import asdf.schema as asdfschema from ._node import DNode, LNode from ._registry import ( @@ -12,6 +12,12 @@ ) +def get_schema_from_tag(ctx, tag): + schema_uri = ctx.extension_manager.get_tag_definition(tag).schema_uris[0] + + return asdf.schema.load_schema(schema_uri, resolve_references=True) + + class TaggedObjectNodeMeta(ABCMeta): """ Metaclass for TaggedObjectNode that maintains a registry @@ -42,11 +48,7 @@ def _schema(self): def get_schema(self): """Retrieve the schema associated with this tag""" - extension_manager = self.ctx.extension_manager - tag_def = extension_manager.get_tag_definition(self.tag) - schema_uri = tag_def.schema_uris[0] - schema = asdfschema.load_schema(schema_uri, resolve_references=True) - return schema + return get_schema_from_tag(self.ctx, self._tag) class TaggedListNodeMeta(ABCMeta): @@ -110,13 +112,7 @@ def key(self): return _scalar_tag_to_key(self._tag) def get_schema(self): - extension_manager = self.ctx.extension_manager - tag_def = extension_manager.get_tag_definition(self.tag) - schema_uri = tag_def.schema_uris[0] - schema = asdf.schema.load_schema(schema_uri, resolve_references=True) - return schema + return get_schema_from_tag(self.ctx, self._tag) def copy(self): - import copy - return copy.copy(self) diff --git a/tests/conftest.py b/tests/conftest.py index 41e2ed45..767bb2cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,12 @@ from roman_datamodels import datamodels from roman_datamodels import maker_utils as utils +MANIFEST = yaml.safe_load(asdf.get_config().resource_manager["asdf://stsci.edu/datamodels/roman/manifests/datamodels-1.0"]) + @pytest.fixture(scope="session") def manifest(): - return yaml.safe_load(asdf.get_config().resource_manager["asdf://stsci.edu/datamodels/roman/manifests/datamodels-1.0"]) + return MANIFEST @pytest.fixture(name="set_up_list_of_l2_files") diff --git a/tests/test_stnode.py b/tests/test_stnode.py index 0fb715fc..9f3621d6 100644 --- a/tests/test_stnode.py +++ b/tests/test_stnode.py @@ -8,18 +8,20 @@ from roman_datamodels import datamodels, maker_utils, stnode, validate from roman_datamodels.testing import assert_node_equal, assert_node_is_copy, wraps_hashable +from .conftest import MANIFEST -def test_generated_node_classes(manifest): - for tag in manifest["tags"]: - class_name = stnode._stnode._class_name_from_tag_uri(tag["tag_uri"]) - node_class = getattr(stnode, class_name) - - assert issubclass(node_class, (stnode.TaggedObjectNode, stnode.TaggedListNode, stnode.TaggedScalarNode)) - assert node_class._tag == tag["tag_uri"] - assert tag["description"] in node_class.__doc__ - assert tag["tag_uri"] in node_class.__doc__ - assert node_class.__module__.startswith(stnode.__name__) - assert node_class.__name__ in stnode.__all__ + +@pytest.mark.parametrize("tag", MANIFEST["tags"]) +def test_generated_node_classes(tag): + class_name = stnode._stnode._class_name_from_tag_uri(tag["tag_uri"]) + node_class = getattr(stnode, class_name) + + assert issubclass(node_class, (stnode.TaggedObjectNode, stnode.TaggedListNode, stnode.TaggedScalarNode)) + assert node_class._tag == tag["tag_uri"] + assert tag["description"] in node_class.__doc__ + assert tag["tag_uri"] in node_class.__doc__ + assert node_class.__module__ == stnode.__name__ + assert hasattr(stnode, node_class.__name__) @pytest.mark.parametrize("node_class", stnode.NODE_CLASSES) From 37825039fa81ee99548f37d85b6d3fcfef40c24b Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Fri, 23 Jun 2023 12:13:39 -0400 Subject: [PATCH 07/14] Replace the metaclasses with `__init_subclass__` --- src/roman_datamodels/stnode/_stnode.py | 4 +- src/roman_datamodels/stnode/_tagged.py | 95 ++++++++++++++------------ 2 files changed, 55 insertions(+), 44 deletions(-) diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index a9eccd1a..21f134a8 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -8,7 +8,7 @@ import yaml from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG -from ._tagged import TaggedObjectNode, TaggedScalarNode +from ._tagged import TaggedObjectNode, TaggedScalarNode, name_from_tag_uri __all__ = [ "NODE_CLASSES", @@ -20,7 +20,7 @@ def _class_name_from_tag_uri(tag_uri): - tag_name = tag_uri.split("/")[-1].split("-")[0] + tag_name = name_from_tag_uri(tag_uri) class_name = "".join([p.capitalize() for p in tag_name.split("_")]) if tag_uri.startswith("asdf://stsci.edu/datamodels/roman/tags/reference_files/"): class_name += "Ref" diff --git a/src/roman_datamodels/stnode/_tagged.py b/src/roman_datamodels/stnode/_tagged.py index ff7815cb..a715159f 100644 --- a/src/roman_datamodels/stnode/_tagged.py +++ b/src/roman_datamodels/stnode/_tagged.py @@ -1,5 +1,4 @@ import copy -from abc import ABCMeta import asdf @@ -13,30 +12,47 @@ def get_schema_from_tag(ctx, tag): + """ + Look up and load ASDF's schema corresponding to the tag_uri. + + Parameters + ---------- + ctx : + An ASDF file context. + tag : str + The tag_uri of the schema to load. + """ schema_uri = ctx.extension_manager.get_tag_definition(tag).schema_uris[0] return asdf.schema.load_schema(schema_uri, resolve_references=True) -class TaggedObjectNodeMeta(ABCMeta): - """ - Metaclass for TaggedObjectNode that maintains a registry - of subclasses. +def name_from_tag_uri(tag_uri): """ + Compute the name of the schema from the tag_uri. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.__name__ != "TaggedObjectNode": - if self._tag in OBJECT_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedObjectNode class for tag '{self._tag}' has been defined twice") - OBJECT_NODE_CLASSES_BY_TAG[self._tag] = self + Parameters + ---------- + tag_uri : str + The tag_uri to find the name from + """ + return tag_uri.split("/")[-1].split("-")[0] -class TaggedObjectNode(DNode, metaclass=TaggedObjectNodeMeta): +class TaggedObjectNode(DNode): """ - Expects subclass to define a class instance of _tag + Base class for all tagged objects defined by RAD + There will be one of these for any tagged object defined by RAD, which has + base type: object. """ + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + if cls.__name__ != "TaggedObjectNode": + if cls._tag in OBJECT_NODE_CLASSES_BY_TAG: + raise RuntimeError(f"TaggedObjectNode class for tag '{cls._tag}' has been defined twice") + OBJECT_NODE_CLASSES_BY_TAG[cls._tag] = cls + @property def tag(self): return self._tag @@ -51,49 +67,44 @@ def get_schema(self): return get_schema_from_tag(self.ctx, self._tag) -class TaggedListNodeMeta(ABCMeta): +class TaggedListNode(LNode): """ - Metaclass for TaggedListNode that maintains a registry - of subclasses. + Base class for all tagged list defined by RAD + There will be one of these for any tagged object defined by RAD, which has + base type: array. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.__name__ != "TaggedListNode": - if self._tag in LIST_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedListNode class for tag '{self._tag}' has been defined twice") - LIST_NODE_CLASSES_BY_TAG[self._tag] = self - + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + if cls.__name__ != "TaggedListNode": + if cls._tag in LIST_NODE_CLASSES_BY_TAG: + raise RuntimeError(f"TaggedListNode class for tag '{cls._tag}' has been defined twice") + LIST_NODE_CLASSES_BY_TAG[cls._tag] = cls -class TaggedListNode(LNode, metaclass=TaggedListNodeMeta): @property def tag(self): return self._tag -def _scalar_tag_to_key(tag): - return tag.split("/")[-1].split("-")[0] - - -class TaggedScalarNodeMeta(ABCMeta): +class TaggedScalarNode: """ - Metaclass for TaggedScalarNode that maintains a registry - of subclasses. + Base class for all tagged scalars defined by RAD + There will be one of these for any tagged object defined by RAD, which has + a scalar base type, or wraps a scalar base type. + These will all be in the tagged_scalars directory. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.__name__ != "TaggedScalarNode": - if self._tag in SCALAR_NODE_CLASSES_BY_TAG: - raise RuntimeError(f"TaggedScalarNode class for tag '{self._tag}' has been defined twice") - SCALAR_NODE_CLASSES_BY_TAG[self._tag] = self - SCALAR_NODE_CLASSES_BY_KEY[_scalar_tag_to_key(self._tag)] = self - - -class TaggedScalarNode(metaclass=TaggedScalarNodeMeta): _tag = None _ctx = None + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + if cls.__name__ != "TaggedScalarNode": + if cls._tag in SCALAR_NODE_CLASSES_BY_TAG: + raise RuntimeError(f"TaggedScalarNode class for tag '{cls._tag}' has been defined twice") + SCALAR_NODE_CLASSES_BY_TAG[cls._tag] = cls + SCALAR_NODE_CLASSES_BY_KEY[name_from_tag_uri(cls._tag)] = cls + @property def ctx(self): if self._ctx is None: @@ -109,7 +120,7 @@ def tag(self): @property def key(self): - return _scalar_tag_to_key(self._tag) + return name_from_tag_uri(self._tag) def get_schema(self): return get_schema_from_tag(self.ctx, self._tag) From e62d54a6f574ed04baeed8112b37ef89c8b722b1 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 16:23:43 -0400 Subject: [PATCH 08/14] Remove unnecessary `fixed` classes --- src/roman_datamodels/stnode/_converters.py | 3 +- src/roman_datamodels/stnode/_fixed.py | 16 +---- src/roman_datamodels/stnode/_stnode.py | 70 ++++++++++++++++++---- 3 files changed, 62 insertions(+), 27 deletions(-) diff --git a/src/roman_datamodels/stnode/_converters.py b/src/roman_datamodels/stnode/_converters.py index 997f7e7b..0acf1965 100644 --- a/src/roman_datamodels/stnode/_converters.py +++ b/src/roman_datamodels/stnode/_converters.py @@ -1,8 +1,9 @@ from asdf.extension import Converter from astropy.time import Time -from ._fixed import FileDate +from ._fixed import WfiMode # noqa: F401 from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG +from ._stnode import FileDate class TaggedObjectNodeConverter(Converter): diff --git a/src/roman_datamodels/stnode/_fixed.py b/src/roman_datamodels/stnode/_fixed.py index c7f3c850..d9a12258 100644 --- a/src/roman_datamodels/stnode/_fixed.py +++ b/src/roman_datamodels/stnode/_fixed.py @@ -1,6 +1,4 @@ -from astropy.time import Time - -from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode +from ._tagged import TaggedObjectNode class WfiMode(TaggedObjectNode): @@ -23,15 +21,3 @@ def grating(self): return self.optical_element else: return None - - -class CalLogs(TaggedListNode): - _tag = "asdf://stsci.edu/datamodels/roman/tags/cal_logs-1.0.0" - - __module__ = "roman_datamodels.stnode" - - -class FileDate(Time, TaggedScalarNode): - _tag = "asdf://stsci.edu/datamodels/roman/tags/file_date-1.0.0" - - __module__ = "roman_datamodels.stnode" diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index 21f134a8..acd7bcc0 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -6,18 +6,40 @@ import rad.resources import yaml +from astropy.time import Time from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG -from ._tagged import TaggedObjectNode, TaggedScalarNode, name_from_tag_uri +from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri __all__ = [ "NODE_CLASSES", ] +SCALAR_TYPE_MAP = { + "string": str, + "http://stsci.edu/schemas/asdf/time/time-1.1.0": Time, +} + + DATAMODELS_MANIFEST_PATH = importlib.resources.files(rad.resources) / "manifests" / "datamodels-1.0.yaml" DATAMODELS_MANIFEST = yaml.safe_load(DATAMODELS_MANIFEST_PATH.read_bytes()) +BASE_SCHEMA_PATH = importlib.resources.files(rad.resources) / "schemas" + + +def _load_schema_from_uri(schema_uri): + filename = f"{schema_uri.split('/')[-1]}.yaml" + + if "reference_files" in schema_uri: + schema_path = BASE_SCHEMA_PATH / "reference_files" / filename + elif "tagged_scalars" in schema_uri: + schema_path = BASE_SCHEMA_PATH / "tagged_scalars" / filename + else: + schema_path = BASE_SCHEMA_PATH / filename + + return yaml.safe_load(schema_path.read_bytes()) + def _class_name_from_tag_uri(tag_uri): tag_name = name_from_tag_uri(tag_uri) @@ -27,21 +49,47 @@ def _class_name_from_tag_uri(tag_uri): return class_name +def _scalar_class(tag, class_name, docstring): + schema = _load_schema_from_uri(tag["schema_uri"]) + + if "type" in schema: + type_ = schema["type"] + elif "allOf" in schema: + type_ = schema["allOf"][0]["$ref"] + else: + raise RuntimeError(f"Unknown schema type: {schema}") + + return type( + class_name, + (SCALAR_TYPE_MAP[type_], TaggedScalarNode), + {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring}, + ) + + +def _node_class(tag, class_name, docstring): + schema = _load_schema_from_uri(tag["schema_uri"]) + + if schema["type"] == "object": + class_type = TaggedObjectNode + elif schema["type"] == "array": + class_type = TaggedListNode + else: + raise RuntimeError(f"Unknown schema type: {schema['type']}") + + return type( + class_name, + (class_type,), + {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring}, + ) + + def _class_from_tag(tag, docstring): class_name = _class_name_from_tag_uri(tag["tag_uri"]) if "tagged_scalar" in tag["schema_uri"]: - cls = type( - class_name, - (str, TaggedScalarNode), - {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring}, - ) + cls = _scalar_class(tag, class_name, docstring) else: - cls = type( - class_name, - (TaggedObjectNode,), - {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring}, - ) + cls = _node_class(tag, class_name, docstring) globals()[class_name] = cls __all__.append(class_name) From 3843bcb75ab8cbd7207c0804c08831f3c870ef8d Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 16:45:50 -0400 Subject: [PATCH 09/14] Turn custom subclassing into mixin --- src/roman_datamodels/stnode/__init__.py | 1 - src/roman_datamodels/stnode/_converters.py | 1 - src/roman_datamodels/stnode/{_fixed.py => _mixins.py} | 9 +-------- src/roman_datamodels/stnode/_stnode.py | 8 +++++++- tests/test_stnode.py | 6 ++++++ 5 files changed, 14 insertions(+), 11 deletions(-) rename src/roman_datamodels/stnode/{_fixed.py => _mixins.py} (69%) diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py index 2a186aa4..1df3cbe3 100644 --- a/src/roman_datamodels/stnode/__init__.py +++ b/src/roman_datamodels/stnode/__init__.py @@ -1,5 +1,4 @@ from ._converters import * # noqa: F403 -from ._fixed import * # noqa: F403 from ._node import * # noqa: F403 from ._stnode import * # noqa: F403 from ._tagged import * # noqa: F403 diff --git a/src/roman_datamodels/stnode/_converters.py b/src/roman_datamodels/stnode/_converters.py index 0acf1965..f44fd8be 100644 --- a/src/roman_datamodels/stnode/_converters.py +++ b/src/roman_datamodels/stnode/_converters.py @@ -1,7 +1,6 @@ from asdf.extension import Converter from astropy.time import Time -from ._fixed import WfiMode # noqa: F401 from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG from ._stnode import FileDate diff --git a/src/roman_datamodels/stnode/_fixed.py b/src/roman_datamodels/stnode/_mixins.py similarity index 69% rename from src/roman_datamodels/stnode/_fixed.py rename to src/roman_datamodels/stnode/_mixins.py index d9a12258..dff257a7 100644 --- a/src/roman_datamodels/stnode/_fixed.py +++ b/src/roman_datamodels/stnode/_mixins.py @@ -1,13 +1,6 @@ -from ._tagged import TaggedObjectNode - - -class WfiMode(TaggedObjectNode): - _tag = "asdf://stsci.edu/datamodels/roman/tags/wfi_mode-1.0.0" - +class WfiModeMixin: _GRATING_OPTICAL_ELEMENTS = {"GRISM", "PRISM"} - __module__ = "roman_datamodels.stnode" - @property def filter(self): if self.optical_element in self._GRATING_OPTICAL_ELEMENTS: diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index acd7bcc0..9b2e440f 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -8,6 +8,7 @@ import yaml from astropy.time import Time +from . import _mixins from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri @@ -76,9 +77,14 @@ def _node_class(tag, class_name, docstring): else: raise RuntimeError(f"Unknown schema type: {schema['type']}") + if hasattr(_mixins, mixin := f"{class_name}Mixin"): + class_type = (class_type, getattr(_mixins, mixin)) + else: + class_type = (class_type,) + return type( class_name, - (class_type,), + class_type, {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring}, ) diff --git a/tests/test_stnode.py b/tests/test_stnode.py index 9f3621d6..031e3db0 100644 --- a/tests/test_stnode.py +++ b/tests/test_stnode.py @@ -66,16 +66,22 @@ def test_wfi_mode(): assert node.optical_element == "GRISM" assert node.grating == "GRISM" assert node.filter is None + assert isinstance(node, stnode.DNode) + assert isinstance(node, stnode._mixins.WfiModeMixin) node = stnode.WfiMode({"optical_element": "PRISM"}) assert node.optical_element == "PRISM" assert node.grating == "PRISM" assert node.filter is None + assert isinstance(node, stnode.DNode) + assert isinstance(node, stnode._mixins.WfiModeMixin) node = stnode.WfiMode({"optical_element": "F129"}) assert node.optical_element == "F129" assert node.grating is None assert node.filter == "F129" + assert isinstance(node, stnode.DNode) + assert isinstance(node, stnode._mixins.WfiModeMixin) @pytest.mark.parametrize("node_class", stnode.NODE_CLASSES) From 722f4c1a069690da98083f64b354d2b53684ebac Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 18:46:10 -0400 Subject: [PATCH 10/14] Pull out class factories --- src/roman_datamodels/stnode/_factories.py | 187 ++++++++++++++++++++++ src/roman_datamodels/stnode/_stnode.py | 123 +++----------- tests/test_stnode.py | 2 +- 3 files changed, 213 insertions(+), 99 deletions(-) create mode 100644 src/roman_datamodels/stnode/_factories.py diff --git a/src/roman_datamodels/stnode/_factories.py b/src/roman_datamodels/stnode/_factories.py new file mode 100644 index 00000000..297f95ba --- /dev/null +++ b/src/roman_datamodels/stnode/_factories.py @@ -0,0 +1,187 @@ +""" +Factories for creating Tagged STNode classes from tag_uris. + These are used to dynamically create classes from the RAD manifest. +""" + +import importlib.resources + +import yaml +from astropy.time import Time +from rad import resources + +from . import _mixins +from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri + +# Map of scalar types in the schemas to the python types +SCALAR_TYPE_MAP = { + "string": str, + "http://stsci.edu/schemas/asdf/time/time-1.1.0": Time, +} + +BASE_SCHEMA_PATH = importlib.resources.files(resources) / "schemas" + + +def load_schema_from_uri(schema_uri): + """ + Load the actual schema from the rad resources directly (outside ASDF) + Outside ASDF because this has to occur before the ASDF extensions are + registered. + + Parameters + ---------- + schema_uri : str + The schema_uri found in the RAD manifest + + Returns + ------- + yaml library dictionary from the schema + """ + filename = f"{schema_uri.split('/')[-1]}.yaml" + + if "reference_files" in schema_uri: + schema_path = BASE_SCHEMA_PATH / "reference_files" / filename + elif "tagged_scalars" in schema_uri: + schema_path = BASE_SCHEMA_PATH / "tagged_scalars" / filename + else: + schema_path = BASE_SCHEMA_PATH / filename + + return yaml.safe_load(schema_path.read_bytes()) + + +def class_name_from_tag_uri(tag_uri): + """ + Construct the class name for the STNode class from the tag_uri + + Parameters + ---------- + tag_uri : str + The tag_uri found in the RAD manifest + + Returns + ------- + string name for the class + """ + tag_name = name_from_tag_uri(tag_uri) + class_name = "".join([p.capitalize() for p in tag_name.split("_")]) + if tag_uri.startswith("asdf://stsci.edu/datamodels/roman/tags/reference_files/"): + class_name += "Ref" + + return class_name + + +def docstring_from_tag(tag): + """ + Read the docstring (if it exists) from the RAD manifest and generate a docstring + for the dynamically generated class. + + Parameters + ---------- + tag: dict + A tag entry from the RAD manifest + + Returns + ------- + A docstring for the class based on the tag + """ + docstring = f"{tag['description']}\n\n" if "description" in tag else "" + + return docstring + f"Class generated from tag '{tag['tag_uri']}'" + + +def scalar_factory(tag): + """ + Factory to create a TaggedScalarNode class from a tag + + Parameters + ---------- + tag: dict + A tag entry from the RAD manifest + + Returns + ------- + A dynamically generated TaggedScalarNode subclass + """ + class_name = class_name_from_tag_uri(tag["tag_uri"]) + schema = load_schema_from_uri(tag["schema_uri"]) + + # TaggedScalarNode subclasses are really subclasses of the type of the scalar, + # with the TaggedScalarNode as a mixin. This is because the TaggedScalarNode + # is supposed to be the scalar, but it needs to be serializable under a specific + # ASDF tag. + # SCALAR_TYPE_MAP will need to be updated as new wrappers of scalar types are added + # to the RAD manifest. + if "type" in schema: + type_ = schema["type"] + elif "allOf" in schema: + type_ = schema["allOf"][0]["$ref"] + else: + raise RuntimeError(f"Unknown schema type: {schema}") + + return type( + class_name, + (SCALAR_TYPE_MAP[type_], TaggedScalarNode), + {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)}, + ) + + +def node_factory(tag): + """ + Factory to create a TaggedObjectNode or TaggedListNode class from a tag + + Parameters + ---------- + tag: dict + A tag entry from the RAD manifest + + Returns + ------- + A dynamically generated TaggedObjectNode or TaggedListNode subclass + """ + class_name = class_name_from_tag_uri(tag["tag_uri"]) + schema = load_schema_from_uri(tag["schema_uri"]) + + # Determine if the class is a TaggedObjectNode or TaggedListNode based on the + # type defined in the schema: + # - TaggedObjectNode if type is "object" + # - TaggedListNode if type is "array" (array in jsonschema represents Python list) + if schema["type"] == "object": + class_type = TaggedObjectNode + elif schema["type"] == "array": + class_type = TaggedListNode + else: + raise RuntimeError(f"Unknown schema type: {schema['type']}") + + # In special cases one may need to add additional features to a tagged node class. + # This is done by creating a mixin class with the name Mixin in _mixins.py + # Here we mixin the mixin class if it exists. + if hasattr(_mixins, mixin := f"{class_name}Mixin"): + class_type = (class_type, getattr(_mixins, mixin)) + else: + class_type = (class_type,) + + return type( + class_name, + class_type, + {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)}, + ) + + +def stnode_factory(tag): + """ + Construct a tagged STNode class from a tag + + Parameters + ---------- + tag: dict + A tag entry from the RAD manifest + + Returns + ------- + A dynamically generated TaggedScalarNode, TaggedObjectNode, or TaggedListNode subclass + """ + # TaggedScalarNodes are a special case because they are not a subclass of a + # _node class, but rather a subclass of the type of the scalar. + if "tagged_scalar" in tag["schema_uri"]: + return scalar_factory(tag) + else: + return node_factory(tag) diff --git a/src/roman_datamodels/stnode/_stnode.py b/src/roman_datamodels/stnode/_stnode.py index 9b2e440f..f829a60c 100644 --- a/src/roman_datamodels/stnode/_stnode.py +++ b/src/roman_datamodels/stnode/_stnode.py @@ -1,124 +1,51 @@ """ -Proof of concept of using tags with the data model framework +Dynamic creation of STNode classes from the RAD manifest. + This module will create all the STNode based classes used by roman_datamodels. + Unfortunately, this is a dynamic process which occurs at first import time because + roman_datamodels cannot predict what STNode objects will be in the version of RAD + used by the user. """ import importlib.resources -import rad.resources import yaml -from astropy.time import Time +from rad import resources -from . import _mixins +from ._factories import stnode_factory from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG -from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri __all__ = [ "NODE_CLASSES", ] -SCALAR_TYPE_MAP = { - "string": str, - "http://stsci.edu/schemas/asdf/time/time-1.1.0": Time, -} - - -DATAMODELS_MANIFEST_PATH = importlib.resources.files(rad.resources) / "manifests" / "datamodels-1.0.yaml" +# Load the manifest directly from the rad resources and not from ASDF. +# This is because the ASDF extensions have to be created before they can be registered +# and this module creates the classes used by the ASDF extension. +DATAMODELS_MANIFEST_PATH = importlib.resources.files(resources) / "manifests" / "datamodels-1.0.yaml" DATAMODELS_MANIFEST = yaml.safe_load(DATAMODELS_MANIFEST_PATH.read_bytes()) -BASE_SCHEMA_PATH = importlib.resources.files(rad.resources) / "schemas" - - -def _load_schema_from_uri(schema_uri): - filename = f"{schema_uri.split('/')[-1]}.yaml" - - if "reference_files" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "reference_files" / filename - elif "tagged_scalars" in schema_uri: - schema_path = BASE_SCHEMA_PATH / "tagged_scalars" / filename - else: - schema_path = BASE_SCHEMA_PATH / filename - - return yaml.safe_load(schema_path.read_bytes()) - - -def _class_name_from_tag_uri(tag_uri): - tag_name = name_from_tag_uri(tag_uri) - class_name = "".join([p.capitalize() for p in tag_name.split("_")]) - if tag_uri.startswith("asdf://stsci.edu/datamodels/roman/tags/reference_files/"): - class_name += "Ref" - return class_name - - -def _scalar_class(tag, class_name, docstring): - schema = _load_schema_from_uri(tag["schema_uri"]) - if "type" in schema: - type_ = schema["type"] - elif "allOf" in schema: - type_ = schema["allOf"][0]["$ref"] - else: - raise RuntimeError(f"Unknown schema type: {schema}") +def _factory(tag): + """ + Wrap the __all__ append and class creation in a function to avoid the linter + getting upset + """ + cls = stnode_factory(tag) - return type( - class_name, - (SCALAR_TYPE_MAP[type_], TaggedScalarNode), - {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring}, - ) - - -def _node_class(tag, class_name, docstring): - schema = _load_schema_from_uri(tag["schema_uri"]) - - if schema["type"] == "object": - class_type = TaggedObjectNode - elif schema["type"] == "array": - class_type = TaggedListNode - else: - raise RuntimeError(f"Unknown schema type: {schema['type']}") - - if hasattr(_mixins, mixin := f"{class_name}Mixin"): - class_type = (class_type, getattr(_mixins, mixin)) - else: - class_type = (class_type,) - - return type( - class_name, - class_type, - {"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring}, - ) - - -def _class_from_tag(tag, docstring): - class_name = _class_name_from_tag_uri(tag["tag_uri"]) - - if "tagged_scalar" in tag["schema_uri"]: - cls = _scalar_class(tag, class_name, docstring) - else: - cls = _node_class(tag, class_name, docstring) - - globals()[class_name] = cls - __all__.append(class_name) + class_name = cls.__name__ + globals()[class_name] = cls # Add to namespace of module + __all__.append(class_name) # add to __all__ so it's imported with `from . import *` +# Main dynamic class creation loop +# Reads each tag entry from the manifest and creates a class for it for tag in DATAMODELS_MANIFEST["tags"]: - docstring = "" - if "description" in tag: - docstring = tag["description"] + "\n\n" - docstring = docstring + f"Class generated from tag '{tag['tag_uri']}'" - - if tag["tag_uri"] in OBJECT_NODE_CLASSES_BY_TAG: - OBJECT_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring - elif tag["tag_uri"] in LIST_NODE_CLASSES_BY_TAG: - LIST_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring - elif tag["tag_uri"] in SCALAR_NODE_CLASSES_BY_TAG: - SCALAR_NODE_CLASSES_BY_TAG[tag["tag_uri"]].__doc__ = docstring - else: - _class_from_tag(tag, docstring) + _factory(tag) -# List of node classes made available by this library. This is part -# of the public API. +# List of node classes made available by this library. +# This is part of the public API. NODE_CLASSES = ( list(OBJECT_NODE_CLASSES_BY_TAG.values()) + list(LIST_NODE_CLASSES_BY_TAG.values()) diff --git a/tests/test_stnode.py b/tests/test_stnode.py index 031e3db0..07a055db 100644 --- a/tests/test_stnode.py +++ b/tests/test_stnode.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("tag", MANIFEST["tags"]) def test_generated_node_classes(tag): - class_name = stnode._stnode._class_name_from_tag_uri(tag["tag_uri"]) + class_name = stnode._factories.class_name_from_tag_uri(tag["tag_uri"]) node_class = getattr(stnode, class_name) assert issubclass(node_class, (stnode.TaggedObjectNode, stnode.TaggedListNode, stnode.TaggedScalarNode)) From 651af18ddc0d8e1dbdd3022aa6d6e769624c3bf7 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Wed, 21 Jun 2023 18:41:35 -0400 Subject: [PATCH 11/14] Add better documentation for stnode. --- src/roman_datamodels/stnode/_converters.py | 8 +++++++- src/roman_datamodels/stnode/_mixins.py | 5 +++++ src/roman_datamodels/stnode/_node.py | 12 ++++++++++++ src/roman_datamodels/stnode/_registry.py | 5 +++++ src/roman_datamodels/stnode/_tagged.py | 17 +++++++++++++++++ 5 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/roman_datamodels/stnode/_converters.py b/src/roman_datamodels/stnode/_converters.py index f44fd8be..43176a4d 100644 --- a/src/roman_datamodels/stnode/_converters.py +++ b/src/roman_datamodels/stnode/_converters.py @@ -1,8 +1,10 @@ +""" +The ASDF Converters to handle the serialization/deseialization of the STNode classes to ASDF. +""" from asdf.extension import Converter from astropy.time import Time from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG -from ._stnode import FileDate class TaggedObjectNodeConverter(Converter): @@ -68,6 +70,8 @@ def select_tag(self, obj, tags, ctx): return obj.tag def to_yaml_tree(self, obj, tag, ctx): + from ._stnode import FileDate + node = obj.__class__.__bases__[0](obj) if tag == FileDate._tag: @@ -77,6 +81,8 @@ def to_yaml_tree(self, obj, tag, ctx): return node def from_yaml_tree(self, node, tag, ctx): + from ._stnode import FileDate + if tag == FileDate._tag: converter = ctx.extension_manager.get_converter_for_type(Time) node = converter.from_yaml_tree(node, tag, ctx) diff --git a/src/roman_datamodels/stnode/_mixins.py b/src/roman_datamodels/stnode/_mixins.py index dff257a7..461a81af 100644 --- a/src/roman_datamodels/stnode/_mixins.py +++ b/src/roman_datamodels/stnode/_mixins.py @@ -1,3 +1,8 @@ +""" +Mixin classes for additional functionality for STNode classes +""" + + class WfiModeMixin: _GRATING_OPTICAL_ELEMENTS = {"GRISM", "PRISM"} diff --git a/src/roman_datamodels/stnode/_node.py b/src/roman_datamodels/stnode/_node.py index 2eb0b81c..a9d9bb95 100644 --- a/src/roman_datamodels/stnode/_node.py +++ b/src/roman_datamodels/stnode/_node.py @@ -1,3 +1,7 @@ +""" +Base node classes for all STNode classes. + These are the base classes for the data objects used by the datamodels package. +""" import datetime import re import warnings @@ -83,6 +87,10 @@ def _get_schema_for_property(schema, attr): class DNode(MutableMapping): + """ + Base class describing all "object" (dict-like) data nodes for STNode classes. + """ + _tag = None _ctx = None @@ -220,6 +228,10 @@ def copy(self): class LNode(UserList): + """ + Base class describing all "array" (list-like) data nodes for STNode classes. + """ + _tag = None def __init__(self, node=None): diff --git a/src/roman_datamodels/stnode/_registry.py b/src/roman_datamodels/stnode/_registry.py index cf025ee2..e6827d05 100644 --- a/src/roman_datamodels/stnode/_registry.py +++ b/src/roman_datamodels/stnode/_registry.py @@ -1,3 +1,8 @@ +""" +Hold all the registry information for the STNode classes. + These will be dynamically populated at import time by the subclasses + whenever they generated. +""" OBJECT_NODE_CLASSES_BY_TAG = {} LIST_NODE_CLASSES_BY_TAG = {} SCALAR_NODE_CLASSES_BY_TAG = {} diff --git a/src/roman_datamodels/stnode/_tagged.py b/src/roman_datamodels/stnode/_tagged.py index a715159f..d742eab1 100644 --- a/src/roman_datamodels/stnode/_tagged.py +++ b/src/roman_datamodels/stnode/_tagged.py @@ -1,3 +1,8 @@ +""" +Base classes for all the tagged objects defined by RAD. + Each tagged object will be dynamically created at runtime by _stnode.py + from RAD's manifest. +""" import copy import asdf @@ -47,6 +52,10 @@ class TaggedObjectNode(DNode): """ def __init_subclass__(cls, **kwargs) -> None: + """ + Register any subclasses of this class in the OBJECT_NODE_CLASSES_BY_TAG + registry. + """ super().__init_subclass__(**kwargs) if cls.__name__ != "TaggedObjectNode": if cls._tag in OBJECT_NODE_CLASSES_BY_TAG: @@ -75,6 +84,10 @@ class TaggedListNode(LNode): """ def __init_subclass__(cls, **kwargs) -> None: + """ + Register any subclasses of this class in the LIST_NODE_CLASSES_BY_TAG + registry. + """ super().__init_subclass__(**kwargs) if cls.__name__ != "TaggedListNode": if cls._tag in LIST_NODE_CLASSES_BY_TAG: @@ -98,6 +111,10 @@ class TaggedScalarNode: _ctx = None def __init_subclass__(cls, **kwargs) -> None: + """ + Register any subclasses of this class in the SCALAR_NODE_CLASSES_BY_TAG + and SCALAR_NODE_CLASSES_BY_KEY registry. + """ super().__init_subclass__(**kwargs) if cls.__name__ != "TaggedScalarNode": if cls._tag in SCALAR_NODE_CLASSES_BY_TAG: From 38da1096f0c38a4c9e8ced5f156f98a21ed72c02 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 19:41:24 -0400 Subject: [PATCH 12/14] Fix the __all__ to conform to the old API --- src/roman_datamodels/stnode/__init__.py | 7 +++++-- src/roman_datamodels/stnode/_converters.py | 6 ++++++ src/roman_datamodels/stnode/_factories.py | 2 ++ src/roman_datamodels/stnode/_mixins.py | 1 + src/roman_datamodels/stnode/_node.py | 2 ++ src/roman_datamodels/stnode/_tagged.py | 6 ++++++ 6 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/roman_datamodels/stnode/__init__.py b/src/roman_datamodels/stnode/__init__.py index 1df3cbe3..1b2c522b 100644 --- a/src/roman_datamodels/stnode/__init__.py +++ b/src/roman_datamodels/stnode/__init__.py @@ -1,6 +1,9 @@ +""" +The STNode classes and supporting objects generated dynamically at import time + from RAD's manifest. +""" from ._converters import * # noqa: F403 +from ._mixins import * # noqa: F403 from ._node import * # noqa: F403 from ._stnode import * # noqa: F403 from ._tagged import * # noqa: F403 - -__all__ = [v.__name__ for v in globals().values() if hasattr(v, "__name__")] diff --git a/src/roman_datamodels/stnode/_converters.py b/src/roman_datamodels/stnode/_converters.py index 43176a4d..38094f03 100644 --- a/src/roman_datamodels/stnode/_converters.py +++ b/src/roman_datamodels/stnode/_converters.py @@ -6,6 +6,12 @@ from ._registry import LIST_NODE_CLASSES_BY_TAG, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG +__all__ = [ + "TaggedObjectNodeConverter", + "TaggedListNodeConverter", + "TaggedScalarNodeConverter", +] + class TaggedObjectNodeConverter(Converter): """ diff --git a/src/roman_datamodels/stnode/_factories.py b/src/roman_datamodels/stnode/_factories.py index 297f95ba..4bee5790 100644 --- a/src/roman_datamodels/stnode/_factories.py +++ b/src/roman_datamodels/stnode/_factories.py @@ -12,6 +12,8 @@ from . import _mixins from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri +__all__ = ["stnode_factory"] + # Map of scalar types in the schemas to the python types SCALAR_TYPE_MAP = { "string": str, diff --git a/src/roman_datamodels/stnode/_mixins.py b/src/roman_datamodels/stnode/_mixins.py index 461a81af..acd97fe0 100644 --- a/src/roman_datamodels/stnode/_mixins.py +++ b/src/roman_datamodels/stnode/_mixins.py @@ -1,6 +1,7 @@ """ Mixin classes for additional functionality for STNode classes """ +__all__ = ["WfiModeMixin"] class WfiModeMixin: diff --git a/src/roman_datamodels/stnode/_node.py b/src/roman_datamodels/stnode/_node.py index a9d9bb95..728937f7 100644 --- a/src/roman_datamodels/stnode/_node.py +++ b/src/roman_datamodels/stnode/_node.py @@ -21,6 +21,8 @@ from ._registry import SCALAR_NODE_CLASSES_BY_KEY +__all__ = ["DNode", "LNode"] + validator_callbacks = HashableDict(asdfschema.YAML_VALIDATORS) validator_callbacks.update({"type": _check_type}) diff --git a/src/roman_datamodels/stnode/_tagged.py b/src/roman_datamodels/stnode/_tagged.py index d742eab1..f4bc0186 100644 --- a/src/roman_datamodels/stnode/_tagged.py +++ b/src/roman_datamodels/stnode/_tagged.py @@ -15,6 +15,12 @@ SCALAR_NODE_CLASSES_BY_TAG, ) +__all__ = [ + "TaggedObjectNode", + "TaggedListNode", + "TaggedScalarNode", +] + def get_schema_from_tag(ctx, tag): """ From b5ce8e6e73951f7239f8839947525f702a764258 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 20:48:15 -0400 Subject: [PATCH 13/14] Fix docs. Need to ignore some objects due to ASDF bug --- docs/roman_datamodels/datamodels/developer_api.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/roman_datamodels/datamodels/developer_api.rst b/docs/roman_datamodels/datamodels/developer_api.rst index 6ee92f9d..1a5f5b34 100644 --- a/docs/roman_datamodels/datamodels/developer_api.rst +++ b/docs/roman_datamodels/datamodels/developer_api.rst @@ -8,7 +8,13 @@ Developer API .. automodapi:: roman_datamodels.integration +.. + The converters are not documented because of https://github.com/asdf-format/asdf/issues/1565 + .. automodapi:: roman_datamodels.stnode + :skip: TaggedObjectNodeConverter + :skip: TaggedListNodeConverter + :skip: TaggedScalarNodeConverter .. automodapi:: roman_datamodels.table_definitions From fabcc6f43ea5ec544fd0c5acdbec66b10f3f82ae Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Tue, 20 Jun 2023 20:50:48 -0400 Subject: [PATCH 14/14] Update changes --- CHANGES.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 70a7956f..6e01922c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -33,6 +33,9 @@ - Update ``roman_datamodels`` to support the new ``msos_stack-1.0.0`` schema. [#206] +- Refactor ``stnode`` to be easier to maintain and test by turning it into a + sub-package and splitting the module apart. [#213] + 0.15.0 (2023-05-15) ===================