diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py index 667be0276b..4a5689df4a 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py @@ -7,7 +7,7 @@ from parameterized import parameterized from torchaudio._backend.utils import get_load_func from torchaudio._internal import module_utils as _mod_utils -from torchaudio.io._compat import _get_encoder +from torchaudio.io._compat import _parse_save_args from torchaudio_unittest.backend.dispatcher.sox.common import name_func from torchaudio_unittest.common_utils import ( @@ -56,11 +56,10 @@ def assert_format( | | 1. Generate given format with Sox | - v 3. Convert to wav with FFmpeg - given format ----------------------> wav - | | - | 2. Load with torchaudio | 4. Load with scipy + + ----------------------------------+ 3. Convert to wav with FFmpeg | | + | 2. Load the given format | 4. Load with scipy + | with torchaudio | v v tensor ----------> x <----------- tensor 5. Compare @@ -72,7 +71,6 @@ def assert_format( By combining i & ii, step 2. and 4. allow for loading reference given format data without using torchaudio """ - path = self.get_temp_path(f"1.original.{format}") ref_path = self.get_temp_path("2.reference.wav") @@ -91,15 +89,15 @@ def assert_format( # 3. Convert to wav with ffmpeg if normalize: - acodec = "pcm_f32le" + encoder = "pcm_f32le" else: encoding_map = { "floating-point": "PCM_F", "signed-integer": "PCM_S", "unsigned-integer": "PCM_U", } - acodec = _get_encoder(data.dtype, "wav", encoding_map.get(encoding), bit_depth) - _convert_audio_file(path, ref_path, acodec=acodec) + _, encoder, _ = _parse_save_args(format, format, encoding_map.get(encoding), bit_depth) + _convert_audio_file(path, ref_path, encoder=encoder) # 4. Load wav with scipy data_ref = load_wav(ref_path, normalize=normalize)[0] @@ -277,7 +275,7 @@ def test_opus(self, bitrate, num_channels, compression_level): """`self._load` can load opus file correctly.""" ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus") wav_path = self.get_temp_path(f"{bitrate}_{compression_level}_{num_channels}ch.opus.wav") - _convert_audio_file(ops_path, wav_path, acodec="pcm_f32le") + _convert_audio_file(ops_path, wav_path, encoder="pcm_f32le") expected, sample_rate = load_wav(wav_path) found, sr = self._load(ops_path) @@ -301,15 +299,14 @@ def test_sphere(self, sample_rate, num_channels): @parameterized.expand( list( itertools.product( - ["float32", "int32", "int16"], - [8000, 16000], - [1, 2], + ["int16"], + [3, 4, 16], [False, True], ) ), name_func=name_func, ) - def test_amb(self, dtype, sample_rate, num_channels, normalize): + def test_amb(self, dtype, num_channels, normalize, sample_rate=8000): """`self._load` can load amb format correctly.""" bit_depth = sox_utils.get_bit_depth(dtype) encoding = sox_utils.get_encoding(dtype) diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py index ef0e56f0e5..98120f2f4f 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py @@ -8,7 +8,7 @@ import torch from parameterized import parameterized from torchaudio._backend.utils import get_save_func -from torchaudio.io._compat import _get_encoder, _get_encoder_format +from torchaudio.io._compat import _parse_save_args from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func from torchaudio_unittest.common_utils import ( @@ -24,12 +24,14 @@ ) -def _convert_audio_file(src_path, dst_path, format=None, acodec=None): - command = ["ffmpeg", "-y", "-i", src_path, "-strict", "-2"] - if format: - command += ["-sample_fmt", format] - if acodec: - command += ["-acodec", acodec] +def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None): + command = ["ffmpeg", "-hide_banner", "-y", "-i", src_path, "-strict", "-2"] + if muxer: + command += ["-f", muxer] + if encoder: + command += ["-acodec", encoder] + if sample_fmt: + command += ["-sample_fmt", sample_fmt] command += [dst_path] print(" ".join(command), file=sys.stderr) subprocess.run(command, check=True) @@ -100,8 +102,10 @@ def assert_save_consistency( # 2.1. Convert the original wav to target format with torchaudio data = load_wav(src_path, normalize=False)[0] if test_mode == "path": - self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample) + ext = format + self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample) elif test_mode == "fileobj": + ext = None with open(tgt_path, "bw") as file_: self._save( file_, @@ -113,6 +117,7 @@ def assert_save_consistency( ) elif test_mode == "bytesio": file_ = io.BytesIO() + ext = None self._save( file_, data, @@ -127,16 +132,15 @@ def assert_save_consistency( else: raise ValueError(f"Unexpected test mode: {test_mode}") # 2.2. Convert the target format to wav with ffmpeg - _convert_audio_file(tgt_path, tst_path, acodec="pcm_f32le") + _convert_audio_file(tgt_path, tst_path, encoder="pcm_f32le") # 2.3. Load with SciPy found = load_wav(tst_path, normalize=False)[0] # 3.1. Convert the original wav to target format with ffmpeg - acodec = _get_encoder(data.dtype, format, encoding, bits_per_sample) - sample_fmt = _get_encoder_format(format, bits_per_sample) - _convert_audio_file(src_path, sox_path, acodec=acodec, format=sample_fmt) + muxer, encoder, sample_fmt = _parse_save_args(ext, format, encoding, bits_per_sample) + _convert_audio_file(src_path, sox_path, muxer=muxer, encoder=encoder, sample_fmt=sample_fmt) # 3.2. Convert the target format to wav with ffmpeg - _convert_audio_file(sox_path, ref_path, acodec="pcm_f32le") + _convert_audio_file(sox_path, ref_path, encoder="pcm_f32le") # 3.3. Load with SciPy expected = load_wav(ref_path, normalize=False)[0] diff --git a/torchaudio/io/_compat.py b/torchaudio/io/_compat.py index 7b122cbcc8..723b7fcaeb 100644 --- a/torchaudio/io/_compat.py +++ b/torchaudio/io/_compat.py @@ -102,7 +102,8 @@ def load_audio_fileobj( format: Optional[str] = None, buffer_size: int = 4096, ) -> Tuple[torch.Tensor, int]: - s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size) + demuxer = "ogg" if format == "vorbis" else format + s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, demuxer, None, buffer_size) sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate) filter = _get_load_filter(frame_offset, num_frames, convert) waveform = _load_audio_fileobj(s, filter, channels_first) @@ -131,7 +132,7 @@ def _native_endianness() -> str: return "be" -def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int) -> str: +def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str: if bits_per_sample not in {None, 8, 16, 24, 32, 64}: raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.") endianness = _native_endianness() @@ -148,49 +149,93 @@ def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int if bits_per_sample == 8: raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.") return f"pcm_s{bits_per_sample}{endianness}" - elif encoding == "PCM_U": + if encoding == "PCM_U": if bits_per_sample in (None, 8): return "pcm_u8" raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.") - elif encoding == "PCM_F": + if encoding == "PCM_F": if not bits_per_sample: bits_per_sample = 32 if bits_per_sample in (32, 64): return f"pcm_f{bits_per_sample}{endianness}" raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.") - elif encoding == "ULAW": + if encoding == "ULAW": if bits_per_sample in (None, 8): return "pcm_mulaw" raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.") - elif encoding == "ALAW": + if encoding == "ALAW": if bits_per_sample in (None, 8): return "pcm_alaw" raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.") raise ValueError(f"WAV encoding {encoding} is not supported.") -def _get_encoder(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int) -> str: - if format == "wav": - return _get_encoder_for_wav(dtype, encoding, bits_per_sample) - if format == "flac": - return "flac" - if format in ("ogg", "vorbis"): - if encoding or bits_per_sample: - raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.") - return "vorbis" - return format +def _get_flac_sample_fmt(bps): + if bps is None or bps == 16: + return "s16" + if bps == 24: + return "s32" + raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).") -def _get_encoder_format(format: str, bits_per_sample: Optional[int]) -> str: - if format == "flac": - if not bits_per_sample: - return "s16" - if bits_per_sample == 24: - return "s32" - if bits_per_sample == 16: - return "s16" - raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bits_per_sample} specified).") - return None +def _parse_save_args( + ext: Optional[str], + format: Optional[str], + encoding: Optional[str], + bps: Optional[int], +): + # torchaudio's save function accepts the followings, which do not 1to1 map + # to FFmpeg. + # + # - format: audio format + # - bits_per_sample: encoder sample format + # - encoding: such as PCM_U8. + # + # In FFmpeg, format is specified with the following three (and more) + # + # - muxer: could be audio format or container format. + # the one we passed to the constructor of StreamWriter + # - encoder: the audio encoder used to encode audio + # - encoder sample format: the format used by encoder to encode audio. + # + # If encoder sample format is different from source sample format, StreamWriter + # will insert a filter automatically. + # + def _type(spec): + # either format is exactly the specified one + # or extension matches to the spec AND there is no format override. + return format == spec or (format is None and ext == spec) + + if _type("wav") or _type("amb"): + # wav is special because it supports different encoding through encoders + # each encoder only supports one encoder format + # + # amb format is a special case originated from libsox. + # It is basically a WAV format, with slight modification. + # https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795 + # It is a format so that decoders will recognize it as ambisonic. + # https://www.ambisonia.com/Members/mleese/file-format-for-b-format/ + # FFmpeg does not recognize amb because it is basically a WAV format. + muxer = "wav" + encoder = _get_encoder_for_wav(encoding, bps) + sample_fmt = None + elif _type("vorbis"): + # FFpmeg does not recognize vorbis extension, while libsox used to do. + # For the sake of bakward compatibility, (and the simplicity), + # we support the case where users want to do save("foo.vorbis") + muxer = "ogg" + encoder = "vorbis" + sample_fmt = None + else: + muxer = format + encoder = None + sample_fmt = None + if _type("flac"): + sample_fmt = _get_flac_sample_fmt(bps) + if _type("ogg"): + sample_fmt = _get_flac_sample_fmt(bps) + print(ext, format, encoding, bps, "===>", muxer, encoder, sample_fmt) + return muxer, encoder, sample_fmt # NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript. @@ -204,25 +249,27 @@ def save_audio( bits_per_sample: Optional[int] = None, buffer_size: int = 4096, ) -> None: + ext = None if hasattr(uri, "write"): if format is None: raise RuntimeError("'format' is required when saving to file object.") else: uri = os.path.normpath(uri) - s = StreamWriter(uri, format=format, buffer_size=buffer_size) - if format is None: - tokens = str(uri).split(".") - if len(tokens) > 1: - format = tokens[-1].lower() + if tokens := str(uri).split(".")[1:]: + ext = tokens[-1].lower() + + muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample) if channels_first: src = src.T + + s = StreamWriter(uri, format=muxer, buffer_size=buffer_size) s.add_audio_stream( sample_rate, num_channels=src.size(-1), format=_get_sample_format(src.dtype), - encoder=_get_encoder(src.dtype, format, encoding, bits_per_sample), - encoder_format=_get_encoder_format(format, bits_per_sample), + encoder=encoder, + encoder_format=enc_fmt, ) with s.open(): s.write_audio_chunk(0, src)