Skip to content

Commit

Permalink
Fail on Python if sox_io info/load does not succeed (pytorch#2423)
Browse files Browse the repository at this point in the history
Summary:
Extracted from pytorch#2419. Move the failure of sox_io from C++ to Python layer.

Pull Request resolved: pytorch#2423

Reviewed By: carolineechen

Differential Revision: D36766152

Pulled By: mthrok

fbshipit-source-id: 53f897a608e97b81ebe5df29577374d88ce178f3
  • Loading branch information
mthrok authored and facebook-github-bot committed May 31, 2022
1 parent c209b70 commit b56f60b
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 28 deletions.
4 changes: 2 additions & 2 deletions test/torchaudio_unittest/backend/sox_io/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def test_fileobj_large_header(self, ext, dtype):
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])

with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"):
with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)

with self._set_buffer_size(16384):
Expand Down Expand Up @@ -583,5 +583,5 @@ def test_info_fail(self):
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.info(path)
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/backend/sox_io/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,5 +627,5 @@ def test_load_fail(self):
When attempted to load a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.load(path)
50 changes: 44 additions & 6 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,37 @@
from .common import AudioMetaData


# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_info(filepath: str, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(filepath))


def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(fileobj))


# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
raise RuntimeError("Failed to load audio from {}".format(filepath))


def _fail_load_fileobj(fileobj, *args, **kwargs):
raise RuntimeError(f"Failed to load audio from {fileobj}")


_fallback_info = _fail_info
_fallback_info_fileobj = _fail_info_fileobj
_fallback_load = _fail_load
_fallback_load_fileobj = _fail_load_fileobj


@_mod_utils.requires_sox()
def info(
filepath: str,
Expand Down Expand Up @@ -46,11 +77,14 @@ def info(
if not torch.jit.is_scripting():
if hasattr(filepath, "read"):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format)
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
assert sinfo is not None # for TorchScript compatibility
return AudioMetaData(*sinfo)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info(filepath, format)


@_mod_utils.requires_sox()
Expand Down Expand Up @@ -145,15 +179,19 @@ def load(
"""
if not torch.jit.is_scripting():
if hasattr(filepath, "read"):
return torchaudio._torchaudio.load_audio_fileobj(
ret = torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
return ret
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath)
ret = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
assert ret is not None # for TorchScript compatibility
return ret
if ret is not None:
return ret
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)


@_mod_utils.requires_sox()
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/csrc/pybind/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ auto apply_effects_fileobj(
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));

// In case of streamed data, length can be 0
validate_input_memfile(sf);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}

// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
Expand Down
6 changes: 4 additions & 2 deletions torchaudio/csrc/pybind/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ auto get_info_fileobj(py::object fileobj, c10::optional<std::string> format)
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));

// In case of streamed data, length can be 0
validate_input_memfile(sf);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return c10::optional<MetaDataTuple>{};
}

return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate),
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/csrc/sox/effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ auto apply_effects_file(
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));

validate_input_file(sf, path);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}

const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);

Expand Down
6 changes: 5 additions & 1 deletion torchaudio/csrc/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ c10::optional<MetaDataTuple> get_info_file(
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));

validate_input_file(sf, path);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}

return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
Expand Down
4 changes: 0 additions & 4 deletions torchaudio/csrc/sox/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ void validate_input_file(const SoxFormat& sf, const std::string& path) {
}
}

void validate_input_memfile(const SoxFormat& sf) {
return validate_input_file(sf, "<in memory buffer>");
}

void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU.");
Expand Down
7 changes: 0 additions & 7 deletions torchaudio/csrc/sox/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ struct SoxFormat {
sox_format_t* fd_;
};

///
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf, const std::string& path);

/// Verify that input memory buffer has known encoding, and not empty
void validate_input_memfile(const SoxFormat& sf);

///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
void validate_input_tensor(const torch::Tensor);
Expand Down
10 changes: 7 additions & 3 deletions torchaudio/sox_effects/sox_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,12 @@ def apply_effects_file(
"""
if not torch.jit.is_scripting():
if hasattr(path, "read"):
return torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
ret = torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
if ret is None:
raise RuntimeError("Failed to load audio from {}".format(path))
return ret
path = os.fspath(path)
ret = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
assert ret is not None
return ret
if ret is not None:
return ret
raise RuntimeError("Failed to load audio from {}".format(path))

0 comments on commit b56f60b

Please sign in to comment.