From 18187454b9a9c19216831d9345daa55d6582e9a1 Mon Sep 17 00:00:00 2001 From: Paul Elliott Date: Mon, 11 Nov 2024 15:13:14 -0500 Subject: [PATCH] feat(dataset): add streaming option to HF datasets --- src/nrtk_explorer/library/dataset.py | 60 ++++++++++++++++++---------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py index 3d9506a..761c7e6 100644 --- a/src/nrtk_explorer/library/dataset.py +++ b/src/nrtk_explorer/library/dataset.py @@ -82,17 +82,33 @@ def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str]): return expanded_identifiers +HF_ROWS_MAX_TO_DOWNLOAD = 5000 +HF_ROWS_TO_TAKE_STREAMING = 1000 + + class HuggingFaceDataset: """Interface to Hugging Face Dataset with the same API as JsonDataset.""" def __init__(self, identifier: str): parts = identifier.split("@") if len(parts) == 3: - identifier, selected_config_name, selected_split_name = parts + repo_name, selected_config_name, selected_split_name = parts else: raise ValueError("Identifier must be in the format 'dataset@config@split'") - self._dataset = load_dataset(identifier, selected_config_name, split=selected_split_name) + infos = get_dataset_infos(repo_name) + selected_info = infos[selected_config_name] + num_examples = selected_info.splits[selected_split_name].num_examples + self._streaming = num_examples > HF_ROWS_MAX_TO_DOWNLOAD + + dataset = load_dataset( + repo_name, selected_config_name, split=selected_split_name, streaming=self._streaming + ) + + if self._streaming: + self._dataset = dataset.take(HF_ROWS_TO_TAKE_STREAMING) + else: + self._dataset = dataset self._metadata = self._dataset.remove_columns(["image"]) imgs, row_idx_to_id, id_to_row_idx = self._load_images() @@ -104,6 +120,28 @@ def __init__(self, identifier: str): self.anns = self._load_annotations() + def _load_images(self): + images = {} + row_idx_to_id = {} + id_to_row_idx = {} + dataset = self._dataset if self._streaming else self._metadata + for idx, example in enumerate(dataset): + id = example.get("id", example.get("image_id", idx)) + images[id] = { + "id": id, + } + row_idx_to_id[idx] = id + id_to_row_idx[id] = example["image"] if self._streaming else idx + + return images, row_idx_to_id, id_to_row_idx + + def get_image(self, id): + """Get the image given an image id.""" + if self._streaming: + return self._id_to_row_idx[id] + row = self._id_to_row_idx[id] + return self._dataset[row]["image"] + def _load_categories(self): labels = None if "labels" in self._metadata.features: @@ -202,24 +240,6 @@ def make_id(): return annotations - def _load_images(self): - images = {} - row_idx_to_id = {} - id_to_row_idx = {} - for idx, example in enumerate(self._metadata): - id = example.get("id", example.get("image_id", idx)) - images[id] = { - "id": id, - } - row_idx_to_id[idx] = id - id_to_row_idx[id] = idx - return images, row_idx_to_id, id_to_row_idx - - def get_image(self, id): - """Get the image given an image id.""" - row = self._id_to_row_idx[id] - return self._dataset[row]["image"] - @lru_cache def __load_dataset(identifier: str):