Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper around smart_open to cache transport_params for every call #335

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions lhotse/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,7 @@ def load_audio(
warnings.warn('You requested a subset of a recording that is read from URL. '
'Expect large I/O overhead if you are going to read many chunks like these, '
'since every time we will download the whole file rather than its subset.')
try:
from smart_open import smart_open
except ImportError:
raise ImportError("To use 'url' audio source type, please do 'pip install smart_open' - "
"if you are using S3/GCP/Azure/other cloud-specific URIs, do "
"'pip install smart_open[s3]' (or smart_open[gcp], etc.) instead.")
with smart_open(self.source) as f:
with SmartOpen.open(self.source, 'rb') as f:
source = BytesIO(f.read())
samples, sampling_rate = read_audio(source, offset=offset, duration=duration)

Expand Down
14 changes: 3 additions & 11 deletions lhotse/features/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lilcom
import numpy as np

from lhotse.utils import Pathlike, is_module_available
from lhotse.utils import Pathlike, is_module_available, SmartOpen


class FeaturesWriter(metaclass=ABCMeta):
Expand Down Expand Up @@ -471,7 +471,6 @@ class LilcomURLReader(FeaturesReader):
Downloads Lilcom-compressed files from a URL (S3, GCP, Azure, HTTP, etc.).
``storage_path`` corresponds to the root URL (e.g. "s3://my-data-bucket")
``storage_key`` will be concatenated to ``storage_path`` to form a full URL (e.g. "my-feature-file.llc")
``transport_params`` is an optional paramater that is passed through to ``smart_open``

.. caution::
Requires ``smart_open`` to be installed (``pip install smart_open``).
Expand All @@ -481,7 +480,6 @@ class LilcomURLReader(FeaturesReader):
def __init__(
self,
storage_path: Pathlike,
transport_params: Optional[dict] = None,
*args,
**kwargs
):
Expand All @@ -490,19 +488,17 @@ def __init__(
# We are manually adding the slash to join the base URL and the key.
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.transport_params = transport_params

def read(
self,
key: str,
left_offset_frames: int = 0,
right_offset_frames: Optional[int] = None
) -> np.ndarray:
import smart_open
# We are manually adding the slash to join the base URL and the key.
if key.startswith('/'):
key = key[1:]
with smart_open.open(f'{self.base_url}/{key}', 'rb', transport_params=self.transport_params) as f:
with SmartOpen.open(f'{self.base_url}/{key}', 'rb') as f:
arr = lilcom.decompress(f.read())
return arr[left_offset_frames: right_offset_frames]

Expand All @@ -513,7 +509,6 @@ class LilcomURLWriter(FeaturesWriter):
Writes Lilcom-compressed files to a URL (S3, GCP, Azure, HTTP, etc.).
``storage_path`` corresponds to the root URL (e.g. "s3://my-data-bucket")
``storage_key`` will be concatenated to ``storage_path`` to form a full URL (e.g. "my-feature-file.llc")
``transport_params`` is an optional paramater that is passed through to ``smart_open``

.. caution::
Requires ``smart_open`` to be installed (``pip install smart_open``).
Expand All @@ -524,7 +519,6 @@ def __init__(
self,
storage_path: Pathlike,
tick_power: int = -5,
transport_params: Optional[dict] = None,
*args,
**kwargs
):
Expand All @@ -534,14 +528,12 @@ def __init__(
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.tick_power = tick_power
self.transport_params = transport_params

@property
def storage_path(self) -> str:
return self.base_url

def write(self, key: str, value: np.ndarray) -> str:
import smart_open
# We are manually adding the slash to join the base URL and the key.
if key.startswith('/'):
key = key[1:]
Expand All @@ -550,7 +542,7 @@ def write(self, key: str, value: np.ndarray) -> str:
key = key + '.llc'
output_features_url = f'{self.base_url}/{key}'
serialized_feats = lilcom.compress(value, tick_power=self.tick_power)
with smart_open.open(output_features_url, 'wb', transport_params=self.transport_params) as f:
with SmartOpen.open(output_features_url, 'wb') as f:
f.write(serialized_feats)
return key

Expand Down
42 changes: 42 additions & 0 deletions lhotse/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import random
import uuid
import logging
from contextlib import AbstractContextManager, contextmanager
from dataclasses import asdict, dataclass
from decimal import Decimal, ROUND_HALF_DOWN, ROUND_HALF_UP
Expand Down Expand Up @@ -30,6 +31,47 @@
_lhotse_uuid: Optional[Callable] = None


class SmartOpen:
oplatek marked this conversation as resolved.
Show resolved Hide resolved
"""Wrapper around smart_open.open method

The class attributes are used to cache smart_open parameters
between different calls to smart open
"""
transport_params: Optional[Dict] = None
compression: Optional[str] = None
import_err_msg = ("Please do 'pip install smart_open' - "
"if you are using S3/GCP/Azure/other cloud-specific URIs, do "
"'pip install smart_open[s3]' (or smart_open[gcp], etc.) instead.")
smart_open: Optional[Callable] = None

@classmethod
def setup(
cls,
compression: Optional[str]=None,
transport_params: Optional[dict]= None):
try:
from smart_open import open as sm_open
except ImportError:
raise ImportError(cls.import_err_msg)
if cls.transport_params is not None and cls.transport_params != transport_params:
logging.warning(f'SmartOpen.setup second call overwrites existing transport_params with new version'
f'\t\n{cls.transport_params}\t\nvs\t\n{transport_params}')
if cls.compression is not None and cls.compression != compression:
logging.warning(f'SmartOpen.setup second call overwrites existing compression param with new version'
f'\t\n{cls.compression} vs {compression}')
cls.transport_params = transport_params
cls.compression = compression
cls.smart_open = sm_open

@classmethod
def open(cls, uri, mode='rb', compression=None, transport_params=None, **kwargs):
if cls.smart_open is None:
cls.setup(compression=compression, transport_params=transport_params)
compression = compression if compression else cls.compression
transport_params = transport_params if transport_params else cls.transport_params
return cls.smart_open(uri, mode=mode, compression=compression, transport_params=transport_params, **kwargs)


def fix_random_seed(random_seed: int):
"""
Set the same random seed for the libraries and modules that Lhotse interacts with.
Expand Down