Skip to content

Commit

Permalink
Refactor arg mapping in ffmpeg save function (#3387)
Browse files Browse the repository at this point in the history
Summary:
The arguments of TorchAudio's save function ("format", "bits_per_sample" and "encoding")
are not one-to-one mapping to the arguments of FFmpeg encoding.

For example, to use vorbis codec, FFmpeg expects "ogg" container/extension with "vorbis"
encoder. It does not recognize "vorbis" extension like TorchAudio (libsox) does.

This commit refactors the logic to parse/map the arguments.

As a result it now properly works with vorbis and mp3 extension.

Pull Request resolved: #3387

Reviewed By: hwangjeff

Differential Revision: D46328787

Pulled By: mthrok

fbshipit-source-id: 36f993952a062bfec58a8b51be6aa86297571f90
  • Loading branch information
mthrok authored and facebook-github-bot committed Jun 1, 2023
1 parent d6dd497 commit b99e5f4
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 60 deletions.
25 changes: 11 additions & 14 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from parameterized import parameterized
from torchaudio._backend.utils import get_load_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.io._compat import _get_encoder
from torchaudio.io._compat import _parse_save_args

from torchaudio_unittest.backend.dispatcher.sox.common import name_func
from torchaudio_unittest.common_utils import (
Expand Down Expand Up @@ -56,11 +56,10 @@ def assert_format(
|
| 1. Generate given format with Sox
|
v 3. Convert to wav with FFmpeg
given format ----------------------> wav
| |
| 2. Load with torchaudio | 4. Load with scipy
+ ----------------------------------+ 3. Convert to wav with FFmpeg
| |
| 2. Load the given format | 4. Load with scipy
| with torchaudio |
v v
tensor ----------> x <----------- tensor
5. Compare
Expand All @@ -72,7 +71,6 @@ def assert_format(
By combining i & ii, step 2. and 4. allow for loading reference given format
data without using torchaudio
"""

path = self.get_temp_path(f"1.original.{format}")
ref_path = self.get_temp_path("2.reference.wav")

Expand All @@ -91,15 +89,15 @@ def assert_format(

# 3. Convert to wav with ffmpeg
if normalize:
acodec = "pcm_f32le"
encoder = "pcm_f32le"
else:
encoding_map = {
"floating-point": "PCM_F",
"signed-integer": "PCM_S",
"unsigned-integer": "PCM_U",
}
acodec = _get_encoder(data.dtype, "wav", encoding_map.get(encoding), bit_depth)
_convert_audio_file(path, ref_path, acodec=acodec)
_, encoder, _ = _parse_save_args(format, format, encoding_map.get(encoding), bit_depth)
_convert_audio_file(path, ref_path, encoder=encoder)

# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
Expand Down Expand Up @@ -277,7 +275,7 @@ def test_opus(self, bitrate, num_channels, compression_level):
"""`self._load` can load opus file correctly."""
ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
wav_path = self.get_temp_path(f"{bitrate}_{compression_level}_{num_channels}ch.opus.wav")
_convert_audio_file(ops_path, wav_path, acodec="pcm_f32le")
_convert_audio_file(ops_path, wav_path, encoder="pcm_f32le")

expected, sample_rate = load_wav(wav_path)
found, sr = self._load(ops_path)
Expand All @@ -301,15 +299,14 @@ def test_sphere(self, sample_rate, num_channels):
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
["int16"],
[3, 4, 16],
[False, True],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
def test_amb(self, dtype, num_channels, normalize, sample_rate=8000):
"""`self._load` can load amb format correctly."""
bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype)
Expand Down
30 changes: 17 additions & 13 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from parameterized import parameterized
from torchaudio._backend.utils import get_save_func
from torchaudio.io._compat import _get_encoder, _get_encoder_format
from torchaudio.io._compat import _parse_save_args

from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
Expand All @@ -24,12 +24,14 @@
)


def _convert_audio_file(src_path, dst_path, format=None, acodec=None):
command = ["ffmpeg", "-y", "-i", src_path, "-strict", "-2"]
if format:
command += ["-sample_fmt", format]
if acodec:
command += ["-acodec", acodec]
def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None):
command = ["ffmpeg", "-hide_banner", "-y", "-i", src_path, "-strict", "-2"]
if muxer:
command += ["-f", muxer]
if encoder:
command += ["-acodec", encoder]
if sample_fmt:
command += ["-sample_fmt", sample_fmt]
command += [dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
Expand Down Expand Up @@ -100,8 +102,10 @@ def assert_save_consistency(
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
elif test_mode == "fileobj":
ext = None
with open(tgt_path, "bw") as file_:
self._save(
file_,
Expand All @@ -113,6 +117,7 @@ def assert_save_consistency(
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
ext = None
self._save(
file_,
data,
Expand All @@ -127,16 +132,15 @@ def assert_save_consistency(
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with ffmpeg
_convert_audio_file(tgt_path, tst_path, acodec="pcm_f32le")
_convert_audio_file(tgt_path, tst_path, encoder="pcm_f32le")
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]

# 3.1. Convert the original wav to target format with ffmpeg
acodec = _get_encoder(data.dtype, format, encoding, bits_per_sample)
sample_fmt = _get_encoder_format(format, bits_per_sample)
_convert_audio_file(src_path, sox_path, acodec=acodec, format=sample_fmt)
muxer, encoder, sample_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
_convert_audio_file(src_path, sox_path, muxer=muxer, encoder=encoder, sample_fmt=sample_fmt)
# 3.2. Convert the target format to wav with ffmpeg
_convert_audio_file(sox_path, ref_path, acodec="pcm_f32le")
_convert_audio_file(sox_path, ref_path, encoder="pcm_f32le")
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]

Expand Down
113 changes: 80 additions & 33 deletions torchaudio/io/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def load_audio_fileobj(
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
demuxer = "ogg" if format == "vorbis" else format
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, demuxer, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio_fileobj(s, filter, channels_first)
Expand Down Expand Up @@ -131,7 +132,7 @@ def _native_endianness() -> str:
return "be"


def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int) -> str:
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
endianness = _native_endianness()
Expand All @@ -148,49 +149,93 @@ def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int
if bits_per_sample == 8:
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
return f"pcm_s{bits_per_sample}{endianness}"
elif encoding == "PCM_U":
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "pcm_u8"
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
elif encoding == "PCM_F":
if encoding == "PCM_F":
if not bits_per_sample:
bits_per_sample = 32
if bits_per_sample in (32, 64):
return f"pcm_f{bits_per_sample}{endianness}"
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
elif encoding == "ULAW":
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "pcm_mulaw"
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
elif encoding == "ALAW":
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "pcm_alaw"
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
raise ValueError(f"WAV encoding {encoding} is not supported.")


def _get_encoder(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int) -> str:
if format == "wav":
return _get_encoder_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
return "flac"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
return "vorbis"
return format
def _get_flac_sample_fmt(bps):
if bps is None or bps == 16:
return "s16"
if bps == 24:
return "s32"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")


def _get_encoder_format(format: str, bits_per_sample: Optional[int]) -> str:
if format == "flac":
if not bits_per_sample:
return "s16"
if bits_per_sample == 24:
return "s32"
if bits_per_sample == 16:
return "s16"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bits_per_sample} specified).")
return None
def _parse_save_args(
ext: Optional[str],
format: Optional[str],
encoding: Optional[str],
bps: Optional[int],
):
# torchaudio's save function accepts the followings, which do not 1to1 map
# to FFmpeg.
#
# - format: audio format
# - bits_per_sample: encoder sample format
# - encoding: such as PCM_U8.
#
# In FFmpeg, format is specified with the following three (and more)
#
# - muxer: could be audio format or container format.
# the one we passed to the constructor of StreamWriter
# - encoder: the audio encoder used to encode audio
# - encoder sample format: the format used by encoder to encode audio.
#
# If encoder sample format is different from source sample format, StreamWriter
# will insert a filter automatically.
#
def _type(spec):
# either format is exactly the specified one
# or extension matches to the spec AND there is no format override.
return format == spec or (format is None and ext == spec)

if _type("wav") or _type("amb"):
# wav is special because it supports different encoding through encoders
# each encoder only supports one encoder format
#
# amb format is a special case originated from libsox.
# It is basically a WAV format, with slight modification.
# https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
# It is a format so that decoders will recognize it as ambisonic.
# https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
# FFmpeg does not recognize amb because it is basically a WAV format.
muxer = "wav"
encoder = _get_encoder_for_wav(encoding, bps)
sample_fmt = None
elif _type("vorbis"):
# FFpmeg does not recognize vorbis extension, while libsox used to do.
# For the sake of bakward compatibility, (and the simplicity),
# we support the case where users want to do save("foo.vorbis")
muxer = "ogg"
encoder = "vorbis"
sample_fmt = None
else:
muxer = format
encoder = None
sample_fmt = None
if _type("flac"):
sample_fmt = _get_flac_sample_fmt(bps)
if _type("ogg"):
sample_fmt = _get_flac_sample_fmt(bps)
print(ext, format, encoding, bps, "===>", muxer, encoder, sample_fmt)
return muxer, encoder, sample_fmt


# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
Expand All @@ -204,25 +249,27 @@ def save_audio(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
ext = None
if hasattr(uri, "write"):
if format is None:
raise RuntimeError("'format' is required when saving to file object.")
else:
uri = os.path.normpath(uri)
s = StreamWriter(uri, format=format, buffer_size=buffer_size)
if format is None:
tokens = str(uri).split(".")
if len(tokens) > 1:
format = tokens[-1].lower()
if tokens := str(uri).split(".")[1:]:
ext = tokens[-1].lower()

muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)

if channels_first:
src = src.T

s = StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s.add_audio_stream(
sample_rate,
num_channels=src.size(-1),
format=_get_sample_format(src.dtype),
encoder=_get_encoder(src.dtype, format, encoding, bits_per_sample),
encoder_format=_get_encoder_format(format, bits_per_sample),
encoder=encoder,
encoder_format=enc_fmt,
)
with s.open():
s.write_audio_chunk(0, src)

0 comments on commit b99e5f4

Please sign in to comment.