Skip to content

Commit

Permalink
Add file-like object support to Streaming API (pytorch#2400)
Browse files Browse the repository at this point in the history
Summary:
This commit adds file-like object support to Streaming API.

## Features
- File-like objects are expected to implement `read(self, n)`.
- Additionally `seek(self, offset, whence)` is used if available.
- Without `seek` method, some formats cannot be decoded properly.
  - To work around this, one can use the existing `decoder` option to tell what decoder it should use.
  - The set of `decoder` and `decoder_option` arguments were added to `add_basic_[audio|video]_stream` method, similar to `add_[audio|video]_stream`.
  - So as to have the arguments common to both audio and video in from of the rest of the arguments, the order of the arguments are changed.
  - Also `dtype` and `format` arguments were changed to make them consistent across audio/video methods.

## Code structure

The approach is very similar to how file-like object is supported in sox-based I/O.
In Streaming API if the input src is string, it is passed to the implementation bound with TorchBind,
if the src has `read` attribute, it is passed to the same implementation bound via PyBind 11.

![Untitled drawing](https://user-images.githubusercontent.com/855818/169098391-6116afee-7b29-460d-b50d-1037bb8a359d.png)

## Refactoring involved
- Extracted to pytorch#2402
  - Some implementation in the original TorchBind surface layer is converted to Wrapper class so that they can be re-used from PyBind11 bindings. The wrapper class serves to simplify the binding.
  - `add_basic_[audio|video]_stream` methods were removed from C++ layer as it was just constructing string and passing it to `add_[audio|video]_stream` method, which is simpler to do in Python.
  - The original core Streamer implementation kept the use of types in `c10` namespace minimum. All the `c10::optional` and `c10::Dict` were converted to the equivalents of `std` at binding layer. But since they work fine with PyBind11, Streamer core methods deal them directly.
- On Python side, the switch of binding happens in the constructor of `StreamReader` class. Since all the methods have to be delegated to the same set of binding, a backend was introduced, which is abstracted away from user code.

## TODO:
- [x] Check if it is possible to stream MP4 (yuv420p) from S3 and directly decode (with/without HW decoding).

Pull Request resolved: pytorch#2400

Differential Revision: D36520073

Pulled By: mthrok

fbshipit-source-id: 3f79875e7635386283893a7c08cd19d4d0f8efa5
  • Loading branch information
mthrok authored and facebook-github-bot committed May 19, 2022
1 parent 38cf5b7 commit 752cbea
Show file tree
Hide file tree
Showing 13 changed files with 539 additions and 223 deletions.
19 changes: 11 additions & 8 deletions examples/tutorials/streaming_api_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,21 +250,24 @@
# When the StreamReader buffered this number of chunks and is asked to pull
# more frames, StreamReader drops the old frames/chunks.
# - ``stream_index``: The index of the source stream.
# - ``decoder``: If provided, override the decoder. Useful if it fails to detect
# the codec.
# - ``decoder_option``: The option for the decoder.
#
# For audio output stream, you can provide the following additional
# parameters to change the audio properties.
#
# - ``sample_rate``: When provided, StreamReader resamples the audio on-the-fly.
# - ``dtype``: By default the StreamReader returns tensor of `float32` dtype,
# with sample values ranging `[-1, 1]`. By providing ``dtype`` argument
# - ``format``: By default the StreamReader returns tensor of `float32` dtype,
# with sample values ranging `[-1, 1]`. By providing ``format`` argument
# the resulting dtype and value range is changed.
# - ``sample_rate``: When provided, StreamReader resamples the audio on-the-fly.
#
# For video output stream, the following parameters are available.
#
# - ``format``: Change the image format.
# - ``frame_rate``: Change the frame rate by dropping or duplicating
# frames. No interpolation is performed.
# - ``width``, ``height``: Change the image size.
# - ``format``: Change the image format.
#

######################################################################
Expand Down Expand Up @@ -298,7 +301,7 @@
# streamer.add_basic_video_stream(
# frames_per_chunk=10,
# frame_rate=30,
# format="RGB"
# format="rgb24"
# )
#
# # Stream video from source stream `j`,
Expand All @@ -310,7 +313,7 @@
# frame_rate=30,
# width=128,
# height=128,
# format="BGR"
# format="bgr24"
# )
#

Expand Down Expand Up @@ -428,7 +431,7 @@
frame_rate=1,
width=960,
height=540,
format="RGB",
format="rgb24",
)

# Video stream with 320x320 (stretched) at 3 FPS, grayscale
Expand All @@ -437,7 +440,7 @@
frame_rate=3,
width=320,
height=320,
format="GRAY",
format="gray",
)
# fmt: on

