Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
datagen changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafia Omer committed Dec 3, 2023
1 parent 1ed207a commit 6e7cc95
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 56 deletions.
55 changes: 0 additions & 55 deletions projects/sandbox/train/train/augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.transforms as T
import logging
from train.augmentations import (
ChannelMuter,
ChannelSwapper,
Expand Down Expand Up @@ -101,7 +98,6 @@ def __init__(
rescaler: Optional["SnrRescaler"] = None,
invert_prob: float = 0.5,
reverse_prob: float = 0.5,
signal_type: str = "bbh",
**polarizations: np.ndarray,
):
super().__init__()
Expand All @@ -119,7 +115,6 @@ def __init__(
self.signal_prob = signal_prob
self.trigger_offset = int(trigger_distance * sample_rate)
self.sample_rate = sample_rate
self.signal_type = signal_type

self.muter = ChannelMuter(frac=mute_frac)
self.swapper = ChannelSwapper(frac=swap_frac)
Expand All @@ -139,13 +134,6 @@ def __init__(
self.register_buffer("tensors", tensors)
self.register_buffer("vertices", vertices)

if self.signal_type == "bns":
# Instantiate a spectrogram transform object
self.spectrogram = BNSSpectrogram(n_fft=512)

# set trigger_offset to None
self.trigger_offset = None

# make sure we have the same number of waveforms
# for all the different polarizations
num_waveforms = None
Expand Down Expand Up @@ -198,31 +186,6 @@ def sample_responses(self, N: int, kernel_size: int, psds: torch.Tensor):
target_snrs = self.snr(N).to(responses.device)
responses, _ = self.rescaler(responses, psds**0.5, target_snrs)

if self.signal_type == "bns":
# chop the waveform from the left
waveform_duration = responses.shape[-1] / self.sample_rate

# on the left side, chop enough so the leftmost kernel still has
# coalscence in it
chop_length = ( waveform_duration * self.sample_rate) - kernel_size

# on the right side pad enough zeroes so the rightmost kernel still
# has atleast 2 sec of waveform in it that includes the coalescence
pad_length = kernel_size - (2 * self.sample_rate)


logging.info(f"chop lenght: {chop_length}")
logging.info(f"pad length: {pad_length}")
logging.info(f"kernel size: {kernel_size}")

# a 16 sec waveform with kernel length of 9 sec
# is padded with 7 secs of zeros to the right
# and chopped off by 7 sec on the left
responses = F.pad(responses, (0,int(pad_length)), "constant", 0)
responses = responses[:,:, int(chop_length): ]

logging.info(f"responses dimentions: {responses.shape}")

kernels = sample_kernels(
responses,
kernel_size=kernel_size,
Expand Down Expand Up @@ -276,12 +239,6 @@ def forward(self, X):
if self.snr is not None:
self.snr.step()

if self.signal_type == "bns":
# contruct Spectrogram for BNS
# Move the computation graph to CUDA
self.spectrogram.to(device=X.device, dtype=torch.float32)
X = self.spectrogram(X)

return X, y


Expand All @@ -297,15 +254,3 @@ def __len__(self):
def __iter__(self):
for X in self.dataloader:
yield self.fn(X[0].to(self.device))


class BNSSpectrogram(torch.nn.Module):
def __init__(self,n_fft=102,):
super().__init__()
self.spec = T.Spectrogram(n_fft=n_fft)

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
# Convert to power spectrogram
spec = self.spec(waveform)
return spec

3 changes: 2 additions & 1 deletion projects/sandbox/train/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def main(
# make output dirs and configure logging file
outdir.mkdir(exist_ok=True, parents=True)
logdir.mkdir(exist_ok=True, parents=True)
configure_logging(logdir / "train.log", verbose)
configure_logging(logdir / "train.log", verbose)

if seed is not None:
logging.info(f"Setting global seed to {seed}")
train_utils.seed_everything(seed)
Expand Down

0 comments on commit 6e7cc95

Please sign in to comment.