Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper around smart_open to cache transport_params for every call #335

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions lhotse/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lhotse.serialization import Serializable, extension_contains
from lhotse.utils import (Decibels, NonPositiveEnergyError, Pathlike, Seconds, SetContainingAnything, asdict_nonull,
compute_num_samples,
exactly_one_not_null, fastcopy,
exactly_one_not_null, fastcopy, SmartOpen,
ifnone, index_by_id_and_check, perturb_num_samples, split_sequence)

Channels = Union[int, List[int]]
Expand Down Expand Up @@ -80,13 +80,7 @@ def load_audio(
warnings.warn('You requested a subset of a recording that is read from URL. '
'Expect large I/O overhead if you are going to read many chunks like these, '
'since every time we will download the whole file rather than its subset.')
try:
from smart_open import smart_open
except ImportError:
raise ImportError("To use 'url' audio source type, please do 'pip install smart_open' - "
"if you are using S3/GCP/Azure/other cloud-specific URIs, do "
"'pip install smart_open[s3]' (or smart_open[gcp], etc.) instead.")
with smart_open(self.source) as f:
with SmartOpen.open(self.source, 'rb') as f:
source = BytesIO(f.read())
samples, sampling_rate = read_audio(source, offset=offset, duration=duration)

Expand Down
14 changes: 3 additions & 11 deletions lhotse/features/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lilcom
import numpy as np

from lhotse.utils import Pathlike, is_module_available
from lhotse.utils import Pathlike, is_module_available, SmartOpen


class FeaturesWriter(metaclass=ABCMeta):
Expand Down Expand Up @@ -471,7 +471,6 @@ class LilcomURLReader(FeaturesReader):
Downloads Lilcom-compressed files from a URL (S3, GCP, Azure, HTTP, etc.).
``storage_path`` corresponds to the root URL (e.g. "s3://my-data-bucket")
``storage_key`` will be concatenated to ``storage_path`` to form a full URL (e.g. "my-feature-file.llc")
``transport_params`` is an optional paramater that is passed through to ``smart_open``

.. caution::
Requires ``smart_open`` to be installed (``pip install smart_open``).
Expand All @@ -481,7 +480,6 @@ class LilcomURLReader(FeaturesReader):
def __init__(
self,
storage_path: Pathlike,
transport_params: Optional[dict] = None,
*args,
**kwargs
):
Expand All @@ -490,19 +488,17 @@ def __init__(
# We are manually adding the slash to join the base URL and the key.
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.transport_params = transport_params

def read(
self,
key: str,
left_offset_frames: int = 0,
right_offset_frames: Optional[int] = None
) -> np.ndarray:
import smart_open
# We are manually adding the slash to join the base URL and the key.
if key.startswith('/'):
key = key[1:]
with smart_open.open(f'{self.base_url}/{key}', 'rb', transport_params=self.transport_params) as f:
with SmartOpen.open(f'{self.base_url}/{key}', 'rb') as f:
arr = lilcom.decompress(f.read())
return arr[left_offset_frames: right_offset_frames]

Expand All @@ -513,7 +509,6 @@ class LilcomURLWriter(FeaturesWriter):
Writes Lilcom-compressed files to a URL (S3, GCP, Azure, HTTP, etc.).
``storage_path`` corresponds to the root URL (e.g. "s3://my-data-bucket")
``storage_key`` will be concatenated to ``storage_path`` to form a full URL (e.g. "my-feature-file.llc")
``transport_params`` is an optional paramater that is passed through to ``smart_open``

.. caution::
Requires ``smart_open`` to be installed (``pip install smart_open``).
Expand All @@ -524,7 +519,6 @@ def __init__(
self,
storage_path: Pathlike,
tick_power: int = -5,
transport_params: Optional[dict] = None,
*args,
**kwargs
):
Expand All @@ -534,14 +528,12 @@ def __init__(
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.tick_power = tick_power
self.transport_params = transport_params

@property
def storage_path(self) -> str:
return self.base_url

def write(self, key: str, value: np.ndarray) -> str:
import smart_open
# We are manually adding the slash to join the base URL and the key.
if key.startswith('/'):
key = key[1:]
Expand All @@ -550,7 +542,7 @@ def write(self, key: str, value: np.ndarray) -> str:
key = key + '.llc'
output_features_url = f'{self.base_url}/{key}'
serialized_feats = lilcom.compress(value, tick_power=self.tick_power)
with smart_open.open(output_features_url, 'wb', transport_params=self.transport_params) as f:
with SmartOpen.open(output_features_url, 'wb') as f:
f.write(serialized_feats)
return key

Expand Down
64 changes: 64 additions & 0 deletions lhotse/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import random
import uuid
import logging
from contextlib import AbstractContextManager, contextmanager
from dataclasses import asdict, dataclass
from decimal import Decimal, ROUND_HALF_DOWN, ROUND_HALF_UP
Expand Down Expand Up @@ -30,6 +31,69 @@
_lhotse_uuid: Optional[Callable] = None


class SmartOpen:
oplatek marked this conversation as resolved.
Show resolved Hide resolved
"""Wrapper class around smart_open.open method

The smart_open.open attributes are cached as classed attributes - they play the role of singleton pattern.

The SmartOpen.setup method is intended for initial setup.
It imports the `open` method from the optional `smart_open` Python package,
and sets the parameters which are shared between all calls of the `smart_open.open` method.

If you do not call the setup method it is called automatically in SmartOpen.open with the provided parameters.

The example demonstrates that instantiating S3 `session.client` once,
instead using the defaults and leaving the smart_open creating it every time
has dramatic performance benefits of 44s vs 18.9s Wall time.

import boto3
session = boto3.Session()
client = session.client('s3')
from lhotse.utils import SmartOpen

if not slow: # switch between 44s vs 18.9 Wall time
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be nit-picking -- this 44s vs 18.9s seems very specific to your use-case, could we remove it?

Copy link
Collaborator

@pzelasko pzelasko Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also -- can you format it like:

Example::

    >>> import boto3
    >>> session = boto3.Session()
    ...

This will render well with sphinx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you pls check that it is ok now?

Or just tell me how to generate (and check) the docs myself.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The auto generated docs for this PR are here https://lhotse--335.org.readthedocs.build/en/335/

RTD used to add it as a CI test result in GH but the integration broke when I moved Lhotse to lhotse-speech organization… and I couldn’t fix it.

Unfortunately it seems that the API reference doesn’t list lhotse.utils so you won’t see it there anyway.. I haven’t gotten to re-working this part of the docs yet.

SmartOpen.setup(transport_params=dict(client=client))

# Simulating SmartOpen usage as in Lhotse datastructures: AudioSource, Features, etc.
pzelasko marked this conversation as resolved.
Show resolved Hide resolved
for i in range(1000):
SmartOpen.open(s3_url, 'rb') as f:
source = f.read()
"""
transport_params: Optional[Dict] = None
compression: Optional[str] = None
import_err_msg = ("Please do 'pip install smart_open' - "
"if you are using S3/GCP/Azure/other cloud-specific URIs, do "
"'pip install smart_open[s3]' (or smart_open[gcp], etc.) instead.")
smart_open: Optional[Callable] = None

@classmethod
def setup(
cls,
compression: Optional[str]=None,
transport_params: Optional[dict]= None):
try:
from smart_open import open as sm_open
except ImportError:
raise ImportError(cls.import_err_msg)
if cls.transport_params is not None and cls.transport_params != transport_params:
logging.warning(f'SmartOpen.setup second call overwrites existing transport_params with new version'
f'\t\n{cls.transport_params}\t\nvs\t\n{transport_params}')
if cls.compression is not None and cls.compression != compression:
logging.warning(f'SmartOpen.setup second call overwrites existing compression param with new version'
f'\t\n{cls.compression} vs {compression}')
cls.transport_params = transport_params
cls.compression = compression
cls.smart_open = sm_open

@classmethod
def open(cls, uri, mode='rb', compression=None, transport_params=None, **kwargs):
if cls.smart_open is None:
cls.setup(compression=compression, transport_params=transport_params)
compression = compression if compression else cls.compression
transport_params = transport_params if transport_params else cls.transport_params
return cls.smart_open(uri, mode=mode, compression=compression, transport_params=transport_params, **kwargs)


def fix_random_seed(random_seed: int):
"""
Set the same random seed for the libraries and modules that Lhotse interacts with.
Expand Down