diff --git a/CHANGELOG.md b/CHANGELOG.md index e0c70d8575..4f4c853d16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Fix hyperlink errors in the document (, ) +- Fix memory unbounded Arrow data format export/import + () ## 15/09/2023 - Release 1.5.0 ### New features diff --git a/docs/source/docs/data-formats/formats/arrow.md b/docs/source/docs/data-formats/formats/arrow.md index a614fc2ec0..57ad6fde23 100644 --- a/docs/source/docs/data-formats/formats/arrow.md +++ b/docs/source/docs/data-formats/formats/arrow.md @@ -178,13 +178,10 @@ Extra options for exporting to Arrow format: - `JPEG/95`: [JPEG](https://en.wikipedia.org/wiki/JPEG) with 95 quality - `JPEG/75`: [JPEG](https://en.wikipedia.org/wiki/JPEG) with 75 quality - `NONE`: skip saving image. -- `--max-chunk-size MAX_CHUNK_SIZE` allow to specify maximum chunk size (batch size) when saving into arrow format. +- `--max-shard-size MAX_SHARD_SIZE` allow to specify maximum number of dataset items when saving into arrow format. (default: `1000`) - `--num-shards NUM_SHARDS` allow to specify the number of shards to generate. - `--num-shards` and `--max-shard-size` are mutually exclusive. - (default: `1`) -- `--max-shard-size MAX_SHARD_SIZE` allow to specify maximum size of each shard. (e.g. 7KB = 7 \* 2^10, 3MB = 3 \* 2^20, and 2GB = 2 \* 2^30) - `--num-shards` and `--max-shard-size` are mutually exclusive. + `--num-shards` and `--max-shard-size` are mutually exclusive. (default: `None`) - `--num-workers NUM_WORKERS` allow to multi-processing for the export. If num_workers = 0, do not use multiprocessing (default: `0`). diff --git a/src/datumaro/components/dataset_base.py b/src/datumaro/components/dataset_base.py index 32f2105848..701706166a 100644 --- a/src/datumaro/components/dataset_base.py +++ b/src/datumaro/components/dataset_base.py @@ -178,13 +178,13 @@ def media_type(_): return _DatasetFilter() - def infos(self): + def infos(self) -> DatasetInfo: return {} - def categories(self): + def categories(self) -> CategoriesInfo: return {} - def get(self, id, subset=None): + def get(self, id, subset=None) -> Optional[DatasetItem]: subset = subset or DEFAULT_SUBSET_NAME for item in self: if item.id == id and item.subset == subset: diff --git a/src/datumaro/components/format_detection.py b/src/datumaro/components/format_detection.py index 2b3795f34e..b77ade965d 100644 --- a/src/datumaro/components/format_detection.py +++ b/src/datumaro/components/format_detection.py @@ -319,7 +319,7 @@ def _require_files_iter( @contextlib.contextmanager def probe_text_file( self, path: str, requirement_desc: str, is_binary_file: bool = False - ) -> Union[BufferedReader, TextIO]: + ) -> Iterator[Union[BufferedReader, TextIO]]: """ Returns a context manager that can be used to place a requirement on the contents of the file referred to by `path`. To do so, you must diff --git a/src/datumaro/plugins/data_formats/arrow/arrow_dataset.py b/src/datumaro/plugins/data_formats/arrow/arrow_dataset.py deleted file mode 100644 index 4212e12804..0000000000 --- a/src/datumaro/plugins/data_formats/arrow/arrow_dataset.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# -# SPDX-License-Identifier: MIT - -import copy -import errno -import os -from collections.abc import Iterator -from typing import Dict, List, Optional, Tuple, TypeVar, Union, overload - -import numpy as np -import pyarrow as pa - -PathLike = TypeVar("PathLike", str, os.PathLike, None) - - -# TODO: consider replacing with `bisect` library -def _binary_search(arr, x, low=None, high=None, fn=None): - low = 0 if low is None else low - high = len(arr) - 1 if high is None else high - if fn is None: - - def fn(arr, x): - return arr[x] - - while low <= high: - mid = (high + low) // 2 - val = fn(arr, mid) - if val > x: - high = mid - 1 - else: - low = mid + 1 - return high - - -class ArrowDatasetIterator(Iterator): - def __init__(self, dataset: "ArrowDataset"): - self.dataset = dataset - self.index = 0 - - def __next__(self) -> Dict: - if self.index >= len(self.dataset): - raise StopIteration - item = self.dataset[self.index] - self.index += 1 - return item - - -class ArrowDataset: - def __init__( - self, - files: Union[PathLike, List[PathLike]], - keep_in_memory: bool = False, - ): - assert files - if not isinstance(files, (list, tuple)): - files = [files] - - self._files = files - self._keep_in_memory = keep_in_memory - - self.unload_arrows() - self.load_arrows() - - @property - def table(self) -> pa.Table: - self.load_arrows() - return self.__table - - @property - def column_names(self) -> List[str]: - return self.table.column_names - - @property - def num_columns(self) -> int: - return self.table.num_columns - - @property - def num_rows(self) -> int: - return self.table.num_rows - - @property - def shape(self) -> Tuple[int]: - return self.table.shape - - @property - def metadata(self) -> Optional[Dict[bytes, bytes]]: - return self.table.schema.metadata - - def __deepcopy__(self, memo) -> "ArrowDataset": - args = [self._files, self._keep_in_memory] - copied_args = copy.deepcopy(args, memo) - return self.__class__(*copied_args) - - def __len__(self) -> int: - return self.table.num_rows - - def __iter__(self): - return ArrowDatasetIterator(self) - - def unload_arrows(self) -> None: - self.__table: pa.Table = None - self.__batches: List[pa.RecordBatch] = [] - self.__offsets: List[int] = [] - - def load_arrows(self) -> None: - if self.__table is not None: - return - - table = None - for file in self._files: - if not os.path.exists(file): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) - - if self._keep_in_memory: - stream = pa.input_stream(file) - else: - stream = pa.memory_map(file) - opened_stream = pa.ipc.open_stream(stream) - table_ = opened_stream.read_all() - if table: - table = pa.concat_tables((table, table_)) - else: - table = table_ - self._update_table(table) - - def _update_table(self, table: pa.Table) -> None: - self.__table = table - self.__batches = self.__table.to_batches() - self.__offsets = np.cumsum([0] + [len(i) for i in self.__batches]).tolist() - - def flatten(self) -> "ArrowDataset": - dataset = ArrowDataset(self._files, self._keep_in_memory) - dataset._update_table(dataset.table.flatten()) - return dataset - - def get_batches(self, offset: int, length: int) -> List[pa.RecordBatch]: - assert length > 0 - - s_idx = _binary_search(self.__offsets, offset) - e_idx = _binary_search(self.__offsets, offset + length - 1) - - batches = self.__batches[s_idx : e_idx + 1] - if self.__offsets[s_idx + 1] >= offset + length: - batches[0] = batches[0].slice(offset - self.__offsets[s_idx], length) - else: - batches[0] = batches[0].slice(offset - self.__offsets[s_idx]) - batches[-1] = batches[-1].slice(0, offset + length - self.__offsets[e_idx]) - - return batches - - def slice(self, offset: int, length: int) -> List[Dict]: - batches = self.get_batches(offset, length) - items = pa.Table.from_batches(batches).to_pylist() - return items - - @overload - def __getitem__(self, key: str) -> "ArrowDataset": - ... - - @overload - def __getitem__(self, key: int) -> Dict: - ... - - @overload - def __getitem__(self, key: Union[range, slice]) -> List[Dict]: - ... - - def __getitem__( - self, key: Union[int, range, slice, str] - ) -> Union[Dict, List[Dict], "ArrowDataset"]: - if isinstance(key, str): - if key not in self.column_names: - raise KeyError(key) - dataset = ArrowDataset(self._files, self._keep_in_memory) - dataset._update_table( - self.table.drop([name for name in self.column_names if name != key]) - ) - return dataset - elif isinstance(key, slice): - key = range(*key.indices(len(self))) - - if isinstance(key, range): - if key.step == 1 and key.stop >= key.start: - return self.slice(key.start, key.stop - key.start) - return [ - self.slice(key_, 1)[0] if key_ >= 0 else self.slice(key_ + len(self), 1)[0] - for key_ in key - ] - if isinstance(key, int): - if key < 0: - key += len(self) - return self.slice(key, 1)[0] - - raise KeyError(key) - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"num_rows={len(self)}, " - f"columns={self.column_names}, " - f"keep_in_memory={self._keep_in_memory}" - ")" - ) diff --git a/src/datumaro/plugins/data_formats/arrow/base.py b/src/datumaro/plugins/data_formats/arrow/base.py index 71ae9534c2..ce2c00f4df 100644 --- a/src/datumaro/plugins/data_formats/arrow/base.py +++ b/src/datumaro/plugins/data_formats/arrow/base.py @@ -2,75 +2,169 @@ # # SPDX-License-Identifier: MIT -import os.path as osp import struct -from typing import List, Optional +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Type import pyarrow as pa -from datumaro.components.dataset_base import SubsetBase -from datumaro.components.errors import MediaTypeError +from datumaro.components.annotation import AnnotationType, Categories +from datumaro.components.dataset_base import ( + CategoriesInfo, + DatasetBase, + DatasetInfo, + DatasetItem, + IDataset, + SubsetBase, +) from datumaro.components.importer import ImportContext -from datumaro.components.media import MediaType -from datumaro.components.merge import get_merger +from datumaro.components.media import Image, MediaElement, MediaType +from datumaro.components.merge.extractor_merger import check_identicalness +from datumaro.plugins.data_formats.arrow.format import DatumaroArrow from datumaro.plugins.data_formats.datumaro.base import JsonReader from datumaro.plugins.data_formats.datumaro_binary.mapper.common import DictMapper +from datumaro.util.definitions import DEFAULT_SUBSET_NAME -from .arrow_dataset import ArrowDataset from .mapper.dataset_item import DatasetItemMapper -class ArrowBase(SubsetBase): +class ArrowSubsetBase(SubsetBase): + __not_plugin__ = True + + def __init__( + self, + lookup: Dict[str, DatasetItem], + infos: Dict[str, Any], + categories: Dict[AnnotationType, Categories], + subset: str, + media_type: Type[MediaElement] = Image, + ): + super().__init__(length=len(lookup), subset=subset, media_type=media_type, ctx=None) + + self._lookup = lookup + self._infos = infos + self._categories = categories + + def __iter__(self) -> Iterator[DatasetItem]: + for item in self._lookup.values(): + yield item + + def __len__(self) -> int: + return len(self._lookup) + + def get(self, item_id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: + if subset != self._subset: + return None + + try: + return self._lookup[item_id] + except KeyError: + return None + + +@dataclass(frozen=True) +class Metadata: + infos: Dict + categories: Dict + media_type: Type[MediaElement] + + +class ArrowBase(DatasetBase): def __init__( self, - path: str, - additional_paths: Optional[List[str]] = None, + root_path: str, *, - subset: Optional[str] = None, + file_paths: List[str], ctx: Optional[ImportContext] = None, ): - super().__init__(subset=subset, ctx=ctx) - - self._paths = [path] - if additional_paths: - self._paths += additional_paths - - self._load() - - def _load(self): - infos = [] - categories = [] - media_types = set() - - for path in self._paths: - with pa.ipc.open_stream(path) as reader: - schema = reader.schema - - _infos, _ = DictMapper.backward(schema.metadata.get(b"infos", b"\x00\x00\x00\x00")) - infos.append(JsonReader._load_infos({"infos": _infos})) - - _categories, _ = DictMapper.backward( - schema.metadata.get(b"categories", b"\x00\x00\x00\x00") - ) - categories.append(JsonReader._load_categories({"categories": _categories})) - - (media_type,) = struct.unpack( - " 1: - raise MediaTypeError("Datasets have different media types") - merger = get_merger("exact") - self._infos = merger.merge_infos(infos) - self._categories = merger.merge_categories(categories) - self._media_type = list(media_types)[0] + self._root_path = root_path + tables = [pa.ipc.open_file(pa.memory_map(path, "r")).read_all() for path in file_paths] + metadatas = [self._load_schema_metadata(table) for table in tables] + + table = pa.concat_tables(tables) + subsets = table.column(DatumaroArrow.SUBSET_FIELD).unique().to_pylist() + media_type = check_identicalness([metadata.media_type for metadata in metadatas]) + + super().__init__(length=len(table), subsets=subsets, media_type=media_type, ctx=ctx) + + self._infos = check_identicalness([metadata.infos for metadata in metadatas]) + self._categories = check_identicalness([metadata.categories for metadata in metadatas]) + + self._init_cache(file_paths, subsets) + + @staticmethod + def _load_schema_metadata(table: pa.Table) -> Metadata: + schema = table.schema + + _infos, _ = DictMapper.backward(schema.metadata.get(b"infos", b"\x00\x00\x00\x00")) + infos = JsonReader._load_infos({"infos": _infos}) + + _categories, _ = DictMapper.backward( + schema.metadata.get(b"categories", b"\x00\x00\x00\x00") + ) + categories = JsonReader._load_categories({"categories": _categories}) + + (media_type,) = struct.unpack(" DatasetInfo: + return self._infos + + def categories(self) -> CategoriesInfo: + return self._categories + + def __iter__(self) -> Iterator[DatasetItem]: + for lookup in self._lookup.values(): + for item in lookup.values(): + yield item + + def _init_cache(self, file_paths: List[str], subsets: List[str]): + self._lookup: Dict[str, Dict[str, DatasetItem]] = {subset: {} for subset in subsets} + + total = len(self) + cnt = 0 + pbar = self._ctx.progress_reporter + pbar.start(total=total, desc="Importing") + + for table_path in file_paths: + with pa.OSFile(table_path, "r") as source: + with pa.ipc.open_file(source) as reader: + table = reader.read_all() + for idx in range(len(table)): + item = DatasetItemMapper.backward(idx, table, table_path) + self._lookup[item.subset][item.id] = item + pbar.report_status(cnt) + cnt += 1 + + self._subsets = { + subset: ArrowSubsetBase( + lookup=lookup, + infos=self._infos, + categories=self._categories, + subset=self._subsets, + media_type=self._media_type, + ) + for subset, lookup in self._lookup.items() + } + + pbar.finish() + + def get(self, item_id: str, subset: Optional[str] = None) -> Optional[DatasetItem]: + subset = subset or DEFAULT_SUBSET_NAME + + try: + return self._lookup[subset][item_id] + except KeyError: + return None + + @property + def lookup(self) -> Dict[str, Dict[str, int]]: + return self._lookup + + def subsets(self) -> Dict[str, IDataset]: + return self._subsets + + def get_subset(self, name: str) -> IDataset: + return self._subsets[name] diff --git a/src/datumaro/plugins/data_formats/arrow/exporter.py b/src/datumaro/plugins/data_formats/arrow/exporter.py index d0fb1e09aa..49b71cc41e 100644 --- a/src/datumaro/plugins/data_formats/arrow/exporter.py +++ b/src/datumaro/plugins/data_formats/arrow/exporter.py @@ -2,227 +2,25 @@ # # SPDX-License-Identifier: MIT -import datetime import os -import platform -import re -import struct import tempfile -from copy import deepcopy -from functools import partial -from multiprocessing.pool import ApplyResult, Pool +from multiprocessing.pool import Pool from shutil import move, rmtree -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Iterator, Optional, Union +import numpy as np import pyarrow as pa -import pytz -from datumaro.components.crypter import NULL_CRYPTER -from datumaro.components.dataset_base import DEFAULT_SUBSET_NAME, DatasetItem, IDataset +from datumaro.components.dataset_base import DatasetItem, IDataset from datumaro.components.errors import DatumaroError -from datumaro.components.exporter import ExportContext, ExportContextComponent, Exporter -from datumaro.plugins.data_formats.datumaro.exporter import _SubsetWriter as __SubsetWriter -from datumaro.plugins.data_formats.datumaro_binary.mapper.common import DictMapper -from datumaro.util.file_utils import to_bytes +from datumaro.components.exporter import ExportContext, Exporter +from datumaro.util.multi_procs_util import consumer_generator from .format import DatumaroArrow from .mapper.dataset_item import DatasetItemMapper from .mapper.media import ImageMapper -class PathNormalizer: - NORMALIZER = {r":[^/\\]": r"꞉"} - UNNORMALIZER = {r"꞉": r":"} - - @classmethod - def normalize(cls, path: str): - for s, t in cls.NORMALIZER.items(): - path = re.sub(s, t, path) - return path - - @classmethod - def unnormalize(cls, path: str): - for s, t in cls.UNNORMALIZER.items(): - path = re.sub(s, t, path) - return path - - -class _SubsetWriter(__SubsetWriter): - def __init__( - self, - context: Exporter, - export_context: ExportContextComponent, - subset: str, - ctx: ExportContext, - max_chunk_size: int = 1000, - num_shards: int = 1, - max_shard_size: Optional[int] = None, - ): - super().__init__( - context=context, - subset=subset, - ann_file="", - export_context=export_context, - ) - self._schema = deepcopy(DatumaroArrow.SCHEMA) - self._writers = [] - self._fnames = [] - self._max_chunk_size = max_chunk_size - self._num_shards = num_shards - self._max_shard_size = max_shard_size - self._ctx = ctx - - self._data = { - "items": [], - "infos": {}, - "categories": {}, - "media_type": None, - "built_time": str(datetime.datetime.now(pytz.utc)), - "source_path": self.export_context.source_path, - "encoding_scheme": self.export_context._image_ext, - "version": str(DatumaroArrow.VERSION), - "signature": DatumaroArrow.SIGNATURE, - } - - def add_infos(self, infos): - if self._writers: - raise ValueError("Writer has been initialized.") - super().add_infos(infos) - self._data["infos"] = DictMapper.forward(self.infos) - - def add_categories(self, categories): - if self._writers: - raise ValueError("Writer has been initialized.") - super().add_categories(categories) - self._data["categories"] = DictMapper.forward(self.categories) - - def add_media_type(self, media_type): - if self._writers: - raise ValueError("Writer has been initialized.") - self._data["media_type"] = struct.pack(" Dict[str, Any]: - item = DatasetItemMapper.forward(item, media={"encoder": context._image_ext}) - - if item["media"].get("bytes", None) is not None: - # truncate source path since the media is embeded in arrow - path = item["media"].get("path") - if path is not None: - item["media"]["path"] = path.replace(context.source_path, "") - return item - - def _write(self, batch, max_chunk_size): - if max_chunk_size == 0: - max_chunk_size = self._max_chunk_size - pa_table = pa.Table.from_arrays(batch, schema=self._schema) - idx = getattr(self, "__writer_idx", 0) - - if self._max_shard_size is not None: - nbytes = pa_table.nbytes - cur_nbytes = getattr(self, "__writer_nbytes", 0) - free_nbytes = self._max_shard_size - cur_nbytes - if free_nbytes < nbytes: - if cur_nbytes == 0: - raise DatumaroError( - "'max_chunk_size' exceeded 'max_shard_size'. " - "Please consider increasing 'max_shard_size' or deceasing 'max_chunk_size'." - ) - self._writers[idx].close() - idx += 1 - setattr(self, "__writer_idx", idx) - setattr(self, "__writer_nbytes", nbytes) - else: - setattr(self, "__writer_nbytes", cur_nbytes + nbytes) - else: - setattr(self, "__writer_idx", (idx + 1) % self._num_shards) - - if len(self._writers) <= idx: - writer, fname = self._init_writer(idx) - self._writers.append(writer) - self._fnames.append(fname) - self._writers[idx].write_table(pa_table, max_chunk_size) - - def write(self, max_chunk_size: Optional[int] = None, pool: Optional[Pool] = None): - if max_chunk_size is None: - max_chunk_size = self._max_chunk_size - - if len(self.items) < max_chunk_size: - return - - batch = [[] for _ in self._schema.names] - for item in self._ctx.progress_reporter.iter( - self.items, desc=f"Building arrow for {self._subset}" - ): - if isinstance(item, ApplyResult): - item = item.get(timeout=DatumaroArrow.MP_TIMEOUT) - if isinstance(item, partial): - item = item() - for j, name in enumerate(self._schema.names): - batch[j].append(item[name]) - - if len(batch[0]) >= max_chunk_size: - self._write(batch, max_chunk_size) - batch = [[] for _ in self._schema.names] - if len(batch[0]) > 0: - self._write(batch, max_chunk_size) - - self._data["items"] = [] - - def done(self): - total = len(self._fnames) - width = len(str(total)) - for idx, (fname, writer) in enumerate(zip(self._fnames, self._writers)): - writer.close() - if platform.system() != "Windows": - template = PathNormalizer.unnormalize(fname) - placeholders = {"width": width, "total": total, f"idx{idx}": idx} - new_fname = template.format(**placeholders) - if os.path.exists(new_fname): - os.remove(new_fname) - os.rename(fname, new_fname) - - class ArrowExporter(Exporter): AVAILABLE_IMAGE_EXTS = ImageMapper.AVAILABLE_SCHEMES DEFAULT_IMAGE_EXT = ImageMapper.AVAILABLE_SCHEMES[0] @@ -247,28 +45,24 @@ def build_cmdline_parser(cls, **kwargs): ) parser.add_argument( - "--max-chunk-size", + "--max-shard-size", type=int, default=1000, - help="The maximum chunk size. (default: %(default)s)", + help="The maximum number of dataset item can be stored in each shard file. " + "'--max-shard-size' and '--num-shards' are mutually exclusive, " + "Therefore, if '--max-shard-size' is not None, the number of shard files will be determined by " + "(# of total dataset item) / (max chunk size). " + "(default: %(default)s)", ) parser.add_argument( "--num-shards", type=int, - default=1, - help="The number of shards to export. " - "'--num-shards' and '--max-shard-size' are mutually exclusive. " - "(default: %(default)s)", - ) - - parser.add_argument( - "--max-shard-size", - type=str, default=None, - help="The maximum size of each shard. " - "(e.g. 7KB = 7 * 2^10, 3MB = 3 * 2^20, and 2GB = 2 * 2^30). " - "'--num-shards' and '--max-shard-size' are mutually exclusive. " + help="The number of shards to export. " + "'--max-shard-size' and '--num-shards' are mutually exclusive. " + "Therefore, if '--num-shards' is not None, the number of dataset item in each shard file " + "will be determined by (# of total dataset item) / (num shards). " "(default: %(default)s)", ) @@ -280,32 +74,16 @@ def build_cmdline_parser(cls, **kwargs): "If num_workers = 0, do not use multiprocessing. (default: %(default)s)", ) - return parser - - def create_writer(self, subset: str, ctx: ExportContext) -> _SubsetWriter: - export_context = ExportContextComponent( - save_dir=self._save_dir, - save_media=self._save_media, - images_dir="", - pcd_dir="", - video_dir="", - crypter=NULL_CRYPTER, - image_ext=self._image_ext, - default_image_ext=self._default_image_ext, - source_path=os.path.abspath(self._extractor._source_path) - if getattr(self._extractor, "_source_path") - else None, + parser.add_argument( + "--prefix", + type=str, + default="datum", + help="Prefix to be appended in front of the shard file name. " + "Therefore, the generated file name will be `-.arrow`. " + "(default: %(default)s)", ) - return _SubsetWriter( - context=self, - subset=subset, - export_context=export_context, - num_shards=self._num_shards, - max_shard_size=self._max_shard_size, - max_chunk_size=self._max_chunk_size, - ctx=ctx, - ) + return parser def _apply_impl(self, *args, **kwargs): if self._num_workers == 0: @@ -317,31 +95,52 @@ def _apply_impl(self, *args, **kwargs): def _apply(self, pool: Optional[Pool] = None): os.makedirs(self._save_dir, exist_ok=True) - if self._split_by_subsets: - writers = { - subset_name: self.create_writer(subset_name, self._ctx) - for subset_name, subset in self._extractor.subsets().items() - if len(subset) - } + pbar = self._ctx.progress_reporter + + if pool is not None: + + def _producer_gen(): + for item in self._extractor: + future = pool.apply_async( + func=self._item_to_dict_record, + args=(item, self._image_ext, self._source_path), + ) + yield future + + with consumer_generator(producer_generator=_producer_gen()) as consumer_gen: + + def _gen_with_pbar(): + for item in pbar.iter( + consumer_gen, desc="Exporting", total=len(self._extractor) + ): + yield item.get() + + self._write_file(_gen_with_pbar()) + else: - writers = {DEFAULT_SUBSET_NAME: self.create_writer(DEFAULT_SUBSET_NAME, self._ctx)} - for writer in writers.values(): - writer.add_infos(self._extractor.infos()) - writer.add_categories(self._extractor.categories()) - writer.add_media_type(self._extractor.media_type()._type) - writer.init_schema() + def create_consumer_gen(): + for item in pbar.iter(self._extractor, desc="Exporting"): + yield self._item_to_dict_record(item, self._image_ext, self._source_path) + + self._write_file(create_consumer_gen()) + + def _write_file(self, consumer_gen: Iterator[Dict[str, Any]]) -> None: + for file_idx, size in enumerate(self._chunk_sizes): + suffix = str(file_idx).zfill(self._max_digits) + fname = f"{self._prefix}-{suffix}.arrow" + fpath = os.path.join(self._save_dir, fname) - for subset_name, subset in self._extractor.subsets().items(): - writer = ( - writers[subset_name] if self._split_by_subsets else writers[DEFAULT_SUBSET_NAME] + record_batch = pa.RecordBatch.from_pylist( + mapping=[next(consumer_gen) for _ in range(size)], + schema=self._schema, ) - for item in subset: - writer.add_item(item, pool=pool) - for writer in writers.values(): - writer.write(0, pool=pool) - writer.done() + with pa.OSFile(fpath, "wb") as sink: + with pa.ipc.new_file(sink, self._schema) as writer: + writer.write(record_batch) + + pass @classmethod def patch(cls, dataset, patch, save_dir, **kwargs): @@ -371,9 +170,9 @@ def __init__( save_dataset_meta: bool = False, ctx: Optional[ExportContext] = None, num_workers: int = 0, - num_shards: int = 1, - max_shard_size: Optional[int] = None, - max_chunk_size: int = 1000, + max_shard_size: Optional[int] = 1000, + num_shards: Optional[int] = None, + prefix: str = "datum", **kwargs, ): super().__init__( @@ -386,30 +185,29 @@ def __init__( ctx=ctx, ) - # TODO: Support a whole single file of arrow - self._split_by_subsets = True - if num_workers < 0: raise DatumaroError( f"num_workers should be non-negative but num_workers={num_workers}." ) self._num_workers = num_workers - if num_shards != 1 and max_shard_size is not None: + if num_shards is not None and max_shard_size is not None: raise DatumaroError( - "Either one of 'num_shards' or 'max_shard_size' should be provided, not both." + "Both 'num_shards' or 'max_shard_size' cannot be provided at the same time." ) - - if num_shards < 0: + elif num_shards is not None and num_shards < 0: raise DatumaroError(f"num_shards should be non-negative but num_shards={num_shards}.") - self._num_shards = num_shards - self._max_shard_size = to_bytes(max_shard_size) if max_shard_size else max_shard_size - - if max_chunk_size < 0: + elif max_shard_size is not None and max_shard_size < 0: + raise DatumaroError( + f"max_shard_size should be non-negative but max_shard_size={max_shard_size}." + ) + elif num_shards is None and max_shard_size is None: raise DatumaroError( - f"max_chunk_size should be non-negative but max_chunk_size={max_chunk_size}." + "Either one of 'num_shards' or 'max_shard_size' should be provided." ) - self._max_chunk_size = max_chunk_size + + self._num_shards = num_shards + self._max_shard_size = max_shard_size if self._save_media: self._image_ext = ( @@ -421,3 +219,62 @@ def __init__( assert ( self._image_ext in self.AVAILABLE_IMAGE_EXTS ), f"{self._image_ext} is unkonwn ext. Available exts are {self.AVAILABLE_IMAGE_EXTS}" + + self._prefix = prefix + + self._source_path = ( + os.path.abspath(self._extractor._source_path) + if getattr(self._extractor, "_source_path") + else None + ) + + total_len = len(self._extractor) + + if self._num_shards is not None: + max_shard_size = int(total_len / self._num_shards) + 1 + + elif self._max_shard_size is not None: + max_shard_size = self._max_shard_size + else: + raise DatumaroError( + "Either one of 'num_shards' or 'max_shard_size' should be provided." + ) + + self._chunk_sizes = np.diff( + np.array([size for size in range(0, total_len, max_shard_size)] + [total_len]) + ) + assert ( + sum(self._chunk_sizes) == total_len + ), "Sum of chunk sizes should be the number of total items." + + num_shard_files = len(self._chunk_sizes) + self._max_digits = len(str(num_shard_files)) + + self._schema = DatumaroArrow.create_schema_with_metadata(self._extractor) + + self._subsets = { + subset_name: idx for idx, subset_name in enumerate(self._extractor.subsets()) + } + + @staticmethod + def _item_to_dict_record( + item: DatasetItem, + image_ext: Optional[str] = None, + source_path: Optional[str] = None, + ) -> Dict[str, Any]: + dict_item = DatasetItemMapper.forward(item, media={"encoder": image_ext}) + + def _change_path(parent: Dict) -> Dict: + for key, child in parent.items(): + if key == "path" and child is not None: + parent["path"] = child.replace(source_path, "") + + if isinstance(child, dict): + parent[key] = _change_path(child) + + return parent + + if source_path is not None: + return _change_path(dict_item) + + return dict_item diff --git a/src/datumaro/plugins/data_formats/arrow/format.py b/src/datumaro/plugins/data_formats/arrow/format.py index 8b03ded61e..dda6221550 100644 --- a/src/datumaro/plugins/data_formats/arrow/format.py +++ b/src/datumaro/plugins/data_formats/arrow/format.py @@ -3,32 +3,54 @@ # SPDX-License-Identifier: MIT +import struct + import pyarrow as pa +from datumaro.components.dataset_base import IDataset from datumaro.errors import DatasetImportError +from datumaro.plugins.data_formats.datumaro.exporter import JsonWriter +from datumaro.plugins.data_formats.datumaro_binary.mapper.common import DictMapper class DatumaroArrow: SIGNATURE = "signature:datumaro_arrow" - VERSION = "1.0" + VERSION = "2.0" MP_TIMEOUT = 300.0 # 5 minutes + ID_FIELD = "id" + SUBSET_FIELD = "subset" + MEDIA_FIELD = "media" + + IMAGE_FIELD = pa.struct( + [ + pa.field("has_bytes", pa.bool_()), + pa.field("bytes", pa.binary()), + pa.field("path", pa.string()), + pa.field("size", pa.list_(pa.uint16(), 2)), + ] + ) + POINT_CLOUD_FIELD = pa.struct( + [ + pa.field("has_bytes", pa.bool_()), + pa.field("bytes", pa.binary()), + pa.field("path", pa.string()), + pa.field("extra_images", pa.list_(pa.field("image", IMAGE_FIELD))), + ] + ) + MEDIA_FIELD = pa.struct( + [ + pa.field("type", pa.uint8()), + pa.field("image", IMAGE_FIELD), + pa.field("point_cloud", POINT_CLOUD_FIELD), + ] + ) SCHEMA = pa.schema( [ - pa.field("id", pa.string()), - pa.field("subset", pa.string()), - pa.field( - "media", - pa.struct( - [ - pa.field("type", pa.uint32()), - pa.field("path", pa.string()), - pa.field("bytes", pa.binary()), - pa.field("attributes", pa.binary()), - ] - ), - ), + pa.field(ID_FIELD, pa.string()), + pa.field(SUBSET_FIELD, pa.string()), + pa.field("media", MEDIA_FIELD), pa.field("annotations", pa.binary()), pa.field("attributes", pa.binary()), ] @@ -49,3 +71,25 @@ def check_schema(cls, schema: pa.lib.Schema): f"input schema: \n{schema.to_string()}\n" f"ground truth schema: \n{cls.SCHEMA.to_string()}\n" ) + + @classmethod + def check_version(cls, version: str): + if version != cls.VERSION: + raise DatasetImportError( + f"Input version={version} is not aligned with the current data format version={cls.VERSION}" + ) + + @classmethod + def create_schema_with_metadata(cls, extractor: IDataset): + media_type = extractor.media_type()._type + categories = JsonWriter.write_categories(extractor.categories()) + + return cls.SCHEMA.with_metadata( + { + "signature": cls.SIGNATURE, + "version": cls.VERSION, + "infos": DictMapper.forward(extractor.infos()), + "categories": DictMapper.forward(categories), + "media_type": struct.pack(" Optional[FormatDetectionConfidence]: - def verify_datumaro_arrow_format(file): - with pa.ipc.open_stream(file) as reader: - schema = reader.schema - DatumaroArrow.check_signature(schema.metadata.get(b"signature", b"").decode()) - DatumaroArrow.check_schema(schema) - if context.root_path.endswith(".arrow"): - verify_datumaro_arrow_format(context.root_path) + cls._verify_datumaro_arrow_format(context.root_path) else: for arrow_file in context.require_files("*.arrow"): with context.probe_text_file( @@ -35,31 +31,43 @@ def verify_datumaro_arrow_format(file): f"{arrow_file} is not Datumaro arrow format.", is_binary_file=True, ) as f: - verify_datumaro_arrow_format(f) + f.close() + cls._verify_datumaro_arrow_format(os.path.join(context.root_path, arrow_file)) @classmethod - def find_sources(cls, path) -> List[Dict]: - sources = cls._find_sources_recursive( - path, - ".arrow", - cls.NAME, + def find_sources(cls, path: str) -> List[Dict]: + def _filter(path: str) -> bool: + try: + cls._verify_datumaro_arrow_format(path) + return True + except DatasetImportError: + return False + + return cls._find_sources_recursive( + path=path, + ext=".arrow", + extractor_name=cls.NAME, + file_filter=_filter, + max_depth=0, ) - # handle sharded arrow files - _sources = [] - for source in sources: - match = re.match(r"(.*?)(-[0-9]+-of-[0-9]+)?\.arrow", source["url"]) - found = False - prefix = match.group(1) - for _source in _sources: - if _source["url"].startswith(prefix): - _source["options"]["additional_paths"].append(source["url"]) - found = True - break - if not found: - source["options"] = { - "additional_paths": [], - "subset": os.path.basename(prefix), - } - _sources.append(source) - return _sources + @classmethod + def find_sources_with_params(cls, path: str, **extra_params) -> List[Dict]: + sources = cls.find_sources(path) + # Merge sources into one config but multiple file_paths + return [ + { + "url": path, + "format": cls.NAME, + "options": {"file_paths": [source["url"] for source in sources]}, + } + ] + + @staticmethod + def _verify_datumaro_arrow_format(file: str) -> None: + with pa.memory_map(file, "r") as mm_file: + with pa.ipc.open_file(mm_file) as reader: + schema = reader.schema + DatumaroArrow.check_signature(schema.metadata.get(b"signature", b"").decode()) + DatumaroArrow.check_version(schema.metadata.get(b"version", b"").decode()) + DatumaroArrow.check_schema(schema) diff --git a/src/datumaro/plugins/data_formats/arrow/mapper/dataset_item.py b/src/datumaro/plugins/data_formats/arrow/mapper/dataset_item.py index bff2df2508..87bc5a082e 100644 --- a/src/datumaro/plugins/data_formats/arrow/mapper/dataset_item.py +++ b/src/datumaro/plugins/data_formats/arrow/mapper/dataset_item.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MIT -from typing import Any, Dict, List, Optional +from typing import Any, Dict import pyarrow as pa @@ -11,7 +11,6 @@ from datumaro.plugins.data_formats.datumaro_binary.mapper.common import DictMapper, Mapper from .media import MediaMapper -from .utils import pa_batches_decoder class DatasetItemMapper(Mapper): @@ -26,47 +25,16 @@ def forward(obj: DatasetItem, **options) -> Dict[str, Any]: } @staticmethod - def backward(obj: Dict[str, Any]) -> DatasetItem: + def backward(idx: int, table: pa.Table, table_path: str) -> DatasetItem: return DatasetItem( - id=obj["id"], - subset=obj["subset"], - media=MediaMapper.backward(obj["media"]), - annotations=AnnotationListMapper.backward(obj["annotations"])[0], - attributes=DictMapper.backward(obj["attributes"])[0], + id=table.column("id")[idx].as_py(), + subset=table.column("subset")[idx].as_py(), + media=MediaMapper.backward( + media_struct=table.column("media")[idx], + idx=idx, + table=table, + table_path=table_path, + ), + annotations=AnnotationListMapper.backward(table.column("annotations")[idx].as_py())[0], + attributes=DictMapper.backward(table.column("attributes")[idx].as_py())[0], ) - - @staticmethod - def backward_from_batches( - batches: List[pa.lib.RecordBatch], - parent: Optional[str] = None, - ) -> List[DatasetItem]: - ids = pa_batches_decoder(batches, f"{parent}.id" if parent else "id") - subsets = pa_batches_decoder(batches, f"{parent}.subset" if parent else "subset") - medias = MediaMapper.backward_from_batches( - batches, f"{parent}.media" if parent else "media" - ) - annotations_ = pa_batches_decoder( - batches, f"{parent}.annotations" if parent else "annotations" - ) - annotations_ = [ - AnnotationListMapper.backward(annotations)[0] for annotations in annotations_ - ] - attributes_ = pa_batches_decoder( - batches, f"{parent}.attributes" if parent else "attributes" - ) - attributes_ = [DictMapper.backward(attributes)[0] for attributes in attributes_] - - items = [] - for id, subset, media, annotations, attributes in zip( - ids, subsets, medias, annotations_, attributes_ - ): - items.append( - DatasetItem( - id=id, - subset=subset, - media=media, - annotations=annotations, - attributes=attributes, - ) - ) - return items diff --git a/src/datumaro/plugins/data_formats/arrow/mapper/media.py b/src/datumaro/plugins/data_formats/arrow/mapper/media.py index 01cc1201c7..a1074cde4e 100644 --- a/src/datumaro/plugins/data_formats/arrow/mapper/media.py +++ b/src/datumaro/plugins/data_formats/arrow/mapper/media.py @@ -3,21 +3,16 @@ # SPDX-License-Identifier: MIT -import os -import struct -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np import pyarrow as pa from datumaro.components.errors import DatumaroError from datumaro.components.media import Image, MediaElement, MediaType, PointCloud -from datumaro.plugins.data_formats.datumaro_binary.mapper.common import DictMapper, Mapper +from datumaro.plugins.data_formats.datumaro_binary.mapper.common import Mapper from datumaro.util.image import decode_image, encode_image, load_image -from .utils import b64decode, b64encode, pa_batches_decoder - class MediaMapper(Mapper): @classmethod @@ -33,38 +28,25 @@ def forward(cls, obj: Optional[MediaElement], **options) -> Dict[str, Any]: raise DatumaroError(f"{obj._type} is not allowed for MediaMapper.") @classmethod - def backward(cls, obj: Dict[str, Any]) -> Optional[MediaElement]: - media_type = obj["type"] + def backward( + cls, + media_struct: pa.StructScalar, + idx: int, + table: pa.Table, + table_path: str, + ) -> Optional[MediaElement]: + media_type = MediaType(media_struct.get("type").as_py()) if media_type == MediaType.NONE: return None if media_type == MediaType.IMAGE: - return ImageMapper.backward(obj) + return ImageMapper.backward(media_struct, idx, table, table_path) if media_type == MediaType.POINT_CLOUD: - return PointCloudMapper.backward(obj) + return PointCloudMapper.backward(media_struct, idx, table, table_path) if media_type == MediaType.MEDIA_ELEMENT: - return MediaElementMapper.backward(obj) + return MediaElementMapper.backward(media_struct, idx, table, table_path) raise DatumaroError(f"{media_type} is not allowed for MediaMapper.") - @classmethod - def backward_from_batches( - cls, - batches: List[pa.lib.RecordBatch], - parent: Optional[str] = None, - ) -> List[Optional[MediaElement]]: - types = pa_batches_decoder(batches, f"{parent}.type" if parent else "type") - assert len(set(types)) == 1, "The types in batch are not identical." - - if types[0] == MediaType.NONE: - return [None for _ in types] - if types[0] == MediaType.IMAGE: - return ImageMapper.backward_from_batches(batches, parent) - if types[0] == MediaType.POINT_CLOUD: - return PointCloudMapper.backward_from_batches(batches, parent) - if types[0] == MediaType.MEDIA_ELEMENT: - return MediaElementMapper.backward_from_batches(batches, parent) - raise NotImplementedError - class MediaElementMapper(Mapper): MAGIC_PATH = "/NOT/A/REAL/PATH" @@ -72,32 +54,17 @@ class MediaElementMapper(Mapper): @classmethod def forward(cls, obj: MediaElement) -> Dict[str, Any]: - return { - "type": int(obj.type), - "path": getattr(obj, "path", cls.MAGIC_PATH), - } - - @classmethod - def backward_dict(cls, obj: Dict[str, Any]) -> Dict[str, Any]: - obj = obj.copy() - media_type = obj["type"] - assert media_type == cls.MEDIA_TYPE, f"Expect {cls.MEDIA_TYPE} but actual is {media_type}." - obj["type"] = media_type - return obj + return {"type": int(obj.type)} @classmethod - def backward(cls, obj: Dict[str, Any]) -> MediaElement: - _ = cls.backward_dict(obj) - return MediaElement() - - @classmethod - def backward_from_batches( + def backward( cls, - batches: List[pa.lib.RecordBatch], - parent: Optional[str] = None, - ) -> List[MediaElement]: - types = pa_batches_decoder(batches, f"{parent}.type" if parent else "type") - return [MediaElement() for _ in types] + media_dict: Dict[str, Any], + idx: int, + table: pa.Table, + table_path: str, + ) -> MediaElement: + return MediaElement() class ImageMapper(MediaElementMapper): @@ -149,58 +116,62 @@ def forward( ) -> Dict[str, Any]: out = super().forward(obj) - _bytes = None - if isinstance(encoder, Callable): - _bytes = encoder(obj) - else: - _bytes = cls.encode(obj, scheme=encoder) - out["bytes"] = _bytes + _bytes = encoder(obj) if isinstance(encoder, Callable) else cls.encode(obj, scheme=encoder) - out["attributes"] = DictMapper.forward(dict(size=obj.size)) + path = None if _bytes is not None else getattr(obj, "path", None) + out["image"] = { + "has_bytes": _bytes is not None, + "bytes": _bytes, + "path": path, + "size": obj.size, + } return out @classmethod - def backward(cls, obj: Dict[str, Any]) -> Image: - media_dict = cls.backward_dict(obj) - - path = media_dict["path"] - attributes, _ = DictMapper.backward(media_dict["attributes"], 0) - - _bytes = None - if media_dict.get("bytes", None) is not None: - _bytes = media_dict["bytes"] + def backward( + cls, + media_struct: pa.StructScalar, + idx: int, + table: pa.Table, + table_path: str, + ) -> Image: + image_struct = media_struct.get("image") + + if path := image_struct.get("path").as_py(): + return Image.from_file( + path=path, + size=image_struct.get("size").as_py(), + ) - if _bytes: - return Image.from_bytes(data=_bytes, size=attributes["size"]) - return Image.from_file(path=path, size=attributes["size"]) + return Image.from_bytes( + data=lambda: pa.ipc.open_file(pa.memory_map(table_path, "r")) + .read_all() + .column("media")[idx] + .get("image") + .get("bytes") + .as_py(), + size=image_struct.get("size").as_py(), + ) @classmethod - def backward_from_batches( - cls, - batches: List[pa.lib.RecordBatch], - parent: Optional[str] = None, - ) -> List[Image]: - paths = pa_batches_decoder(batches, f"{parent}.path" if parent else "path") - attributes_ = pa_batches_decoder( - batches, f"{parent}.attributes" if parent else "attributes" + def backward_extra_image( + cls, image_struct: pa.StructScalar, idx: int, table: pa.Table, extra_image_idx: int + ) -> Image: + if path := image_struct.get("path").as_py(): + return Image.from_file( + path=path, + size=image_struct.get("size").as_py(), + ) + + return Image.from_bytes( + data=lambda: table.column("media")[idx] + .get("point_cloud") + .get("extra_images")[extra_image_idx] + .get("bytes") + .as_py(), + size=image_struct.get("size").as_py(), ) - attributes_ = [DictMapper.backward(attributes)[0] for attributes in attributes_] - - images = [] - for idx, (path, attributes) in enumerate(zip(paths, attributes_)): - if os.path.exists(path): - images.append(Image.from_file(path=path, size=attributes["size"])) - else: - images.append( - Image.from_bytes( - data=lambda: pa_batches_decoder( - batches, f"{parent}.bytes" if parent else "bytes" - )[idx], - size=attributes["size"], - ) - ) - return images # TODO: share binary for extra images @@ -214,85 +185,50 @@ def forward( ) -> Dict[str, Any]: out = super().forward(obj) - _bytes = None if isinstance(encoder, Callable): _bytes = encoder(obj) elif encoder != "NONE": _bytes = obj.data - out["bytes"] = _bytes - - bytes_arr = bytearray() - bytes_arr.extend(struct.pack(" PointCloud: - offset = 0 - media_dict = cls.backward_dict(obj) - - path = media_dict["path"] - (len_extra_images,) = struct.unpack_from(" List[PointCloud]: - paths = pa_batches_decoder(batches, f"{parent}.path" if parent else "path") - - def extra_images(idx): - offset = 0 - attributes = pa_batches_decoder( - batches, f"{parent}.attributes" if parent else "attributes" - )[idx] - (len_extra_images,) = struct.unpack_from(" PointCloud: + point_cloud_struct = media_struct.get("point_cloud") + + extra_images = [ + ImageMapper.backward_extra_image(image_struct, idx, table, extra_image_idx) + for extra_image_idx, image_struct in enumerate(point_cloud_struct.get("extra_images")) + ] + + if path := point_cloud_struct.get("path").as_py(): + return PointCloud.from_file(path=path, extra_images=extra_images) + + return PointCloud.from_bytes( + data=pa.ipc.open_file(pa.memory_map(table_path, "r")) + .read_all() + .column("media")[idx] + .get("point_cloud") + .get("bytes") + .as_py(), + extra_images=extra_images, + ) diff --git a/src/datumaro/plugins/data_formats/arrow/mapper/utils.py b/src/datumaro/plugins/data_formats/arrow/mapper/utils.py deleted file mode 100644 index 910ac427c5..0000000000 --- a/src/datumaro/plugins/data_formats/arrow/mapper/utils.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (C) 2023 Intel Corporation -# -# SPDX-License-Identifier: MIT - -import base64 - -import pyarrow as pa - - -def pa_batches_decoder(batches, column=None): - if not isinstance(batches, list): - batches = [batches] - table = pa.Table.from_batches(batches) - if column: - data = table.column(column).to_pylist() - else: - data = table.to_pylist() - return data - - -def b64encode(obj, prefix=None): - if isinstance(obj, dict): - for k, v in obj.items(): - obj[k] = b64encode(v, prefix) - elif isinstance(obj, (list, tuple)): - _obj = [] - for v in obj: - _obj.append(b64encode(v, prefix)) - if isinstance(obj, list): - _obj = list(_obj) - obj = _obj - elif isinstance(obj, bytes): - obj = base64.b64encode(obj).decode() - if prefix: - obj = prefix + obj - return obj - - -def b64decode(obj, prefix=None): - if isinstance(obj, dict): - for k, v in obj.items(): - obj[k] = b64decode(v, prefix) - elif isinstance(obj, (list, tuple)): - _obj = [] - for v in obj: - _obj.append(b64decode(v, prefix)) - if isinstance(obj, list): - _obj = list(_obj) - obj = _obj - elif isinstance(obj, str): - if prefix and obj.startswith(prefix): - obj = obj.replace(prefix, "", 1) - obj = base64.b64decode(obj) - else: - try: - _obj = base64.b64decode(obj) - if base64.b64encode(_obj).decode() == obj: - obj = _obj - except Exception: - pass - return obj diff --git a/src/datumaro/plugins/data_formats/datumaro/exporter.py b/src/datumaro/plugins/data_formats/datumaro/exporter.py index 80cebbed19..5800c3ca7a 100644 --- a/src/datumaro/plugins/data_formats/datumaro/exporter.py +++ b/src/datumaro/plugins/data_formats/datumaro/exporter.py @@ -44,6 +44,88 @@ from .format import DATUMARO_FORMAT_VERSION, DatumaroPath +class JsonWriter: + @classmethod + def _convert_attribute_categories(cls, attributes): + return sorted(attributes) + + @classmethod + def _convert_labels_label_groups(cls, labels): + return sorted(labels) + + @classmethod + def _convert_label_categories(cls, obj): + converted = { + "labels": [], + "label_groups": [], + "attributes": cls._convert_attribute_categories(obj.attributes), + } + for label in obj.items: + converted["labels"].append( + { + "name": cast(label.name, str), + "parent": cast(label.parent, str), + "attributes": cls._convert_attribute_categories(label.attributes), + } + ) + for label_group in obj.label_groups: + converted["label_groups"].append( + { + "name": cast(label_group.name, str), + "group_type": label_group.group_type.to_str(), + "labels": cls._convert_labels_label_groups(label_group.labels), + } + ) + return converted + + @classmethod + def _convert_mask_categories(cls, obj): + converted = { + "colormap": [], + } + for label_id, color in obj.colormap.items(): + converted["colormap"].append( + { + "label_id": int(label_id), + "r": int(color[0]), + "g": int(color[1]), + "b": int(color[2]), + } + ) + return converted + + @classmethod + def _convert_points_categories(cls, obj): + converted = { + "items": [], + } + for label_id, item in obj.items.items(): + converted["items"].append( + { + "label_id": int(label_id), + "labels": [cast(label, str) for label in item.labels], + "joints": [list(map(int, j)) for j in item.joints], + } + ) + return converted + + @classmethod + def write_categories(cls, categories) -> Dict[str, Dict]: + dict_cat = {} + for ann_type, desc in categories.items(): + if isinstance(desc, LabelCategories): + converted_desc = cls._convert_label_categories(desc) + elif isinstance(desc, MaskCategories): + converted_desc = cls._convert_mask_categories(desc) + elif isinstance(desc, PointsCategories): + converted_desc = cls._convert_points_categories(desc) + else: + raise NotImplementedError() + dict_cat[ann_type.name] = converted_desc + + return dict_cat + + class _SubsetWriter: def __init__( self, @@ -216,16 +298,7 @@ def add_infos(self, infos): self._data["infos"].update(infos) def add_categories(self, categories): - for ann_type, desc in categories.items(): - if isinstance(desc, LabelCategories): - converted_desc = self._convert_label_categories(desc) - elif isinstance(desc, MaskCategories): - converted_desc = self._convert_mask_categories(desc) - elif isinstance(desc, PointsCategories): - converted_desc = self._convert_points_categories(desc) - else: - raise NotImplementedError() - self.categories[ann_type.name] = converted_desc + self._data["categories"] = JsonWriter.write_categories(categories) def write(self, *args, **kwargs): dump_json_file(self.ann_file, self._data) @@ -339,65 +412,6 @@ def _convert_cuboid_3d_object(self, obj): def _convert_ellipse_object(self, obj: Ellipse): return self._convert_shape_object(obj) - def _convert_attribute_categories(self, attributes): - return sorted(attributes) - - def _convert_labels_label_groups(self, labels): - return sorted(labels) - - def _convert_label_categories(self, obj): - converted = { - "labels": [], - "label_groups": [], - "attributes": self._convert_attribute_categories(obj.attributes), - } - for label in obj.items: - converted["labels"].append( - { - "name": cast(label.name, str), - "parent": cast(label.parent, str), - "attributes": self._convert_attribute_categories(label.attributes), - } - ) - for label_group in obj.label_groups: - converted["label_groups"].append( - { - "name": cast(label_group.name, str), - "group_type": label_group.group_type.to_str(), - "labels": self._convert_labels_label_groups(label_group.labels), - } - ) - return converted - - def _convert_mask_categories(self, obj): - converted = { - "colormap": [], - } - for label_id, color in obj.colormap.items(): - converted["colormap"].append( - { - "label_id": int(label_id), - "r": int(color[0]), - "g": int(color[1]), - "b": int(color[2]), - } - ) - return converted - - def _convert_points_categories(self, obj): - converted = { - "items": [], - } - for label_id, item in obj.items.items(): - converted["items"].append( - { - "label_id": int(label_id), - "labels": [cast(label, str) for label in item.labels], - "joints": [list(map(int, j)) for j in item.joints], - } - ) - return converted - class _StreamSubsetWriter(_SubsetWriter): def __init__( diff --git a/tests/unit/data_formats/arrow/conftest.py b/tests/unit/data_formats/arrow/conftest.py index 69488615b0..d3fdec99bc 100644 --- a/tests/unit/data_formats/arrow/conftest.py +++ b/tests/unit/data_formats/arrow/conftest.py @@ -12,8 +12,6 @@ from datumaro.components.dataset_base import DatasetItem from datumaro.components.media import Image, PointCloud from datumaro.components.project import Dataset -from datumaro.plugins.data_formats.arrow import ArrowExporter -from datumaro.plugins.data_formats.arrow.arrow_dataset import ArrowDataset from datumaro.util.image import encode_image from ..datumaro.conftest import ( @@ -138,13 +136,3 @@ def fxt_point_cloud(test_dir, n=1000): ) yield source_dataset - - -@pytest.fixture -def fxt_arrow_dataset(fxt_image, test_dir): - ArrowExporter.convert(fxt_image, save_dir=test_dir, save_media=True) - files = [ - os.path.join(test_dir, file) for file in os.listdir(test_dir) if file.endswith(".arrow") - ] - dataset = ArrowDataset(files) - yield dataset diff --git a/tests/unit/data_formats/arrow/test_arrow_format.py b/tests/unit/data_formats/arrow/test_arrow_format.py index 458f64178f..e5a0744080 100644 --- a/tests/unit/data_formats/arrow/test_arrow_format.py +++ b/tests/unit/data_formats/arrow/test_arrow_format.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT -import os from functools import partial import numpy as np @@ -11,11 +10,9 @@ from datumaro.components.dataset_base import DatasetItem from datumaro.components.environment import Environment -from datumaro.components.importer import DatasetImportError from datumaro.components.media import FromFileMixin, Image from datumaro.components.project import Dataset from datumaro.plugins.data_formats.arrow import ArrowExporter, ArrowImporter -from datumaro.plugins.data_formats.arrow.arrow_dataset import ArrowDataset from datumaro.plugins.transforms import Sort from ....requirements import Requirements, mark_requirement @@ -23,99 +20,6 @@ from tests.utils.test_utils import check_save_and_load, compare_datasets, compare_datasets_strict -class ArrowDatasetTest: - @mark_requirement(Requirements.DATUM_GENERAL_REQ) - @pytest.mark.parametrize( - ["fxt_dataset"], - [ - pytest.param("fxt_arrow_dataset", id="test_arrow_format_with_keep_in_memory"), - ], - ) - def test_arrow_dataset_getitem(self, helper_tc, fxt_dataset, request): - fxt_dataset = request.getfixturevalue(fxt_dataset) - - helper_tc.assertTrue( - len(fxt_dataset.column_names) < len(fxt_dataset.flatten().column_names) - ) - - # column name - for column_name in fxt_dataset.column_names: - _dataset = fxt_dataset[column_name] - helper_tc.assertEqual(len(_dataset.column_names), 1) - helper_tc.assertEqual(column_name, _dataset.column_names[0]) - - with pytest.raises(KeyError): - fxt_dataset["invalid column name"] - - length = len(fxt_dataset) - - # positive integer - for i in range(length): - item = fxt_dataset[i] - helper_tc.assertTrue(isinstance(item, dict)) - helper_tc.assertEqual(item["id"], str(i)) - - # negative integer - for i in range(length): - item = fxt_dataset[i - length] - helper_tc.assertTrue(isinstance(item, dict)) - helper_tc.assertEqual(item["id"], str(i)) - - # positive slice - items = fxt_dataset[0:100] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual([i["id"] for i in items], [str(i) for i in range(0, 100)]) - - # positive range - items = fxt_dataset[range(0, 100)] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual([i["id"] for i in items], [str(i) for i in range(0, 100)]) - - # negative slice - items = fxt_dataset[-100:-1] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual( - [i["id"] for i in items], [str(i) for i in range(length - 100, length - 1)] - ) - - # negative range - items = fxt_dataset[range(length - 100, length - 1)] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual( - [i["id"] for i in items], [str(i) for i in range(length - 100, length - 1)] - ) - - # positive slice with interval - items = fxt_dataset[0:100:3] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual([i["id"] for i in items], [str(i) for i in range(0, 100, 3)]) - - # positive range with interval - items = fxt_dataset[range(0, 100, 3)] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual([i["id"] for i in items], [str(i) for i in range(0, 100, 3)]) - - # negative slice with interval - items = fxt_dataset[-100:-1:3] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual( - [i["id"] for i in items], [str(i) for i in range(length - 100, length - 1, 3)] - ) - - # negative range with interval - items = fxt_dataset[range(length - 100, length - 1, 3)] - helper_tc.assertTrue(isinstance(items, list)) - helper_tc.assertEqual( - [i["id"] for i in items], [str(i) for i in range(length - 100, length - 1, 3)] - ) - - with pytest.raises(KeyError): - fxt_dataset[0.1] - - with pytest.raises(IndexError): - fxt_dataset[length] - - class ArrowFormatTest: exporter = ArrowExporter importer = ArrowImporter @@ -269,7 +173,7 @@ def _test_save_and_load( compare_datasets_strict, True, True, - {"max_chunk_size": 20, "num_shards": 5}, + {"max_shard_size": None, "num_shards": 5}, lambda dataset: Sort(dataset, lambda item: int(item.id)), id="test_can_save_and_load_image_with_num_shards", ), @@ -278,7 +182,7 @@ def _test_save_and_load( compare_datasets_strict, True, True, - {"max_chunk_size": 20, "max_shard_size": "1M"}, + {"max_shard_size": 20, "num_shards": None}, lambda dataset: Sort(dataset, lambda item: int(item.id)), id="test_can_save_and_load_image_with_max_size", ), @@ -296,7 +200,7 @@ def _test_save_and_load( compare_datasets_strict, True, True, - {"max_chunk_size": 20, "num_shards": 5}, + {"max_shard_size": None, "num_shards": 5}, lambda dataset: Sort(dataset, lambda item: int(item.id)), id="test_can_save_and_load_point_cloud_with_num_shards", ), @@ -305,7 +209,7 @@ def _test_save_and_load( compare_datasets_strict, True, True, - {"max_chunk_size": 20, "max_shard_size": "1M"}, + {"max_shard_size": 20, "num_shards": None}, lambda dataset: Sort(dataset, lambda item: int(item.id)), id="test_can_save_and_load_point_cloud_with_max_size", ),