Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor arg mapping in ffmpeg save function #3387

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from parameterized import parameterized
from torchaudio._backend.utils import get_load_func
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.io._compat import _get_encoder
from torchaudio.io._compat import _parse_save_args

from torchaudio_unittest.backend.dispatcher.sox.common import name_func
from torchaudio_unittest.common_utils import (
Expand Down Expand Up @@ -56,11 +56,10 @@ def assert_format(
|
| 1. Generate given format with Sox
|
v 3. Convert to wav with FFmpeg
given format ----------------------> wav
| |
| 2. Load with torchaudio | 4. Load with scipy
+ ----------------------------------+ 3. Convert to wav with FFmpeg
| |
| 2. Load the given format | 4. Load with scipy
| with torchaudio |
v v
tensor ----------> x <----------- tensor
5. Compare
Expand All @@ -72,7 +71,6 @@ def assert_format(
By combining i & ii, step 2. and 4. allow for loading reference given format
data without using torchaudio
"""

path = self.get_temp_path(f"1.original.{format}")
ref_path = self.get_temp_path("2.reference.wav")

Expand All @@ -91,15 +89,15 @@ def assert_format(

# 3. Convert to wav with ffmpeg
if normalize:
acodec = "pcm_f32le"
encoder = "pcm_f32le"
else:
encoding_map = {
"floating-point": "PCM_F",
"signed-integer": "PCM_S",
"unsigned-integer": "PCM_U",
}
acodec = _get_encoder(data.dtype, "wav", encoding_map.get(encoding), bit_depth)
_convert_audio_file(path, ref_path, acodec=acodec)
_, encoder, _ = _parse_save_args(format, format, encoding_map.get(encoding), bit_depth)
_convert_audio_file(path, ref_path, encoder=encoder)

# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
Expand Down Expand Up @@ -277,7 +275,7 @@ def test_opus(self, bitrate, num_channels, compression_level):
"""`self._load` can load opus file correctly."""
ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
wav_path = self.get_temp_path(f"{bitrate}_{compression_level}_{num_channels}ch.opus.wav")
_convert_audio_file(ops_path, wav_path, acodec="pcm_f32le")
_convert_audio_file(ops_path, wav_path, encoder="pcm_f32le")

expected, sample_rate = load_wav(wav_path)
found, sr = self._load(ops_path)
Expand All @@ -301,15 +299,14 @@ def test_sphere(self, sample_rate, num_channels):
@parameterized.expand(
list(
itertools.product(
["float32", "int32", "int16"],
[8000, 16000],
[1, 2],
["int16"],
[3, 4, 16],
[False, True],
)
),
name_func=name_func,
)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
def test_amb(self, dtype, num_channels, normalize, sample_rate=8000):
"""`self._load` can load amb format correctly."""
bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype)
Expand Down
30 changes: 17 additions & 13 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from parameterized import parameterized
from torchaudio._backend.utils import get_save_func
from torchaudio.io._compat import _get_encoder, _get_encoder_format
from torchaudio.io._compat import _parse_save_args

from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
Expand All @@ -24,12 +24,14 @@
)


def _convert_audio_file(src_path, dst_path, format=None, acodec=None):
command = ["ffmpeg", "-y", "-i", src_path, "-strict", "-2"]
if format:
command += ["-sample_fmt", format]
if acodec:
command += ["-acodec", acodec]
def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None):
command = ["ffmpeg", "-hide_banner", "-y", "-i", src_path, "-strict", "-2"]
if muxer:
command += ["-f", muxer]
if encoder:
command += ["-acodec", encoder]
if sample_fmt:
command += ["-sample_fmt", sample_fmt]
command += [dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
Expand Down Expand Up @@ -100,8 +102,10 @@ def assert_save_consistency(
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
elif test_mode == "fileobj":
ext = None
with open(tgt_path, "bw") as file_:
self._save(
file_,
Expand All @@ -113,6 +117,7 @@ def assert_save_consistency(
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
ext = None
self._save(
file_,
data,
Expand All @@ -127,16 +132,15 @@ def assert_save_consistency(
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with ffmpeg
_convert_audio_file(tgt_path, tst_path, acodec="pcm_f32le")
_convert_audio_file(tgt_path, tst_path, encoder="pcm_f32le")
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]

# 3.1. Convert the original wav to target format with ffmpeg
acodec = _get_encoder(data.dtype, format, encoding, bits_per_sample)
sample_fmt = _get_encoder_format(format, bits_per_sample)
_convert_audio_file(src_path, sox_path, acodec=acodec, format=sample_fmt)
muxer, encoder, sample_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
_convert_audio_file(src_path, sox_path, muxer=muxer, encoder=encoder, sample_fmt=sample_fmt)
# 3.2. Convert the target format to wav with ffmpeg
_convert_audio_file(sox_path, ref_path, acodec="pcm_f32le")
_convert_audio_file(sox_path, ref_path, encoder="pcm_f32le")
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]

Expand Down
113 changes: 80 additions & 33 deletions torchaudio/io/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def load_audio_fileobj(
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
demuxer = "ogg" if format == "vorbis" else format
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, demuxer, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio_fileobj(s, filter, channels_first)
Expand Down Expand Up @@ -131,7 +132,7 @@ def _native_endianness() -> str:
return "be"


def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int) -> str:
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
endianness = _native_endianness()
Expand All @@ -148,49 +149,93 @@ def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int
if bits_per_sample == 8:
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
return f"pcm_s{bits_per_sample}{endianness}"
elif encoding == "PCM_U":
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "pcm_u8"
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
elif encoding == "PCM_F":
if encoding == "PCM_F":
if not bits_per_sample:
bits_per_sample = 32
if bits_per_sample in (32, 64):
return f"pcm_f{bits_per_sample}{endianness}"
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
elif encoding == "ULAW":
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "pcm_mulaw"
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
elif encoding == "ALAW":
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "pcm_alaw"
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
raise ValueError(f"WAV encoding {encoding} is not supported.")


def _get_encoder(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int) -> str:
if format == "wav":
return _get_encoder_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
return "flac"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
return "vorbis"
return format
def _get_flac_sample_fmt(bps):
if bps is None or bps == 16:
return "s16"
if bps == 24:
return "s32"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")


def _get_encoder_format(format: str, bits_per_sample: Optional[int]) -> str:
if format == "flac":
if not bits_per_sample:
return "s16"
if bits_per_sample == 24:
return "s32"
if bits_per_sample == 16:
return "s16"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bits_per_sample} specified).")
return None
def _parse_save_args(
ext: Optional[str],
format: Optional[str],
encoding: Optional[str],
bps: Optional[int],
):
# torchaudio's save function accepts the followings, which do not 1to1 map
# to FFmpeg.
#
# - format: audio format
# - bits_per_sample: encoder sample format
# - encoding: such as PCM_U8.
#
# In FFmpeg, format is specified with the following three (and more)
#
# - muxer: could be audio format or container format.
# the one we passed to the constructor of StreamWriter
# - encoder: the audio encoder used to encode audio
# - encoder sample format: the format used by encoder to encode audio.
#
# If encoder sample format is different from source sample format, StreamWriter
# will insert a filter automatically.
#
def _type(spec):
# either format is exactly the specified one
# or extension matches to the spec AND there is no format override.
return format == spec or (format is None and ext == spec)

if _type("wav") or _type("amb"):
# wav is special because it supports different encoding through encoders
# each encoder only supports one encoder format
#
# amb format is a special case originated from libsox.
# It is basically a WAV format, with slight modification.
# https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
# It is a format so that decoders will recognize it as ambisonic.
# https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
# FFmpeg does not recognize amb because it is basically a WAV format.
muxer = "wav"
encoder = _get_encoder_for_wav(encoding, bps)
sample_fmt = None
elif _type("vorbis"):
# FFpmeg does not recognize vorbis extension, while libsox used to do.
# For the sake of bakward compatibility, (and the simplicity),
# we support the case where users want to do save("foo.vorbis")
muxer = "ogg"
encoder = "vorbis"
sample_fmt = None
else:
muxer = format
encoder = None
sample_fmt = None
if _type("flac"):
sample_fmt = _get_flac_sample_fmt(bps)
if _type("ogg"):
sample_fmt = _get_flac_sample_fmt(bps)
print(ext, format, encoding, bps, "===>", muxer, encoder, sample_fmt)
return muxer, encoder, sample_fmt


# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
Expand All @@ -204,25 +249,27 @@ def save_audio(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
ext = None
if hasattr(uri, "write"):
if format is None:
raise RuntimeError("'format' is required when saving to file object.")
else:
uri = os.path.normpath(uri)
s = StreamWriter(uri, format=format, buffer_size=buffer_size)
if format is None:
tokens = str(uri).split(".")
if len(tokens) > 1:
format = tokens[-1].lower()
if tokens := str(uri).split(".")[1:]:
ext = tokens[-1].lower()

muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)

if channels_first:
src = src.T

s = StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s.add_audio_stream(
sample_rate,
num_channels=src.size(-1),
format=_get_sample_format(src.dtype),
encoder=_get_encoder(src.dtype, format, encoding, bits_per_sample),
encoder_format=_get_encoder_format(format, bits_per_sample),
encoder=encoder,
encoder_format=enc_fmt,
)
with s.open():
s.write_audio_chunk(0, src)