Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bf16 inference support and fix seq_len stft issue #7338

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class TranscriptionConfig:
cuda: Optional[int] = None
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
amp: bool = False
amp_dtype: str = "float16" # can be set to "float32" or "bfloat16" when using amp
audio_type: str = "wav"

# Recompute model transcription, even if the output folder exists with scores.
Expand Down Expand Up @@ -304,7 +305,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
else:

@contextlib.contextmanager
def autocast():
def autocast(dtype=None):
yield

# Compute output filename
Expand All @@ -319,7 +320,10 @@ def autocast():
return cfg

# transcribe audio
with autocast():

amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16

with autocast(dtype=amp_dtype):
with torch.no_grad():
if partial_audio:
transcriptions = transcribe_partial_audio(
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,15 +390,15 @@ def log_zero_guard_value_fn(self, x):
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = self.stft_pad_amount * 2 if self.stft_pad_amount is not None else self.n_fft // 2 * 2
seq_len = torch.floor((seq_len + pad_amount - self.n_fft) / self.hop_length) + 1
seq_len = torch.floor_divide((seq_len + pad_amount - self.n_fft), self.hop_length) + 1
return seq_len.to(dtype=torch.long)

@property
def filter_banks(self):
return self.fb

def forward(self, x, seq_len, linear_spec=False):
seq_len = self.get_seq_len(seq_len.float())
seq_len = self.get_seq_len(seq_len)

if self.stft_pad_amount is not None:
x = torch.nn.functional.pad(
Expand Down