Skip to content

Commit

Permalink
Merge pull request #1561 from astrofrog/fix-save-aggregate-slice
Browse files Browse the repository at this point in the history
Fix serialization of AggregateSlice
  • Loading branch information
astrofrog committed Mar 7, 2018
1 parent 09de1fd commit 8fe93c7
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ v0.12.4 (2017-01-09)

* Improve error message in PV slicer when _slice_index fails. [#1536]

* Fixed a bug that caused an error when trying to save a session that
included an image viewer with an aggregated slice. [#1561]

* Fixed a bug that caused an error in the terminal if creating a data
viewer failed properly (with a GUI error message).

Expand Down
78 changes: 76 additions & 2 deletions glue/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@ def load(rec, context)
from glue.utils import lookup_class


literals = tuple([type(None), float, int, bytes, bool, list, tuple])
literals = tuple([type(None), float, int, bytes, bool])

if six.PY2:
literals += (long,)


literals += np.ScalarType

builtin_iterables = (tuple, list, set)

JSON_ENCODER = json.JSONEncoder()

# We need to make sure that we don't break backward-compatibility when we move
Expand Down Expand Up @@ -220,6 +222,25 @@ def __setitem__(self, key, value):
self._data[item][version] = value


def as_nested_lists(obj):
items = []
for item in obj:
if type(item) in builtin_iterables:
item = as_nested_lists(item)
items.append(item)
return items


def flattened(obj):
items = []
for item in obj:
if type(item) in builtin_iterables:
items += as_nested_lists(item)
else:
items.append(item)
return items


class GlueSerializer(object):

"""
Expand Down Expand Up @@ -261,6 +282,12 @@ def id(self, obj):
if type(obj) in literals:
return obj

# Now check for list, set, and tuple, and skip if they don't contain
# any non-literals.
if type(obj) in builtin_iterables:
if all(isinstance(x, literals) for x in flattened(obj)):
return as_nested_lists(obj)

oid = id(obj)

if oid in self._names:
Expand Down Expand Up @@ -299,6 +326,12 @@ def do(self, obj):
if type(obj) in literals:
return obj

# Now check for list, set, and tuple, and skip if they don't contain
# any non-literals
if type(obj) in builtin_iterables:
if all(isinstance(x, literals) for x in flattened(obj)):
return as_nested_lists(obj)

oid = id(obj)
if oid in self._working:
raise GlueSerializeError("Circular reference detected")
Expand All @@ -322,6 +355,7 @@ def do(self, obj):
return result

def _dispatch(self, obj):

if hasattr(obj, '__gluestate__'):
return type(obj).__gluestate__, 1

Expand Down Expand Up @@ -452,7 +486,7 @@ def object(self, obj_id):
self._working.add(obj_id)
rec = self._rec[obj_id]

elif isinstance(obj_id, literals):
elif isinstance(obj_id, literals) or isinstance(obj_id, (tuple, list)):
return obj_id
else:
rec = obj_id
Expand Down Expand Up @@ -493,6 +527,46 @@ def _load_dict(rec, context):
for key, value in rec['contents'].items())


@saver(tuple)
def _save_tuple(state, context):
return dict(contents=[context.do(item) for item in state])


@loader(tuple)
def _load_tuple(rec, context):
return tuple(_load_list(rec, context))


@saver(list)
def _save_list(state, context):
return dict(contents=[context.do(item) for item in state])


@loader(list)
def _load_list(rec, context):
return [context.object(item) for item in rec['contents']]


@saver(set)
def _save_set(state, context):
return dict(contents=[context.do(item) for item in state])


@loader(set)
def _load_set(rec, context):
return set(_load_list(rec, context))


@saver(slice)
def _save_slice(slc, context):
return dict(start=slc.start, stop=slc.stop, step=slc.step)


@loader(slice)
def _load_slice(rec, context):
return slice(rec['start'], rec['stop'], rec['step'])


@saver(CompositeSubsetState)
def _save_composite_subset_state(state, context):
return dict(state1=context.id(state.state1),
Expand Down
28 changes: 26 additions & 2 deletions glue/viewers/image/qt/tests/test_data_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from glue.core.state import GlueUnSerializer
from glue.app.qt.layer_tree_widget import LayerTreeWidget
from glue.viewers.scatter.state import ScatterLayerState
from glue.viewers.image.state import ImageLayerState, ImageSubsetLayerState
from glue.viewers.image.state import ImageLayerState, ImageSubsetLayerState, AggregateSlice
from glue.core.link_helpers import LinkSame
from glue.app.qt import GlueApplication

Expand Down Expand Up @@ -81,7 +81,7 @@ def setup_method(self, method):
self.data_collection.append(self.image1_wcs)
self.data_collection.append(self.hypercube_wcs)

self.viewer = ImageViewer(self.session)
self.viewer = self.application.new_data_viewer(ImageViewer)

self.data_collection.register_to_hub(self.hub)
self.viewer.register_to_hub(self.hub)
Expand Down Expand Up @@ -580,6 +580,30 @@ def test_linking_and_enabling(self):
assert not self.viewer.layers[2].enabled # image subset
assert self.viewer.layers[3].enabled # scatter subset

def test_save_aggregate_slice(self, tmpdir):

# Regressin test to make sure that image viewers that include
# aggregate slice objects in the slices can be saved/restored

self.viewer.add_data(self.hypercube)
self.viewer.state.slices = AggregateSlice(slice(1, 3), 10, np.sum), 3, 0, 0

filename = tmpdir.join('session.glu').strpath

self.application.save_session(filename)
self.application.close()

app2 = GlueApplication.restore_session(filename)
viewer_state = app2.viewers[0][0].state
slices = viewer_state.slices
assert isinstance(slices[0], AggregateSlice)
assert slices[0].slice == slice(1, 3)
assert slices[0].center == 10
assert slices[0].function is np.sum
assert slices[1:] == (3, 0, 0)

app2.close()


class TestSessions(object):

Expand Down
12 changes: 12 additions & 0 deletions glue/viewers/image/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ def __init__(self, slice=None, center=None, function=None):
self.center = center
self.function = function

def __gluestate__(self, context):
state = dict(slice=context.do(self.slice),
center=self.center,
function=context.do(self.function))
return state

@classmethod
def __setgluestate__(cls, rec, context):
return cls(slice=context.object(rec['slice']),
center=rec['center'],
function=context.object(rec['function']))


class ImageViewerState(MatplotlibDataViewerState):
"""
Expand Down

0 comments on commit 8fe93c7

Please sign in to comment.