Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in translating 3D Spectrum1D with only spectral axis defined #61

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 40 additions & 93 deletions glue_astronomy/translators/spectrum1d.py
Original file line number Diff line number Diff line change
@@ -4,13 +4,16 @@
from glue.core import Data, Subset

from gwcs import WCS as GWCS
from gwcs.coordinate_frames import CoordinateFrame

from astropy.wcs import WCS
from astropy import units as u
from astropy.wcs import WCSSUB_SPECTRAL
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty
from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper
from astropy.wcs.wcsapi import HighLevelWCSMixin, BaseHighLevelWCS
from astropy.wcs.wcsapi import BaseHighLevelWCS
from astropy.modeling import models

from ndcube.wcs.wrappers import CompoundLowLevelWCS

from glue_astronomy.spectral_coordinates import SpectralCoordinates

@@ -32,89 +35,23 @@
'custom:spect.doplerVeloc.beta': 'Beta'}


class PaddedSpectrumWCS(BaseWCSWrapper, HighLevelWCSMixin):

# Spectrum1D can use a 1D spectral WCS even for n-dimensional
# datasets while glue always needs the dimensionality to match,
# so this class pads the WCS so that it is n-dimensional.

# NOTE: for now this only handles padding the WCS into 2D WCS. Rather than
# generalize this we can just remove this class and use CompoundLowLevelWCS
# from NDCube once it is in a released version.

def __init__(self, wcs):
self.spectral_wcs = wcs

@property
def pixel_n_dim(self):
return 2

@property
def world_n_dim(self):
return 2

@property
def world_axis_physical_types(self):
return [self.spectral_wcs.world_axis_physical_types[0], None]

@property
def world_axis_units(self):
return (self.spectral_wcs.world_axis_units[0], None)

def pixel_to_world_values(self, *pixel_arrays):
# The ravel and reshape are needed because of
# https://github.com/astropy/astropy/issues/12154
px = np.array(pixel_arrays[0])
world_arrays = [self.spectral_wcs.pixel_to_world_values(px.ravel()).reshape(px.shape),
pixel_arrays[1]]
return tuple(world_arrays)

def world_to_pixel_values(self, *world_arrays):
# The ravel and reshape are needed because of
# https://github.com/astropy/astropy/issues/12154
wx = np.array(world_arrays[0])
pixel_arrays = [self.spectral_wcs.world_to_pixel_values(wx.ravel()).reshape(wx.shape),
world_arrays[1]]
return tuple(pixel_arrays)

@property
def world_axis_object_components(self):
return [
self.spectral_wcs.world_axis_object_components[0],
('spatial', 'value', 'value')
]
class PaddedSpectrumWCS(CompoundLowLevelWCS):

@property
def world_axis_object_classes(self):
spectral_key = self.spectral_wcs.world_axis_object_components[0][0]
return {
spectral_key: self.spectral_wcs.world_axis_object_classes[spectral_key],
'spatial': (u.Quantity, (), {'unit': u.pixel})
}
def __init__(self, spectral_wcs, n_extra_axes):
self.spectral_wcs = spectral_wcs
frame1 = CoordinateFrame(n_extra_axes, ['SPATIAL']*n_extra_axes,
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
np.arange(n_extra_axes), unit=[u.dimensionless_unscaled]*n_extra_axes,

name="Dummy1")
frame2 = CoordinateFrame(n_extra_axes, ['SPATIAL']*n_extra_axes,
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes,
np.arange(n_extra_axes), unit=[u.dimensionless_unscaled]*n_extra_axes,

name="Dummy2")
frame2frame = models.Multiply(1)
if n_extra_axes > 1:
for i in range(n_extra_axes-1):
frame2frame = frame2frame & models.Multiply(1)

@property
def pixel_shape(self):
return None

@property
def pixel_bounds(self):
return None

@property
def pixel_axis_names(self):
return tuple([self.spectral_wcs.pixel_axis_names[0], 'spatial'])

@property
def world_axis_names(self):
return (UCD_TO_SPECTRAL_NAME[self.spectral_wcs.world_axis_physical_types[0]], 'Offset')
Comment on lines -108 to -109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is changing the label of the returned components in the 2d example to 'World 0','World 1', breaking test_spectrum1d_2d_data further down. Maybe the new naming is more logical, but that should be documented somewhere (and at least World 1 does not sound like an improvement over Frequency).


@property
def axis_correlation_matrix(self):
return np.identity(2).astype('bool')

@property
def serialized_classes(self):
return False
pad_wcs = GWCS([(frame1, frame2frame), (frame2, None)])
super().__init__(pad_wcs, spectral_wcs)


@data_translator(Spectrum1D)
@@ -125,15 +62,20 @@ def to_data(self, obj):
# Glue expects spectral axis first for cubes (opposite of specutils).
# Swap the spectral axis to first here. to_object doesn't need this because
# Spectrum1D does it automatically on initialization.
if len(obj.flux.shape) == 3:
data = Data(coords=obj.wcs.swapaxes(-1, 0))
if obj.flux.ndim > 1:
# It's possible to have a 3D Spectrum1D with only a spectral axis defined
# rather than a full WCS, in which case we need to pad the WCS to match
# the dimensionality of the flux array.
if obj.wcs.world_n_dim == obj.flux.ndim:
data = Data(coords=obj.wcs.swapaxes(-1, 0))
else:
n_extra = obj.flux.ndim - obj.wcs.world_n_dim
data = Data(coords=PaddedSpectrumWCS(obj.wcs, n_extra))
data['flux'] = np.swapaxes(obj.flux, -1, 0)
data.get_component('flux').units = str(obj.unit)
else:
if obj.flux.ndim == 1 and obj.wcs.world_n_dim == 1 and isinstance(obj.wcs, GWCS):
data = Data(coords=SpectralCoordinates(obj.spectral_axis))
elif obj.flux.ndim == 2 and obj.wcs.world_n_dim == 1:
data = Data(coords=PaddedSpectrumWCS(obj.wcs))
else:
data = Data(coords=obj.wcs)
data['flux'] = obj.flux
@@ -150,7 +92,7 @@ def to_data(self, obj):

