From 9f68c7455f5109c63ea3e887bf010211449b0be9 Mon Sep 17 00:00:00 2001 From: Deep Chatterjee Date: Thu, 12 Sep 2024 15:25:08 -0400 Subject: [PATCH] device consistency (#125) --- projects/train/train/augmentations.py | 1 + projects/train/train/data/datasets/flow.py | 3 +++ .../train/train/data/waveforms/generator/cbc.py | 16 ++++++++++++---- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/projects/train/train/augmentations.py b/projects/train/train/augmentations.py index 5f80eb58..6cd493e7 100644 --- a/projects/train/train/augmentations.py +++ b/projects/train/train/augmentations.py @@ -98,5 +98,6 @@ def forward(self, X: Tensor) -> Tuple[Tensor, Tensor]: if X.ndim == 3 and X.size(0) == 2: background = background[0] + self.spectral_density.to(device=background.device) psds = self.spectral_density(background.double()) return X, psds diff --git a/projects/train/train/data/datasets/flow.py b/projects/train/train/data/datasets/flow.py index 458493c9..b1cd1fc1 100644 --- a/projects/train/train/data/datasets/flow.py +++ b/projects/train/train/data/datasets/flow.py @@ -9,6 +9,9 @@ class FlowDataset(AmplfiDataset): """ def inject(self, X, cross, plus, parameters): + self.projector.to(self.device) + self.whitener.to(self.device) + X, psds = self.psd_estimator(X) dec, psi, phi = self.waveform_sampler.sample_extrinsic(X) waveforms = self.projector(dec, psi, phi, cross=cross, plus=plus) diff --git a/projects/train/train/data/waveforms/generator/cbc.py b/projects/train/train/data/waveforms/generator/cbc.py index d993e047..26a28cdb 100644 --- a/projects/train/train/data/waveforms/generator/cbc.py +++ b/projects/train/train/data/waveforms/generator/cbc.py @@ -51,7 +51,7 @@ def __init__( ): super().__init__(*args, **kwargs) - waveform_arguments = waveform_arguments or {} + self.waveform_arguments = waveform_arguments or {} # set approximant (possibly torch.nn.Module) as an attribute # so that it will get moved to the proper device when `.to` is called @@ -122,8 +122,13 @@ def time_domain_strain(self, **parameters): A dictionary of parameters to pass to the waveform model """ - # TODO: support time domain waveforms - hc, hp = self.waveform(self.frequencies[self.freq_mask], **parameters) + device = parameters["chirp_mass"].device + freqs = torch.clone(self.frequencies).to(device) + self.approximant.to(device) + + parameters.update(self.waveform_arguments) + + hc, hp = self.approximant(freqs[self.freq_mask], **parameters) # fourier transform hc, hp = torch.fft.irfft(hc), torch.fft.irfft(hp) @@ -148,7 +153,10 @@ def time_domain_strain(self, **parameters): return hc, hp def frequency_domain_strain(self, **parameters): - return self.waveform(self.frequencies[self.freq_mask], **parameters) + device = parameters["chirp_mass"].device + freqs = torch.clone(self.frequencies).to(device) + self.approximant.to(device) + return self.waveform(freqs[self.freq_mask], **parameters) def slice_waveforms(self, waveforms: torch.Tensor): # for cbc waveforms, the padding (see above)