-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in translating 3D Spectrum1D with only spectral axis defined #61
Changes from 4 commits
6304448
4883dac
5481719
d772700
312a840
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,13 +4,16 @@ | |||||
from glue.core import Data, Subset | ||||||
|
||||||
from gwcs import WCS as GWCS | ||||||
from gwcs.coordinate_frames import CoordinateFrame | ||||||
|
||||||
from astropy.wcs import WCS | ||||||
from astropy import units as u | ||||||
from astropy.wcs import WCSSUB_SPECTRAL | ||||||
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty | ||||||
from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper | ||||||
from astropy.wcs.wcsapi import HighLevelWCSMixin, BaseHighLevelWCS | ||||||
from astropy.wcs.wcsapi import BaseHighLevelWCS | ||||||
from astropy.modeling import models | ||||||
|
||||||
from ndcube.wcs.wrappers import CompoundLowLevelWCS | ||||||
|
||||||
from glue_astronomy.spectral_coordinates import SpectralCoordinates | ||||||
|
||||||
|
@@ -32,89 +35,23 @@ | |||||
'custom:spect.doplerVeloc.beta': 'Beta'} | ||||||
|
||||||
|
||||||
class PaddedSpectrumWCS(BaseWCSWrapper, HighLevelWCSMixin): | ||||||
|
||||||
# Spectrum1D can use a 1D spectral WCS even for n-dimensional | ||||||
# datasets while glue always needs the dimensionality to match, | ||||||
# so this class pads the WCS so that it is n-dimensional. | ||||||
|
||||||
# NOTE: for now this only handles padding the WCS into 2D WCS. Rather than | ||||||
# generalize this we can just remove this class and use CompoundLowLevelWCS | ||||||
# from NDCube once it is in a released version. | ||||||
|
||||||
def __init__(self, wcs): | ||||||
self.spectral_wcs = wcs | ||||||
|
||||||
@property | ||||||
def pixel_n_dim(self): | ||||||
return 2 | ||||||
|
||||||
@property | ||||||
def world_n_dim(self): | ||||||
return 2 | ||||||
|
||||||
@property | ||||||
def world_axis_physical_types(self): | ||||||
return [self.spectral_wcs.world_axis_physical_types[0], None] | ||||||
|
||||||
@property | ||||||
def world_axis_units(self): | ||||||
return (self.spectral_wcs.world_axis_units[0], None) | ||||||
|
||||||
def pixel_to_world_values(self, *pixel_arrays): | ||||||
# The ravel and reshape are needed because of | ||||||
# https://github.com/astropy/astropy/issues/12154 | ||||||
px = np.array(pixel_arrays[0]) | ||||||
world_arrays = [self.spectral_wcs.pixel_to_world_values(px.ravel()).reshape(px.shape), | ||||||
pixel_arrays[1]] | ||||||
return tuple(world_arrays) | ||||||
|
||||||
def world_to_pixel_values(self, *world_arrays): | ||||||
# The ravel and reshape are needed because of | ||||||
# https://github.com/astropy/astropy/issues/12154 | ||||||
wx = np.array(world_arrays[0]) | ||||||
pixel_arrays = [self.spectral_wcs.world_to_pixel_values(wx.ravel()).reshape(wx.shape), | ||||||
world_arrays[1]] | ||||||
return tuple(pixel_arrays) | ||||||
|
||||||
@property | ||||||
def world_axis_object_components(self): | ||||||
return [ | ||||||
self.spectral_wcs.world_axis_object_components[0], | ||||||
('spatial', 'value', 'value') | ||||||
] | ||||||
class PaddedSpectrumWCS(CompoundLowLevelWCS): | ||||||
|
||||||
@property | ||||||
def world_axis_object_classes(self): | ||||||
spectral_key = self.spectral_wcs.world_axis_object_components[0][0] | ||||||
return { | ||||||
spectral_key: self.spectral_wcs.world_axis_object_classes[spectral_key], | ||||||
'spatial': (u.Quantity, (), {'unit': u.pixel}) | ||||||
} | ||||||
def __init__(self, spectral_wcs, n_extra_axes): | ||||||
self.spectral_wcs = spectral_wcs | ||||||
frame1 = CoordinateFrame(n_extra_axes, ['SPATIAL']*n_extra_axes, | ||||||
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes, | ||||||
name="Dummy1") | ||||||
frame2 = CoordinateFrame(n_extra_axes, ['SPATIAL']*n_extra_axes, | ||||||
np.arange(n_extra_axes), unit=[u.pix]*n_extra_axes, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
name="Dummy2") | ||||||
frame2frame = models.Multiply(1) | ||||||
if n_extra_axes > 1: | ||||||
for i in range(n_extra_axes-1): | ||||||
frame2frame = frame2frame & models.Multiply(1) | ||||||
|
||||||
@property | ||||||
def pixel_shape(self): | ||||||
return None | ||||||
|
||||||
@property | ||||||
def pixel_bounds(self): | ||||||
return None | ||||||
|
||||||
@property | ||||||
def pixel_axis_names(self): | ||||||
return tuple([self.spectral_wcs.pixel_axis_names[0], 'spatial']) | ||||||
|
||||||
@property | ||||||
def world_axis_names(self): | ||||||
return (UCD_TO_SPECTRAL_NAME[self.spectral_wcs.world_axis_physical_types[0]], 'Offset') | ||||||
Comment on lines
-108
to
-109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is changing the |
||||||
|
||||||
@property | ||||||
def axis_correlation_matrix(self): | ||||||
return np.identity(2).astype('bool') | ||||||
|
||||||
@property | ||||||
def serialized_classes(self): | ||||||
return False | ||||||
pad_wcs = GWCS([(frame1, frame2frame), (frame2, None)]) | ||||||
super().__init__(pad_wcs, spectral_wcs) | ||||||
|
||||||
|
||||||
@data_translator(Spectrum1D) | ||||||
|
@@ -125,15 +62,20 @@ def to_data(self, obj): | |||||
# Glue expects spectral axis first for cubes (opposite of specutils). | ||||||
# Swap the spectral axis to first here. to_object doesn't need this because | ||||||
# Spectrum1D does it automatically on initialization. | ||||||
if len(obj.flux.shape) == 3: | ||||||
data = Data(coords=obj.wcs.swapaxes(-1, 0)) | ||||||
if obj.flux.ndim > 1: | ||||||
# It's possible to have a 3D Spectrum1D with only a spectral axis defined | ||||||
# rather than a full WCS, in which case we need to pad the WCS to match | ||||||
# the dimensionality of the flux array. | ||||||
if obj.wcs.world_n_dim == obj.flux.ndim: | ||||||
data = Data(coords=obj.wcs.swapaxes(-1, 0)) | ||||||
else: | ||||||
n_extra = obj.flux.ndim - obj.wcs.world_n_dim | ||||||
data = Data(coords=PaddedSpectrumWCS(obj.wcs, n_extra)) | ||||||
data['flux'] = np.swapaxes(obj.flux, -1, 0) | ||||||
data.get_component('flux').units = str(obj.unit) | ||||||
else: | ||||||
if obj.flux.ndim == 1 and obj.wcs.world_n_dim == 1 and isinstance(obj.wcs, GWCS): | ||||||
data = Data(coords=SpectralCoordinates(obj.spectral_axis)) | ||||||
elif obj.flux.ndim == 2 and obj.wcs.world_n_dim == 1: | ||||||
data = Data(coords=PaddedSpectrumWCS(obj.wcs)) | ||||||
else: | ||||||
data = Data(coords=obj.wcs) | ||||||
data['flux'] = obj.flux | ||||||
|
@@ -150,7 +92,7 @@ def to_data(self, obj): | |||||
|
||||||
# Include mask if it exists | ||||||
if obj.mask is not None: | ||||||
if len(obj.flux.shape) == 3: | ||||||
if len(obj.flux.shape) > 1: | ||||||
data['mask'] = np.swapaxes(obj.mask, -1, 0) | ||||||
else: | ||||||
data['mask'] = obj.mask | ||||||
|
@@ -183,12 +125,14 @@ def to_object(self, data_or_subset, attribute=None, statistic='mean'): | |||||
if data.ndim < 2 and statistic is not None: | ||||||
statistic = None | ||||||
|
||||||
if statistic is None and isinstance(data.coords, BaseHighLevelWCS): | ||||||
manual_swap = None | ||||||
|
||||||
if isinstance(data.coords, PaddedSpectrumWCS): | ||||||
kwargs = {'wcs': data.coords.spectral_wcs} | ||||||
else: | ||||||
kwargs = {'wcs': data.coords} | ||||||
if statistic is None and isinstance(data.coords, PaddedSpectrumWCS): | ||||||
kwargs = {'wcs': data.coords.spectral_wcs} | ||||||
if data.ndim > 1: | ||||||
manual_swap = True | ||||||
elif statistic is None and isinstance(data.coords, BaseHighLevelWCS): | ||||||
kwargs = {'wcs': data.coords} | ||||||
|
||||||
elif statistic is not None: | ||||||
|
||||||
|
@@ -262,6 +206,9 @@ def parse_attributes(attributes): | |||||
mask = np.all(mask, collapse_axes) | ||||||
else: | ||||||
values = data.get_data(attribute) | ||||||
if manual_swap: | ||||||
# In this case we need to move the spectral axis back to last | ||||||
values = np.swapaxes(values, -1, 0) | ||||||
|
||||||
attribute_label = attribute.label | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.