# Include mask if it exists
if obj.mask is not None:
if len(obj.flux.shape) == 3:
if len(obj.flux.shape) > 1:
data['mask'] = np.swapaxes(obj.mask, -1, 0)
else:
data['mask'] = obj.mask
@@ -183,12 +125,14 @@ def to_object(self, data_or_subset, attribute=None, statistic='mean'):
if data.ndim < 2 and statistic is not None:
statistic = None

if statistic is None and isinstance(data.coords, BaseHighLevelWCS):
manual_swap = None

if isinstance(data.coords, PaddedSpectrumWCS):
kwargs = {'wcs': data.coords.spectral_wcs}
else:
kwargs = {'wcs': data.coords}
if statistic is None and isinstance(data.coords, PaddedSpectrumWCS):
kwargs = {'wcs': data.coords.spectral_wcs}
if data.ndim > 1:
manual_swap = True
elif statistic is None and isinstance(data.coords, BaseHighLevelWCS):
kwargs = {'wcs': data.coords}

elif statistic is not None:

@@ -262,6 +206,9 @@ def parse_attributes(attributes):
mask = np.all(mask, collapse_axes)
else:
values = data.get_data(attribute)
if manual_swap:
# In this case we need to move the spectral axis back to last
values = np.swapaxes(values, -1, 0)

attribute_label = attribute.label

23 changes: 16 additions & 7 deletions glue_astronomy/translators/tests/test_spectrum1d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import warnings
import numpy as np
from numpy.testing import assert_allclose, assert_equal

@@ -125,7 +126,7 @@ def test_to_spectrum1d_default_attribute():


@pytest.mark.parametrize('mode', ('wcs1d', 'wcs3d', 'lookup'))
def test_from_spectrum1d(mode):
def test_from_spectrum1d(mode, recwarn):

if mode == 'wcs3d':
# This test is intended to be run with the version of Spectrum1D based
@@ -195,10 +196,16 @@ def test_from_spectrum1d(mode):
print(uncertainty)
assert_quantity_allclose(spec_new.uncertainty.quantity,
np.ones((5, 4, 4))*0.01*u.Jy**2)

assert len(recwarn) == 3
for w in recwarn:
assert issubclass(w.category, UserWarning)
assert "Input WCS indicates that the spectral axis is not last." in str(w.message)
else:
assert_quantity_allclose(spec_new.flux, [2, 3, 4, 5] * u.Jy)
assert spec_new.uncertainty is not None
assert_quantity_allclose(spec_new.uncertainty.quantity, [0.1, 0.1, 0.1, 0.1] * u.Jy**2)
assert len(recwarn) == 0


def test_spectrum1d_2d_data():
@@ -207,6 +214,7 @@ def test_spectrum1d_2d_data():
# Note that Spectrum1D will typically have a 1D spectral WCS even if the
# data is N-dimensional, so we need to pad the WCS before passing it to
# glue and un-pad it when translating back.
# Also Spectrum1D.flux has the spectral axis along last dimension, not first.

# We test both the case where the WCS is 2D and the case where it is 1D

@@ -215,7 +223,7 @@ def test_spectrum1d_2d_data():
wcs.wcs.cdelt = [10]
wcs.wcs.set()

flux = np.ones((3, 2)) * u.Unit('Jy')
flux = np.arange(1, 7).reshape((3, 2)) * u.Unit('Jy')

spec = Spectrum1D(flux, wcs=wcs, meta={'instrument': 'spamcam'})

@@ -231,7 +239,7 @@ def test_spectrum1d_2d_data():
assert isinstance(data, Data)
assert len(data.main_components) == 1
assert data.main_components[0].label == 'flux'
assert_allclose(data['flux'], flux.value)
assert_allclose(data['flux'], flux.value.swapaxes(-1, 0))

assert data.coords.pixel_n_dim == 2
assert data.coords.world_n_dim == 2
@@ -240,11 +248,12 @@ def test_spectrum1d_2d_data():

assert data.coordinate_components[0].label == 'Pixel Axis 0 [y]'
assert data.coordinate_components[1].label == 'Pixel Axis 1 [x]'
assert data.coordinate_components[2].label == 'Offset'
assert data.coordinate_components[3].label == 'Frequency'
assert data.coordinate_components[2].label == 'World 0'
assert data.coordinate_components[3].label == 'World 1'

assert_equal(data['Offset'], [[0, 0], [1, 1], [2, 2]])
assert_equal(data['Frequency'], [[10, 20], [10, 20], [10, 20]])
assert_equal(data['World 0'], [[10, 10, 10], [10, 10, 10]])
assert data['World 1'].shape == (2, 3)
assert_equal(data['World 1'], [[20, 20, 20], [20, 20, 20]])

s, o = data.coords.pixel_to_world(1, 2)
assert isinstance(s, SpectralCoord)