Expand Down
116 changes: 79 additions & 37 deletions test/torchaudio_unittest/io/stream_reader_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import (
get_asset_path,
get_image,
Expand All @@ -22,12 +22,49 @@
)


def get_video_asset(file="nasa_13013.mp4"):
return get_asset_path(file)
################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj
_TEST_FILEOBJ = "src_is_fileobj"


def _class_name(cls, _, params):
return f'{cls.__name__}{"_fileobj" if params[_TEST_FILEOBJ] else "_path"}'


_media_source = parameterized_class((_TEST_FILEOBJ,), [(False,), (True,)], class_name_func=_class_name)


class _MediaSourceMixin:
def setUp(self):
super().setUp()
self.src = None

@property
def test_fileobj(self):
return getattr(self, _TEST_FILEOBJ)

def get_video_asset(self, file="nasa_13013.mp4"):
if self.src is not None:
raise ValueError("get_video_asset can be called only once.")

path = get_asset_path(file)
if self.test_fileobj:
self.src = open(path, "rb")
return self.src
return path

def tearDown(self):
if self.src is not None:
self.src.close()
super().tearDown()


################################################################################


@skipIfNoFFmpeg
class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
@_media_source
class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
"""Test suite for interface behaviors around StreamReader"""

def test_streamer_invalid_input(self):
Expand All @@ -48,14 +85,13 @@ def test_streamer_invalid_input(self):
def test_streamer_invalide_option(self, invalid_keys, options):
"""When invalid options are given, StreamReader raises an exception with these keys"""
options.update({k: k for k in invalid_keys})
src = get_video_asset()
with self.assertRaises(RuntimeError) as ctx:
StreamReader(src, option=options)
StreamReader(self.get_video_asset(), option=options)
assert all(f'"{k}"' in str(ctx.exception) for k in invalid_keys)

def test_src_info(self):
"""`get_src_stream_info` properly fetches information"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
assert s.num_src_streams == 6

expected = [
Expand Down Expand Up @@ -112,35 +148,35 @@ def test_src_info(self):
bit_rate=None,
),
]
for i, exp in enumerate(expected):
assert exp == s.get_src_stream_info(i)
output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output

def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
for i in [-1, 6, 7, 8]:
with self.assertRaises(IndexError):
with self.assertRaises(Exception):
s.get_src_stream_info(i)

def test_default_streams(self):
"""default stream is not None"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
assert s.default_audio_stream is not None
assert s.default_video_stream is not None

def test_default_audio_stream_none(self):
"""default audio stream is None for video without audio"""
s = StreamReader(get_video_asset("nasa_13013_no_audio.mp4"))
s = StreamReader(self.get_video_asset("nasa_13013_no_audio.mp4"))
assert s.default_audio_stream is None

def test_default_video_stream_none(self):
"""default video stream is None for video with only audio"""
s = StreamReader(get_video_asset("nasa_13013_no_video.mp4"))
s = StreamReader(self.get_video_asset("nasa_13013_no_video.mp4"))
assert s.default_video_stream is None

def test_num_out_stream(self):
"""num_out_streams gives the correct count of output streams"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
n, m = 6, 4
for i in range(n):
assert s.num_out_streams == i
Expand All @@ -158,10 +194,10 @@ def test_num_out_stream(self):

def test_basic_audio_stream(self):
"""`add_basic_audio_stream` constructs a correct filter."""
s = StreamReader(get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=None)
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, format=None)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=torch.int16)
s.add_basic_audio_stream(frames_per_chunk=-1, format="s16p")

sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_audio_stream
Expand All @@ -177,11 +213,11 @@ def test_basic_audio_stream(self):

def test_basic_video_stream(self):
"""`add_basic_video_stream` constructs a correct filter."""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_video_stream(frames_per_chunk=-1, format=None)
s.add_basic_video_stream(frames_per_chunk=-1, width=3, height=5)
s.add_basic_video_stream(frames_per_chunk=-1, frame_rate=7)
s.add_basic_video_stream(frames_per_chunk=-1, format="BGR")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")

sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_video_stream
Expand All @@ -201,7 +237,7 @@ def test_basic_video_stream(self):

def test_remove_streams(self):
"""`remove_stream` removes the correct output stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=24000)
s.add_basic_video_stream(frames_per_chunk=-1, width=16, height=16)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
Expand All @@ -221,21 +257,21 @@ def test_remove_streams(self):

