From 6e7cc95a8ef8608985eb277f44d3e307b5caad50 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Sun, 3 Dec 2023 09:48:56 -0800 Subject: [PATCH] datagen changes --- projects/sandbox/train/train/augmentor.py | 55 ----------------------- projects/sandbox/train/train/train.py | 3 +- 2 files changed, 2 insertions(+), 56 deletions(-) diff --git a/projects/sandbox/train/train/augmentor.py b/projects/sandbox/train/train/augmentor.py index 47ef2cd1..d268e8ee 100644 --- a/projects/sandbox/train/train/augmentor.py +++ b/projects/sandbox/train/train/augmentor.py @@ -2,9 +2,6 @@ import numpy as np import torch -import torch.nn.functional as F -import torchaudio.transforms as T -import logging from train.augmentations import ( ChannelMuter, ChannelSwapper, @@ -101,7 +98,6 @@ def __init__( rescaler: Optional["SnrRescaler"] = None, invert_prob: float = 0.5, reverse_prob: float = 0.5, - signal_type: str = "bbh", **polarizations: np.ndarray, ): super().__init__() @@ -119,7 +115,6 @@ def __init__( self.signal_prob = signal_prob self.trigger_offset = int(trigger_distance * sample_rate) self.sample_rate = sample_rate - self.signal_type = signal_type self.muter = ChannelMuter(frac=mute_frac) self.swapper = ChannelSwapper(frac=swap_frac) @@ -139,13 +134,6 @@ def __init__( self.register_buffer("tensors", tensors) self.register_buffer("vertices", vertices) - if self.signal_type == "bns": - # Instantiate a spectrogram transform object - self.spectrogram = BNSSpectrogram(n_fft=512) - - # set trigger_offset to None - self.trigger_offset = None - # make sure we have the same number of waveforms # for all the different polarizations num_waveforms = None @@ -198,31 +186,6 @@ def sample_responses(self, N: int, kernel_size: int, psds: torch.Tensor): target_snrs = self.snr(N).to(responses.device) responses, _ = self.rescaler(responses, psds**0.5, target_snrs) - if self.signal_type == "bns": - # chop the waveform from the left - waveform_duration = responses.shape[-1] / self.sample_rate - - # on the left side, chop enough so the leftmost kernel still has - # coalscence in it - chop_length = ( waveform_duration * self.sample_rate) - kernel_size - - # on the right side pad enough zeroes so the rightmost kernel still - # has atleast 2 sec of waveform in it that includes the coalescence - pad_length = kernel_size - (2 * self.sample_rate) - - - logging.info(f"chop lenght: {chop_length}") - logging.info(f"pad length: {pad_length}") - logging.info(f"kernel size: {kernel_size}") - - # a 16 sec waveform with kernel length of 9 sec - # is padded with 7 secs of zeros to the right - # and chopped off by 7 sec on the left - responses = F.pad(responses, (0,int(pad_length)), "constant", 0) - responses = responses[:,:, int(chop_length): ] - - logging.info(f"responses dimentions: {responses.shape}") - kernels = sample_kernels( responses, kernel_size=kernel_size, @@ -276,12 +239,6 @@ def forward(self, X): if self.snr is not None: self.snr.step() - if self.signal_type == "bns": - # contruct Spectrogram for BNS - # Move the computation graph to CUDA - self.spectrogram.to(device=X.device, dtype=torch.float32) - X = self.spectrogram(X) - return X, y @@ -297,15 +254,3 @@ def __len__(self): def __iter__(self): for X in self.dataloader: yield self.fn(X[0].to(self.device)) - - -class BNSSpectrogram(torch.nn.Module): - def __init__(self,n_fft=102,): - super().__init__() - self.spec = T.Spectrogram(n_fft=n_fft) - - def forward(self, waveform: torch.Tensor) -> torch.Tensor: - # Convert to power spectrogram - spec = self.spec(waveform) - return spec - diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index f78b299f..a5a204c8 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -216,7 +216,8 @@ def main( # make output dirs and configure logging file outdir.mkdir(exist_ok=True, parents=True) logdir.mkdir(exist_ok=True, parents=True) - configure_logging(logdir / "train.log", verbose) + configure_logging(logdir / "train.log", verbose) + if seed is not None: logging.info(f"Setting global seed to {seed}") train_utils.seed_everything(seed)