-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor benchmark dataset format and add big ann benchmark format #265
Changes from all commits
d0fc9d9
ac918ec
e34b444
1c854ee
28a2624
5993692
a10a08a
aea9017
77911c5
c697ef5
d9bf04b
ba90f00
14a517d
dfd09d9
2413506
4bc3b8f
15fc192
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,30 +6,24 @@ | |
|
||
"""Utility functions for parsing""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Union, cast | ||
import h5py | ||
|
||
from okpt.io.config.parsers.base import ConfigurationError | ||
from okpt.io.dataset import HDF5DataSet, BigANNNeighborDataSet, \ | ||
BigANNVectorDataSet, DataSet, Context | ||
|
||
|
||
@dataclass | ||
class Dataset: | ||
train: h5py.Dataset | ||
test: h5py.Dataset | ||
neighbors: h5py.Dataset | ||
distances: h5py.Dataset | ||
def parse_dataset(dataset_format: str, dataset_path: str, | ||
context: Context) -> DataSet: | ||
if dataset_format == 'hdf5': | ||
return HDF5DataSet(dataset_path, context) | ||
|
||
if dataset_format == 'bigann' and context == Context.NEIGHBORS: | ||
return BigANNNeighborDataSet(dataset_path) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: BigANNNeighborDataset There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I felt like that looked better as well but PyCharm was giving me spelling errors so I went with DataSet. Here is a post on it: https://english.stackexchange.com/questions/2120/which-is-correct-dataset-or-data-set. I dont really have a preference so I went with the one that wasnt giving me spelling errors. |
||
|
||
def parse_dataset(dataset_path: str, dataset_format: str) -> Union[Dataset]: | ||
if dataset_format == 'hdf5': | ||
file = h5py.File(dataset_path) | ||
return Dataset(train=cast(h5py.Dataset, file['train']), | ||
test=cast(h5py.Dataset, file['test']), | ||
neighbors=cast(h5py.Dataset, file['neighbors']), | ||
distances=cast(h5py.Dataset, file['distances'])) | ||
else: | ||
raise Exception() | ||
if dataset_format == 'bigann': | ||
return BigANNVectorDataSet(dataset_path) | ||
|
||
raise Exception("Unsupported data-set format") | ||
|
||
|
||
def parse_string_param(key: str, first_map, second_map, default) -> str: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
|
||
"""Defines DataSet interface and implements particular formats | ||
|
||
A DataSet is the basic functionality that it can be read in chunks, or | ||
read completely and reset to the start. | ||
|
||
Currently, we support HDF5 formats from ann-benchmarks and big-ann-benchmarks | ||
datasets. | ||
|
||
Classes: | ||
HDF5DataSet: Format used in ann-benchmarks | ||
BigANNNeighborDataSet: Neighbor format for big-ann-benchmarks | ||
BigANNVectorDataSet: Vector format for big-ann-benchmarks | ||
""" | ||
import os | ||
from abc import ABC, ABCMeta, abstractmethod | ||
from enum import Enum | ||
from typing import cast | ||
import h5py | ||
import numpy as np | ||
|
||
import struct | ||
|
||
|
||
class Context(Enum): | ||
"""DataSet context enum. Can be used to add additional context for how a | ||
data-set should be interpreted. | ||
""" | ||
INDEX = 1 | ||
QUERY = 2 | ||
NEIGHBORS = 3 | ||
|
||
|
||
class DataSet(ABC): | ||
"""DataSet interface. Used for reading data-sets from files. | ||
|
||
Methods: | ||
read: Read a chunk of data from the data-set | ||
size: Gets the number of items in the data-set | ||
reset: Resets internal state of data-set to beginning | ||
""" | ||
__metaclass__ = ABCMeta | ||
|
||
@abstractmethod | ||
def read(self, chunk_size: int): | ||
pass | ||
|
||
@abstractmethod | ||
def size(self): | ||
pass | ||
|
||
@abstractmethod | ||
def reset(self): | ||
pass | ||
|
||
|
||
class HDF5DataSet(DataSet): | ||
""" Data-set format corresponding to `ANN Benchmarks | ||
<https://github.com/erikbern/ann-benchmarks#data-sets>`_ | ||
""" | ||
|
||
def __init__(self, dataset_path: str, context: Context): | ||
file = h5py.File(dataset_path) | ||
self.data = cast(h5py.Dataset, file[self._parse_context(context)]) | ||
self.current = 0 | ||
|
||
def read(self, chunk_size: int): | ||
if self.current >= self.size(): | ||
return None | ||
|
||
end_i = self.current + chunk_size | ||
if end_i > self.size(): | ||
end_i = self.size() | ||
|
||
v = cast(np.ndarray, self.data[self.current:end_i]) | ||
self.current = end_i | ||
return v | ||
|
||
def size(self): | ||
return self.data.len() | ||
|
||
def reset(self): | ||
self.current = 0 | ||
|
||
@staticmethod | ||
def _parse_context(context: Context) -> str: | ||
if context == Context.NEIGHBORS: | ||
return "neighbors" | ||
|
||
if context == Context.INDEX: | ||
return "train" | ||
|
||
if context == Context.QUERY: | ||
return "test" | ||
|
||
raise Exception("Unsupported context") | ||
|
||
|
||
class BigANNNeighborDataSet(DataSet): | ||
""" Data-set format for neighbor data-sets for `Big ANN Benchmarks | ||
<https://big-ann-benchmarks.com/index.html#bench-datasets>`_""" | ||
|
||
def __init__(self, dataset_path: str): | ||
self.file = open(dataset_path, 'rb') | ||
self.file.seek(0, os.SEEK_END) | ||
num_bytes = self.file.tell() | ||
self.file.seek(0) | ||
|
||
if num_bytes < 8: | ||
raise Exception("File is invalid") | ||
|
||
self.num_queries = int.from_bytes(self.file.read(4), "little") | ||
self.k = int.from_bytes(self.file.read(4), "little") | ||
|
||
# According to the website, the number of bytes that will follow will | ||
# be: num_queries X K x sizeof(uint32_t) bytes + num_queries X K x | ||
# sizeof(float) | ||
if (num_bytes - 8) != 2 * (self.num_queries * self.k * 4): | ||
raise Exception("File is invalid") | ||
|
||
self.current = 0 | ||
|
||
def read(self, chunk_size: int): | ||
if self.current >= self.size(): | ||
return None | ||
|
||
end_i = self.current + chunk_size | ||
if end_i > self.size(): | ||
end_i = self.size() | ||
|
||
v = [[int.from_bytes(self.file.read(4), "little") for _ in | ||
range(self.k)] for _ in range(end_i - self.current)] | ||
|
||
self.current = end_i | ||
return v | ||
|
||
def size(self): | ||
return self.num_queries | ||
|
||
def reset(self): | ||
self.file.seek(8) | ||
self.current = 0 | ||
|
||
|
||
class BigANNVectorDataSet(DataSet): | ||
""" Data-set format for vector data-sets for `Big ANN Benchmarks | ||
<https://big-ann-benchmarks.com/index.html#bench-datasets>`_ | ||
""" | ||
|
||
def __init__(self, dataset_path: str): | ||
self.file = open(dataset_path, 'rb') | ||
self.file.seek(0, os.SEEK_END) | ||
num_bytes = self.file.tell() | ||
self.file.seek(0) | ||
|
||
if num_bytes < 8: | ||
raise Exception("File is invalid") | ||
|
||
self.num_points = int.from_bytes(self.file.read(4), "little") | ||
self.dimension = int.from_bytes(self.file.read(4), "little") | ||
bytes_per_num = self._get_data_size(dataset_path) | ||
|
||
if (num_bytes - 8) != self.num_points * self.dimension * bytes_per_num: | ||
raise Exception("File is invalid") | ||
|
||
self.reader = self._value_reader(dataset_path) | ||
self.current = 0 | ||
|
||
def read(self, chunk_size: int): | ||
if self.current >= self.size(): | ||
return None | ||
|
||
end_i = self.current + chunk_size | ||
if end_i > self.size(): | ||
end_i = self.size() | ||
|
||
v = np.asarray([self._read_vector() for _ in | ||
range(end_i - self.current)]) | ||
self.current = end_i | ||
return v | ||
|
||
def _read_vector(self): | ||
return np.asarray([self.reader(self.file) for _ in | ||
range(self.dimension)]) | ||
|
||
def size(self): | ||
return self.num_points | ||
|
||
def reset(self): | ||
self.file.seek(8) # Seek to 8 bytes to skip re-reading metadata | ||
self.current = 0 | ||
|
||
@staticmethod | ||
def _get_data_size(file_name): | ||
ext = file_name.split('.')[-1] | ||
if ext == "u8bin": | ||
return 1 | ||
|
||
if ext == "fbin": | ||
return 4 | ||
|
||
raise Exception("Unknown extension") | ||
|
||
@staticmethod | ||
def _value_reader(file_name): | ||
ext = file_name.split('.')[-1] | ||
if ext == "u8bin": | ||
return lambda file: float(int.from_bytes(file.read(1), "little")) | ||
|
||
if ext == "fbin": | ||
return lambda file: struct.unpack('<f', file.read(4)) | ||
|
||
raise Exception("Unknown extension") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the behavior in case dataset is in some unsupported format or in a wrong context, do we need to validate and maybe throw some error in such case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Behavior is undefined. I will add some validation.