Skip to content

Commit

Permalink
Merge pull request #847 from eslavich/eslavich-add-extension-config
Browse files Browse the repository at this point in the history
Allow installed extensions to be configured with AsdfConfig
  • Loading branch information
eslavich authored Aug 4, 2020
2 parents 16c8866 + c900f3f commit 00908dd
Show file tree
Hide file tree
Showing 18 changed files with 525 additions and 128 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- Add new resource mapping API for extending asdf with additional
schemas. [#819, #828, #843, #846]

- Add global configuration mechanism. [#819, #839, #844]
- Add global configuration mechanism. [#819, #839, #844, #847]

- Drop support for automatic serialization of subclass
attributes. [#825]
Expand Down
124 changes: 79 additions & 45 deletions asdf/asdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from . import yamlutil
from . import _display as display
from .exceptions import AsdfDeprecationWarning, AsdfWarning, AsdfConversionWarning
from .extension import AsdfExtensionList, default_extensions
from .extension import AsdfExtensionList, AsdfExtension, ExtensionProxy
from .util import NotSet
from .search import AsdfSearchResult
from ._helpers import validate_version
Expand Down Expand Up @@ -64,9 +64,8 @@ def __init__(self, tree=None, uri=None, extensions=None, version=None,
automatically determined from the associated file object,
if possible and if created from `AsdfFile.open`.
extensions : list of AsdfExtension
A list of extensions to use when reading and writing ASDF files.
See `~asdf.types.AsdfExtension` for more information.
extensions : asdf.extension.AsdfExtension or asdf.extension.AsdfExtensionList or list of asdf.extension.AsdfExtension
Additional extensions to use when reading and writing the file.
version : str, optional
The ASDF Standard version. If not provided, defaults to the
Expand Down Expand Up @@ -105,16 +104,13 @@ def __init__(self, tree=None, uri=None, extensions=None, version=None,
validation pass. This can be used to ensure that particular ASDF
files follow custom conventions beyond those enforced by the
standard.
"""
if version is None:
self.version = get_config().default_version
else:
self.version = version

self._extensions = []
self._extension_metadata = {}
self._process_extensions(extensions)
self.extensions = extensions

if custom_schema is not None:
self._custom_schema = schema._load_schema_cached(custom_schema, self.resolver, True, False)
Expand Down Expand Up @@ -148,7 +144,7 @@ def __init__(self, tree=None, uri=None, extensions=None, version=None,
# an empty tree.
self._tree = AsdfObject()
elif isinstance(tree, AsdfFile):
if self._extensions != tree._extensions:
if self.extensions != tree.extensions:
raise ValueError(
"Can not copy AsdfFile and change active extensions")
self._uri = tree.uri
Expand Down Expand Up @@ -202,6 +198,43 @@ def version_string(self):
def version_map(self):
return versioning.get_version_map(self.version_string)

@property
def extensions(self):
"""
Get the list of extensions that are enabled for
use with this AsdfFile.
Returns
-------
list of asdf.extension.AsdfExtension
"""
return self._extensions

@extensions.setter
def extensions(self, value):
"""
Set the list of extensions that are enabled for
use with this AsdfFile.
Parameters
----------
value : list of asdf.extension.AsdfExtension
"""
self._extensions = self._process_extensions(value)
self._extension_list = None

@property
def extension_list(self):
"""
Get the AsdfExtensionList for this AsdfFile.
Returns
-------
asdf.extension.AsdfExtensionList
"""
if self._extension_list is None:
self._extension_list = AsdfExtensionList(self.extensions)
return self._extension_list

def __enter__(self):
return self

Expand All @@ -214,8 +247,10 @@ def _check_extensions(self, tree, strict=False):
return

for extension in tree['history']['extensions']:
installed = next((e for e in self.extensions if e.class_name == extension.extension_class), None)

filename = "'{}' ".format(self._fname) if self._fname else ''
if extension.extension_class not in self._extension_metadata:
if installed is None:
msg = "File {}was created with extension '{}', which is " \
"not currently installed"
if extension.software:
Expand All @@ -229,45 +264,47 @@ def _check_extensions(self, tree, strict=False):
warnings.warn(fmt_msg, AsdfWarning)

elif extension.software:
installed = self._extension_metadata[extension.extension_class]
# Local extensions may not have a real version
if not installed[1]:
if not installed.package_version:
continue
# Compare version in file metadata with installed version
if parse_version(installed[1]) < parse_version(extension.software['version']):
if parse_version(installed.package_version) < parse_version(extension.software['version']):
msg = "File {}was created with extension '{}' from " \
"package {}-{}, but older version {}-{} is installed"
fmt_msg = msg.format(
filename, extension.extension_class,
extension.software['name'],
extension.software['version'],
installed[0], installed[1])
installed.package_name, installed.package_version)
if strict:
raise RuntimeError(fmt_msg)
else:
warnings.warn(fmt_msg, AsdfWarning)

def _process_extensions(self, extensions):
if extensions is None or extensions == []:
self._extensions = default_extensions.extension_list
self._extension_metadata = default_extensions.package_metadata
return

if isinstance(extensions, AsdfExtensionList):
self._extensions = extensions
return
def _process_extensions(self, requested_extensions):
if requested_extensions is None:
requested_extensions = []
elif isinstance(requested_extensions, (AsdfExtension, ExtensionProxy)):
requested_extensions = [requested_extensions]
elif isinstance(requested_extensions, AsdfExtensionList):
requested_extensions = requested_extensions.extensions

if not isinstance(requested_extensions, list):
raise TypeError(
"The extensions parameter must be an AsdfExtension, AsdfExtensionList, "
"or list of AsdfExtension."
)

if not isinstance(extensions, list):
extensions = [extensions]
requested_extensions = [ExtensionProxy.maybe_wrap(e) for e in requested_extensions]

# Process metadata about custom extensions
for extension in extensions:
ext_name = util.get_class_name(extension)
self._extension_metadata[ext_name] = ('', '')
extensions = []
# Add requested extensions to the list first, so that they
# take precedence.
for extension in requested_extensions + get_config().extensions:
if extension not in extensions:
extensions.append(extension)

extensions = default_extensions.extensions + extensions
self._extensions = AsdfExtensionList(extensions)
self._extension_metadata.update(default_extensions.package_metadata)
return extensions

def _update_extension_history(self):
if self.version < versioning.NEW_HISTORY_FORMAT_MIN_VERSION:
Expand All @@ -287,11 +324,10 @@ def _update_extension_history(self):
self.tree['history']['extensions'] = []

for extension in self.type_index.get_extensions_used():
ext_name = util.get_class_name(extension)
ext_name = extension.class_name
ext_meta = ExtensionMetadata(extension_class=ext_name)
metadata = self._extension_metadata.get(ext_name)
if metadata is not None:
ext_meta['software'] = Software(name=metadata[0], version=metadata[1])
if extension.package_name is not None:
ext_meta['software'] = Software(name=extension.package_name, version=extension.package_version)

for i, entry in enumerate(self.tree['history']['extensions']):
# Update metadata about this extension if it already exists
Expand Down Expand Up @@ -352,23 +388,23 @@ def tag_to_schema_resolver(self):
"The 'tag_to_schema_resolver' property is deprecated. Use "
"'tag_mapping' instead.",
AsdfDeprecationWarning)
return self._extensions.tag_mapping
return self.extension_list.tag_mapping

@property
def tag_mapping(self):
return self._extensions.tag_mapping
return self.extension_list.tag_mapping

@property
def url_mapping(self):
return self._extensions.url_mapping
return self.extension_list.url_mapping

@property
def resolver(self):
return self._extensions.resolver
return self.extension_list.resolver

@property
def type_index(self):
return self._extensions.type_index
return self.extension_list.type_index

def resolve_uri(self, uri):
"""
Expand Down Expand Up @@ -769,7 +805,6 @@ def _open_impl(cls, self, fd, uri=None, mode='r',
strict_extension_check=strict_extension_check,
ignore_missing_extensions=ignore_missing_extensions,
ignore_unrecognized_tag=self._ignore_unrecognized_tag,
_extension_metadata=self._extension_metadata,
**kwargs)
except ValueError:
raise ValueError(
Expand Down Expand Up @@ -1452,9 +1487,8 @@ def open_asdf(fd, uri=None, mode=None, validate_checksums=False,
If `True`, validate the blocks against their checksums.
Requires reading the entire file, so disabled by default.
extensions : list of AsdfExtension
A list of extensions to use when reading and writing ASDF files.
See `~asdf.types.AsdfExtension` for more information.
extensions : asdf.extension.AsdfExtension or asdf.extension.AsdfExtensionList or list of asdf.extension.AsdfExtension
Additional extensions to use when reading and writing the file.
do_not_fill_defaults : bool, optional
When `True`, do not fill in missing default values.
Expand Down
2 changes: 1 addition & 1 deletion asdf/commands/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _qualified_name(_class):
def list_tags(display_classes=False, iostream=sys.stdout):
"""Function to list tags"""
af = AsdfFile()
type_by_tag = af._extensions._type_index._type_by_tag
type_by_tag = af.type_index._type_by_tag
tags = sorted(type_by_tag.keys())

for tag in tags:
Expand Down
4 changes: 2 additions & 2 deletions asdf/commands/tests/test_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_list_schemas():
obs_tags = _get_tags(False)

af = AsdfFile()
exp_tags = sorted(af._extensions._type_index._type_by_tag.keys())
exp_tags = sorted(af.type_index._type_by_tag.keys())

for exp, obs in zip(exp_tags, obs_tags):
assert exp == obs
Expand All @@ -25,7 +25,7 @@ def test_list_schemas_and_tags():
tag_lines = _get_tags(True)

af = AsdfFile()
type_by_tag = af._extensions._type_index._type_by_tag
type_by_tag = af.type_index._type_by_tag
exp_tags = sorted(type_by_tag.keys())

for exp_tag, line in zip(exp_tags, tag_lines):
Expand Down
57 changes: 57 additions & 0 deletions asdf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .resource import ResourceMappingProxy, ResourceManager
from . import versioning
from ._helpers import validate_version
from .extension import ExtensionProxy

__all__ = ["AsdfConfig", "get_config", "config_context"]

Expand All @@ -28,6 +29,7 @@ class AsdfConfig:
def __init__(self):
self._resource_mappings = None
self._resource_manager = None
self._extensions = None
self._validate_on_read = DEFAULT_VALIDATE_ON_READ
self._default_version = DEFAULT_DEFAULT_VERSION

Expand Down Expand Up @@ -118,6 +120,61 @@ def resource_manager(self):
self._resource_manager = ResourceManager(self.resource_mappings)
return self._resource_manager

@property
def extensions(self):
"""
Get the list of registered `AsdfExtension` instances.
Returns
-------
list of asdf.extension.AsdfExtension
"""
if self._extensions is None:
with self._lock:
if self._extensions is None:
self._extensions = entry_points.get_extensions()
return self._extensions

def add_extension(self, extension):
"""
Register a new extension. The new extension will
take precedence over all previously registered extensions.
Parameters
----------
extension : asdf.extension.AsdfExtension
"""
with self._lock:
extension = ExtensionProxy.maybe_wrap(extension)
self._extensions = [extension] + [e for e in self.extensions if e != extension]

def remove_extension(self, extension=None, *, package=None):
"""
Remove a registered extension.
Parameters
----------
extension : asdf.extension.AsdfExtension, optional
An extension instance to remove.
package : str, optional
A Python package name whose extensions will all be removed.
"""
with self._lock:
extensions = self.extensions
if extension is not None:
extension = ExtensionProxy.maybe_wrap(extension)
extensions = [e for e in extensions if e != extension]
if package is not None:
extensions = [e for e in extensions if e.package_name != package]
self._extensions = extensions

def reset_extensions(self):
"""
Reset extensions to the default list registered via entry points.
"""
with self._lock:
self._extensions = None

@property
def validate_on_read(self):
"""
Expand Down
8 changes: 8 additions & 0 deletions asdf/entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@

from .exceptions import AsdfWarning
from .resource import ResourceMappingProxy
from .extension import ExtensionProxy


RESOURCE_MAPPINGS_GROUP = "asdf.resource_mappings"
LEGACY_EXTENSIONS_GROUP = "asdf_extensions"


def get_resource_mappings():
return _list_entry_points(RESOURCE_MAPPINGS_GROUP, ResourceMappingProxy)


def get_extensions():
legacy_extensions = _list_entry_points(LEGACY_EXTENSIONS_GROUP, ExtensionProxy)

return legacy_extensions


def _list_entry_points(group, proxy_class):
results = []
for entry_point in iter_entry_points(group=group):
Expand Down
4 changes: 4 additions & 0 deletions asdf/extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Support for plugins that extend asdf to serialize
additional custom types.
"""
from ._extension import ExtensionProxy
from ._legacy import (
AsdfExtension,
AsdfExtensionList,
Expand All @@ -12,6 +13,9 @@


__all__ = [
# New API
"ExtensionProxy",
# Legacy API
"AsdfExtension",
"AsdfExtensionList",
"BuiltinExtension",
Expand Down
Loading

0 comments on commit 00908dd

Please sign in to comment.