Skip to content

Commit

Permalink
device consistency (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
deepchatterjeeligo authored Sep 12, 2024
1 parent e13eaab commit 9f68c74
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
1 change: 1 addition & 0 deletions projects/train/train/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,6 @@ def forward(self, X: Tensor) -> Tuple[Tensor, Tensor]:
if X.ndim == 3 and X.size(0) == 2:
background = background[0]

self.spectral_density.to(device=background.device)
psds = self.spectral_density(background.double())
return X, psds
3 changes: 3 additions & 0 deletions projects/train/train/data/datasets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class FlowDataset(AmplfiDataset):
"""

def inject(self, X, cross, plus, parameters):
self.projector.to(self.device)
self.whitener.to(self.device)

X, psds = self.psd_estimator(X)
dec, psi, phi = self.waveform_sampler.sample_extrinsic(X)
waveforms = self.projector(dec, psi, phi, cross=cross, plus=plus)
Expand Down
16 changes: 12 additions & 4 deletions projects/train/train/data/waveforms/generator/cbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
):

super().__init__(*args, **kwargs)
waveform_arguments = waveform_arguments or {}
self.waveform_arguments = waveform_arguments or {}

# set approximant (possibly torch.nn.Module) as an attribute
# so that it will get moved to the proper device when `.to` is called
Expand Down Expand Up @@ -122,8 +122,13 @@ def time_domain_strain(self, **parameters):
A dictionary of parameters to pass to the waveform model
"""

# TODO: support time domain waveforms
hc, hp = self.waveform(self.frequencies[self.freq_mask], **parameters)
device = parameters["chirp_mass"].device
freqs = torch.clone(self.frequencies).to(device)
self.approximant.to(device)

parameters.update(self.waveform_arguments)

hc, hp = self.approximant(freqs[self.freq_mask], **parameters)

# fourier transform
hc, hp = torch.fft.irfft(hc), torch.fft.irfft(hp)
Expand All @@ -148,7 +153,10 @@ def time_domain_strain(self, **parameters):
return hc, hp

def frequency_domain_strain(self, **parameters):
return self.waveform(self.frequencies[self.freq_mask], **parameters)
device = parameters["chirp_mass"].device
freqs = torch.clone(self.frequencies).to(device)
self.approximant.to(device)
return self.waveform(freqs[self.freq_mask], **parameters)

def slice_waveforms(self, waveforms: torch.Tensor):
# for cbc waveforms, the padding (see above)
Expand Down

0 comments on commit 9f68c74

Please sign in to comment.