def test_remove_stream_invalid(self):
"""Attempt to remove invalid output streams raises IndexError"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
for i in range(-3, 3):
with self.assertRaises(IndexError):
with self.assertRaises(Exception):
s.remove_stream(i)

s.add_audio_stream(frames_per_chunk=-1)
for i in range(-3, 3):
if i == 0:
continue
with self.assertRaises(IndexError):
with self.assertRaises(Exception):
s.remove_stream(i)

def test_process_packet(self):
"""`process_packet` method returns 0 while there is a packet in source stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
# nasa_1013.mp3 contains 1023 packets.
for _ in range(1023):
code = s.process_packet()
Expand All @@ -246,19 +282,19 @@ def test_process_packet(self):

def test_pop_chunks_no_output_stream(self):
"""`pop_chunks` method returns empty list when there is no output stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
assert s.pop_chunks() == []

def test_pop_chunks_empty_buffer(self):
"""`pop_chunks` method returns None when a buffer is empty"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1)
s.add_basic_video_stream(frames_per_chunk=-1)
assert s.pop_chunks() == [None, None]

def test_pop_chunks_exhausted_stream(self):
"""`pop_chunks` method returns None when the source stream is exhausted"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
# video is 16.57 seconds.
# audio streams per 10 second chunk
# video streams per 20 second chunk
Expand All @@ -284,14 +320,14 @@ def test_pop_chunks_exhausted_stream(self):

def test_stream_empty(self):
"""`stream` fails when no output stream is configured"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
with self.assertRaises(RuntimeError):
next(s.stream())

def test_stream_smoke_test(self):
"""`stream` streams chunks fine"""
w, h = 256, 198
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=2000, sample_rate=8000)
s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h)
for i, (achunk, vchunk) in enumerate(s.stream()):
Expand All @@ -302,7 +338,7 @@ def test_stream_smoke_test(self):

def test_seek(self):
"""Calling `seek` multiple times should not segfault"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
for i in range(10):
s.seek(i)
for _ in range(0):
Expand All @@ -312,8 +348,8 @@ def test_seek(self):

def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
s = StreamReader(get_video_asset())
with self.assertRaises(ValueError):
s = StreamReader(self.get_video_asset())
with self.assertRaises(Exception):
s.seek(-1.0)


Expand All @@ -327,9 +363,9 @@ def _get_reference_wav(self, sample_rate, channels_first=False, **kwargs):
save_wav(path, data, sample_rate, channels_first=channels_first)
return path, data

def _test_wav(self, path, original, dtype):
def _test_wav(self, path, original, format):
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=dtype)
s.add_basic_audio_stream(frames_per_chunk=-1, format=format)
s.process_all_packets()
(output,) = s.pop_chunks()
self.assertEqual(original, output)
Expand All @@ -342,10 +378,16 @@ def test_basic_audio_stream(self, dtype, num_channels):
"""`basic_audio_stream` can load WAV file properly."""
path, original = self._get_reference_wav(8000, dtype=dtype, num_channels=num_channels)

format = {
"uint8": "u8p",
"int16": "s16p",
"int32": "s32p",
}[dtype]

# provide the matching dtype
self._test_wav(path, original, getattr(torch, dtype))
self._test_wav(path, original, format=format)
# use the internal dtype ffmpeg picks
self._test_wav(path, original, None)
self._test_wav(path, original, format=None)

@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
Expand Down
7 changes: 6 additions & 1 deletion tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def get_ext_modules():
]
)
if _USE_FFMPEG:
modules.append(Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]))
modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]),
Extension(name="torchaudio._torchaudio_ffmpeg", sources=[]),
]
)
return modules


Expand Down
Loading

0 comments on commit 752cbea

Please sign in to comment.