diff --git a/format/nexus.py b/format/nexus.py index 5637c75b3..08b972dd0 100644 --- a/format/nexus.py +++ b/format/nexus.py @@ -19,6 +19,7 @@ import dxtbx.model from dxtbx.model import ( Beam, + SpectrumBeam, Crystal, Detector, Panel, @@ -547,28 +548,53 @@ def load_model(self, index=None): # Get the items from the NXbeam class wavelength = self.obj.handle["incident_wavelength"] - wavelength_weights = self.obj.handle.get("incident_wavelength_weights") - if wavelength.shape in (tuple(), (1,)): - wavelength_value = wavelength[()] - elif len(wavelength.shape) == 1: - if wavelength_weights is None: - if index is None: - index = 0 - wavelength_value = wavelength[index] + wavelength_value = None + wavelength_weights = self.obj.handle.get("incident_wavelength_weight") + wavelength_calculated = self.obj.handle.get("incident_wavelength_calculated") + + if index is None: + index = 0 + self.index = index + + def get_wavelength(wavelength): + if wavelength.shape in (tuple(), (1,)): + wavelength_value = wavelength[()] else: - raise NotImplementedError("Spectra not implemented") + wavelength_value = wavelength[index] + wavelength_units = wavelength.attrs["units"] + wavelength_value = float( + convert_units(wavelength_value, wavelength_units, "angstrom") + ) + return wavelength_value + + if wavelength_calculated is not None: + wavelength_value = get_wavelength(wavelength_calculated) + elif wavelength_weights is None: + wavelength_value = get_wavelength(wavelength) + + if wavelength_weights is None: + # Construct the beam model + self.model = Beam(direction=(0, 0, 1), wavelength=wavelength_value) else: - raise NotImplementedError("Spectra not implemented") - wavelength_units = wavelength.attrs["units"] + self.model = SpectrumBeam() + self.model.set_direction((0, 0, 1)) - # Convert wavelength to Angstroms - wavelength_value = float( - convert_units(wavelength_value, wavelength_units, "angstrom") - ) + wavelength_units = wavelength.attrs["units"] - # Construct the beam model - self.index = index - self.model = Beam(direction=(0, 0, 1), wavelength=wavelength_value) + if len(wavelength.shape) > 1: + wavelength = wavelength[index] + wavelength_weights = wavelength_weights[index] + + spectrum_wavelengths = convert_units( + wavelength, wavelength_units, "angstrom" + ) + spectrum_energies = 12398.4187 / spectrum_wavelengths + spectrum_weights = wavelength_weights + self.model.set_spectrum(spectrum_energies, spectrum_weights) + if wavelength_value: + self.model.set_wavelength(wavelength_value) + else: + self.model.set_wavelength(self.model.get_weighted_wavelength()) def get_change_of_basis(transformation):