Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in translating 3D Spectrum1D with only spectral axis defined #61

Closed
wants to merge 5 commits into from
Closed
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 40 additions & 93 deletions glue_astronomy/translators/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from glue.core import Data, Subset

from gwcs import WCS as GWCS
from gwcs.coordinate_frames import CoordinateFrame

from astropy.wcs import WCS
from astropy import units as u
from astropy.wcs import WCSSUB_SPECTRAL
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty
from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper
from astropy.wcs.wcsapi import HighLevelWCSMixin, BaseHighLevelWCS
from astropy.wcs.wcsapi import BaseHighLevelWCS
from astropy.modeling import models

from ndcube.wcs.wrappers import CompoundLowLevelWCS

from glue_astronomy.spectral_coordinates import SpectralCoordinates

Expand All @@ -32,89 +35,23 @@
'custom:spect.doplerVeloc.beta': 'Beta'}


class PaddedSpectrumWCS(BaseWCSWrapper, HighLevelWCSMixin):

# Spectrum1D can use a 1D spectral WCS even for n-dimensional
# datasets while glue always needs the dimensionality to match,
# so this class pads the WCS so that it is n-dimensional.

# NOTE: for now this only handles padding the WCS into 2D WCS. Rather than
# generalize this we can just remove this class and use CompoundLowLevelWCS
# from NDCube once it is in a released version.

def __init__(self, wcs):
self.spectral_wcs = wcs

@property
def pixel_n_dim(self):
return 2

@property
def world_n_dim(self):
return 2

@property
def world_axis_physical_types(self):
return [self.spectral_wcs.world_axis_physical_types[0], None]

@property
def world_axis_units(self):
return (self.spectral_wcs.world_axis_units[0], None)

def pixel_to_world_values(self, *pixel_arrays):
# The ravel and reshape are needed because of
# https://github.com/astropy/astropy/issues/12154
px = np.array(pixel_arrays[0])
world_arrays = [self.spectral_wcs.pixel_to_world_values(px.ravel()).reshape(px.shape),
pixel_arrays[1]]
return tuple(world_arrays)

def world_to_pixel_values(self, *world_arrays):
# The ravel and reshape are needed because of
# https://github.com/astropy/astropy/issues/12154
wx = np.array(world_arrays[0])
pixel_arrays = [self.spectral_wcs.world_to_pixel_values(wx.ravel()).reshape(wx.shape),
world_arrays[1]]
return tuple(pixel_arrays)

@property
def world_axis_object_components(self):
return [
self.spectral_wcs.world_axis_object_components[0],
('spatial', 'value', 'value')
]
class PaddedSpectrumWCS(CompoundLowLevelWCS):

@property
def world_axis_object_classes(self):
spectral_key = self.spectral_wcs.world_axis_object_components[0][0]
return {
spectral_key: self.spectral_wcs.world_axis_object_classes[spectral_key],
'spatial': (u.Quantity, (), {'unit': u.pixel})
}
def __init__(self, spectral_wcs, n_extra_axes):
self.spectral_wcs = spectral_wcs
frame1 = CoordinateFrame(n_extra_axes, ['SPATIAL']*n_extra_axes,
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
np.arange(n_extra_axes), unit=[u.dimensionless_unscaled]*n_extra_axes,

name="Dummy1")
frame2 = CoordinateFrame(n_extra_axes, ['SPATIAL']*n_extra_axes,
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
np.arange(n_extra_axes), unit=[u.dimensionless_unscaled]*n_extra_axes,

name="Dummy2")
frame2frame = models.Multiply(1)
if n_extra_axes > 1:
for i in range(n_extra_axes-1):
frame2frame = frame2frame & models.Multiply(1)

@property
def pixel_shape(self):
return None

@property
def pixel_bounds(self):
return None

@property
def pixel_axis_names(self):
return tuple([self.spectral_wcs.pixel_axis_names[0], 'spatial'])

@property
def world_axis_names(self):
return (UCD_TO_SPECTRAL_NAME[self.spectral_wcs.world_axis_physical_types[0]], 'Offset')
Comment on lines -108 to -109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changing the label of the returned components in the 2d example to 'World 0','World 1', breaking test_spectrum1d_2d_data further down. Maybe the new naming is more logical, but that should be documented somewhere (and at least World 1 does not sound like an improvement over Frequency).


@property
def axis_correlation_matrix(self):
return np.identity(2).astype('bool')

@property
def serialized_classes(self):
return False
pad_wcs = GWCS([(frame1, frame2frame), (frame2, None)])
super().__init__(pad_wcs, spectral_wcs)


@data_translator(Spectrum1D)
Expand All @@ -125,15 +62,20 @@ def to_data(self, obj):
# Glue expects spectral axis first for cubes (opposite of specutils).
# Swap the spectral axis to first here. to_object doesn't need this because
# Spectrum1D does it automatically on initialization.
if len(obj.flux.shape) == 3:
data = Data(coords=obj.wcs.swapaxes(-1, 0))
if obj.flux.ndim > 1:
# It's possible to have a 3D Spectrum1D with only a spectral axis defined
# rather than a full WCS, in which case we need to pad the WCS to match
# the dimensionality of the flux array.
if obj.wcs.world_n_dim == obj.flux.ndim:
data = Data(coords=obj.wcs.swapaxes(-1, 0))
else:
n_extra = obj.flux.ndim - obj.wcs.world_n_dim
data = Data(coords=PaddedSpectrumWCS(obj.wcs, n_extra))
data['flux'] = np.swapaxes(obj.flux, -1, 0)
data.get_component('flux').units = str(obj.unit)
else:
if obj.flux.ndim == 1 and obj.wcs.world_n_dim == 1 and isinstance(obj.wcs, GWCS):
data = Data(coords=SpectralCoordinates(obj.spectral_axis))
elif obj.flux.ndim == 2 and obj.wcs.world_n_dim == 1:
data = Data(coords=PaddedSpectrumWCS(obj.wcs))
else:
data = Data(coords=obj.wcs)
data['flux'] = obj.flux
Expand All @@ -150,7 +92,7 @@ def to_data(self, obj):

# Include mask if it exists
if obj.mask is not None:
if len(obj.flux.shape) == 3:
if len(obj.flux.shape) > 1:
data['mask'] = np.swapaxes(obj.mask, -1, 0)
else:
data['mask'] = obj.mask
Expand Down Expand Up @@ -183,12 +125,14 @@ def to_object(self, data_or_subset, attribute=None, statistic='mean'):
if data.ndim < 2 and statistic is not None:
statistic = None

if statistic is None and isinstance(data.coords, BaseHighLevelWCS):
manual_swap = None

if isinstance(data.coords, PaddedSpectrumWCS):
kwargs = {'wcs': data.coords.spectral_wcs}
else:
kwargs = {'wcs': data.coords}
if statistic is None and isinstance(data.coords, PaddedSpectrumWCS):
kwargs = {'wcs': data.coords.spectral_wcs}
if data.ndim > 1:
manual_swap = True
elif statistic is None and isinstance(data.coords, BaseHighLevelWCS):
kwargs = {'wcs': data.coords}

elif statistic is not None:

Expand Down Expand Up @@ -262,6 +206,9 @@ def parse_attributes(attributes):
mask = np.all(mask, collapse_axes)
else:
values = data.get_data(attribute)
if manual_swap:
# In this case we need to move the spectral axis back to last
values = np.swapaxes(values, -1, 0)

attribute_label = attribute.label

Expand Down