Skip to content

Commit

Permalink
feat(dataset): add streaming option to HF datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Nov 19, 2024
1 parent c613029 commit 1818745
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,33 @@ def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str]):
return expanded_identifiers


HF_ROWS_MAX_TO_DOWNLOAD = 5000
HF_ROWS_TO_TAKE_STREAMING = 1000


class HuggingFaceDataset:
"""Interface to Hugging Face Dataset with the same API as JsonDataset."""

def __init__(self, identifier: str):
parts = identifier.split("@")
if len(parts) == 3:
identifier, selected_config_name, selected_split_name = parts
repo_name, selected_config_name, selected_split_name = parts
else:
raise ValueError("Identifier must be in the format 'dataset@config@split'")

self._dataset = load_dataset(identifier, selected_config_name, split=selected_split_name)
infos = get_dataset_infos(repo_name)
selected_info = infos[selected_config_name]
num_examples = selected_info.splits[selected_split_name].num_examples
self._streaming = num_examples > HF_ROWS_MAX_TO_DOWNLOAD

dataset = load_dataset(
repo_name, selected_config_name, split=selected_split_name, streaming=self._streaming
)

if self._streaming:
self._dataset = dataset.take(HF_ROWS_TO_TAKE_STREAMING)
else:
self._dataset = dataset
self._metadata = self._dataset.remove_columns(["image"])

imgs, row_idx_to_id, id_to_row_idx = self._load_images()
Expand All @@ -104,6 +120,28 @@ def __init__(self, identifier: str):

self.anns = self._load_annotations()

def _load_images(self):
images = {}
row_idx_to_id = {}
id_to_row_idx = {}
dataset = self._dataset if self._streaming else self._metadata
for idx, example in enumerate(dataset):
id = example.get("id", example.get("image_id", idx))
images[id] = {
"id": id,
}
row_idx_to_id[idx] = id
id_to_row_idx[id] = example["image"] if self._streaming else idx

return images, row_idx_to_id, id_to_row_idx

def get_image(self, id):
"""Get the image given an image id."""
if self._streaming:
return self._id_to_row_idx[id]
row = self._id_to_row_idx[id]
return self._dataset[row]["image"]

def _load_categories(self):
labels = None
if "labels" in self._metadata.features:
Expand Down Expand Up @@ -202,24 +240,6 @@ def make_id():

return annotations

def _load_images(self):
images = {}
row_idx_to_id = {}
id_to_row_idx = {}
for idx, example in enumerate(self._metadata):
id = example.get("id", example.get("image_id", idx))
images[id] = {
"id": id,
}
row_idx_to_id[idx] = id
id_to_row_idx[id] = idx
return images, row_idx_to_id, id_to_row_idx

def get_image(self, id):
"""Get the image given an image id."""
row = self._id_to_row_idx[id]
return self._dataset[row]["image"]


@lru_cache
def __load_dataset(identifier: str):
Expand Down

0 comments on commit 1818745

Please sign in to comment.