-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add OpenSearch Benchmark index workload for k-NN (#364)
Initial commit for integration with OpenSearch benchmark code. Sets up a workload with 2 procedures that can be used to test indexing on an arbitrary k-NN configuration. Signed-off-by: John Mazanec <jmazane@amazon.com>
- Loading branch information
1 parent
ccedc8d
commit e5745de
Showing
21 changed files
with
1,377 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
|
||
import os | ||
import numpy as np | ||
from abc import ABC, ABCMeta, abstractmethod | ||
from enum import Enum | ||
from typing import cast | ||
import h5py | ||
import struct | ||
|
||
|
||
class Context(Enum): | ||
"""DataSet context enum. Can be used to add additional context for how a | ||
data-set should be interpreted. | ||
""" | ||
INDEX = 1 | ||
QUERY = 2 | ||
NEIGHBORS = 3 | ||
|
||
|
||
class DataSet(ABC): | ||
"""DataSet interface. Used for reading data-sets from files. | ||
Methods: | ||
read: Read a chunk of data from the data-set | ||
seek: Get to position in the data-set | ||
size: Gets the number of items in the data-set | ||
reset: Resets internal state of data-set to beginning | ||
""" | ||
__metaclass__ = ABCMeta | ||
|
||
BEGINNING = 0 | ||
|
||
@abstractmethod | ||
def read(self, chunk_size: int): | ||
pass | ||
|
||
@abstractmethod | ||
def seek(self, offset: int): | ||
pass | ||
|
||
@abstractmethod | ||
def size(self): | ||
pass | ||
|
||
@abstractmethod | ||
def reset(self): | ||
pass | ||
|
||
|
||
class HDF5DataSet(DataSet): | ||
""" Data-set format corresponding to `ANN Benchmarks | ||
<https://github.com/erikbern/ann-benchmarks#data-sets>`_ | ||
""" | ||
|
||
FORMAT_NAME = "hdf5" | ||
|
||
def __init__(self, dataset_path: str, context: Context): | ||
file = h5py.File(dataset_path) | ||
self.data = cast(h5py.Dataset, file[self._parse_context(context)]) | ||
self.current = self.BEGINNING | ||
|
||
def read(self, chunk_size: int): | ||
if self.current >= self.size(): | ||
return None | ||
|
||
end_offset = self.current + chunk_size | ||
if end_offset > self.size(): | ||
end_offset = self.size() | ||
|
||
v = cast(np.ndarray, self.data[self.current:end_offset]) | ||
self.current = end_offset | ||
return v | ||
|
||
def seek(self, offset: int): | ||
|
||
if offset < self.BEGINNING: | ||
raise Exception("Offset must be greater than or equal to 0") | ||
|
||
if offset >= self.size(): | ||
raise Exception("Offset must be less than the data set size") | ||
|
||
self.current = offset | ||
|
||
def size(self): | ||
return self.data.len() | ||
|
||
def reset(self): | ||
self.current = self.BEGINNING | ||
|
||
@staticmethod | ||
def _parse_context(context: Context) -> str: | ||
if context == Context.NEIGHBORS: | ||
return "neighbors" | ||
|
||
if context == Context.INDEX: | ||
return "train" | ||
|
||
if context == Context.QUERY: | ||
return "test" | ||
|
||
raise Exception("Unsupported context") | ||
|
||
|
||
class BigANNVectorDataSet(DataSet): | ||
""" Data-set format for vector data-sets for `Big ANN Benchmarks | ||
<https://big-ann-benchmarks.com/index.html#bench-datasets>`_ | ||
""" | ||
|
||
DATA_SET_HEADER_LENGTH = 8 | ||
U8BIN_EXTENSION = "u8bin" | ||
FBIN_EXTENSION = "fbin" | ||
FORMAT_NAME = "bigann" | ||
|
||
BYTES_PER_U8INT = 1 | ||
BYTES_PER_FLOAT = 4 | ||
|
||
def __init__(self, dataset_path: str): | ||
self.file = open(dataset_path, 'rb') | ||
self.file.seek(BigANNVectorDataSet.BEGINNING, os.SEEK_END) | ||
num_bytes = self.file.tell() | ||
self.file.seek(BigANNVectorDataSet.BEGINNING) | ||
|
||
if num_bytes < BigANNVectorDataSet.DATA_SET_HEADER_LENGTH: | ||
raise Exception("File is invalid") | ||
|
||
self.num_points = int.from_bytes(self.file.read(4), "little") | ||
self.dimension = int.from_bytes(self.file.read(4), "little") | ||
self.bytes_per_num = self._get_data_size(dataset_path) | ||
|
||
if (num_bytes - BigANNVectorDataSet.DATA_SET_HEADER_LENGTH) != self.num_points * \ | ||
self.dimension * self.bytes_per_num: | ||
raise Exception("File is invalid") | ||
|
||
self.reader = self._value_reader(dataset_path) | ||
self.current = BigANNVectorDataSet.BEGINNING | ||
|
||
def read(self, chunk_size: int): | ||
if self.current >= self.size(): | ||
return None | ||
|
||
end_offset = self.current + chunk_size | ||
if end_offset > self.size(): | ||
end_offset = self.size() | ||
|
||
v = np.asarray([self._read_vector() for _ in | ||
range(end_offset - self.current)]) | ||
self.current = end_offset | ||
return v | ||
|
||
def seek(self, offset: int): | ||
|
||
if offset < self.BEGINNING: | ||
raise Exception("Offset must be greater than or equal to 0") | ||
|
||
if offset >= self.size(): | ||
raise Exception("Offset must be less than the data set size") | ||
|
||
bytes_offset = BigANNVectorDataSet.DATA_SET_HEADER_LENGTH + \ | ||
self.dimension * self.bytes_per_num * offset | ||
self.file.seek(bytes_offset) | ||
self.current = offset | ||
|
||
def _read_vector(self): | ||
return np.asarray([self.reader(self.file) for _ in | ||
range(self.dimension)]) | ||
|
||
def size(self): | ||
return self.num_points | ||
|
||
def reset(self): | ||
self.file.seek(BigANNVectorDataSet.DATA_SET_HEADER_LENGTH) | ||
self.current = BigANNVectorDataSet.BEGINNING | ||
|
||
@staticmethod | ||
def _get_data_size(file_name): | ||
ext = file_name.split('.')[-1] | ||
if ext == BigANNVectorDataSet.U8BIN_EXTENSION: | ||
return BigANNVectorDataSet.BYTES_PER_U8INT | ||
|
||
if ext == BigANNVectorDataSet.FBIN_EXTENSION: | ||
return BigANNVectorDataSet.BYTES_PER_FLOAT | ||
|
||
raise Exception("Unknown extension") | ||
|
||
@staticmethod | ||
def _value_reader(file_name): | ||
ext = file_name.split('.')[-1] | ||
if ext == BigANNVectorDataSet.U8BIN_EXTENSION: | ||
return lambda file: float(int.from_bytes(file.read(BigANNVectorDataSet.BYTES_PER_U8INT), "little")) | ||
|
||
if ext == BigANNVectorDataSet.FBIN_EXTENSION: | ||
return lambda file: struct.unpack('<f', file.read(BigANNVectorDataSet.BYTES_PER_FLOAT)) | ||
|
||
raise Exception("Unknown extension") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
import copy | ||
|
||
from .data_set import Context, HDF5DataSet, DataSet, BigANNVectorDataSet | ||
from .util import bulk_transform, parse_string_parameter, parse_int_parameter, \ | ||
ConfigurationError | ||
|
||
|
||
def register(registry): | ||
registry.register_param_source( | ||
"bulk-from-data-set", BulkVectorsFromDataSetParamSource | ||
) | ||
|
||
|
||
class BulkVectorsFromDataSetParamSource: | ||
def __init__(self, workload, params, **kwargs): | ||
self.data_set_format = parse_string_parameter("data_set_format", params) | ||
self.data_set_path = parse_string_parameter("data_set_path", params) | ||
self.data_set: DataSet = self._read_data_set() | ||
|
||
self.field_name: str = parse_string_parameter("field", params) | ||
self.index_name: str = parse_string_parameter("index", params) | ||
self.bulk_size: int = parse_int_parameter("bulk_size", params) | ||
self.retries: int = parse_int_parameter("retries", params, 10) | ||
self.num_vectors: int = parse_int_parameter( | ||
"num_vectors", params, self.data_set.size() | ||
) | ||
self.total = self.num_vectors | ||
self.current = 0 | ||
self.infinite = False | ||
self.percent_completed = 0 | ||
self.offset = 0 | ||
|
||
def _read_data_set(self): | ||
if self.data_set_format == HDF5DataSet.FORMAT_NAME: | ||
return HDF5DataSet(self.data_set_path, Context.INDEX) | ||
if self.data_set_format == BigANNVectorDataSet.FORMAT_NAME: | ||
return BigANNVectorDataSet(self.data_set_path) | ||
raise ConfigurationError("Invalid data set format") | ||
|
||
def partition(self, partition_index, total_partitions): | ||
if self.data_set.size() % total_partitions != 0: | ||
raise ValueError("Data set must be divisible by number of clients") | ||
|
||
partition_x = copy.copy(self) | ||
partition_x.num_vectors = int(self.num_vectors / total_partitions) | ||
partition_x.offset = int(partition_index * partition_x.num_vectors) | ||
|
||
# We need to create a new instance of the data set for each client | ||
partition_x.data_set = partition_x._read_data_set() | ||
partition_x.data_set.seek(partition_x.offset) | ||
partition_x.current = partition_x.offset | ||
return partition_x | ||
|
||
def params(self): | ||
|
||
if self.current >= self.num_vectors + self.offset: | ||
raise StopIteration | ||
|
||
def action(doc_id): | ||
return {'index': {'_index': self.index_name, '_id': doc_id}} | ||
|
||
partition = self.data_set.read(self.bulk_size) | ||
body = bulk_transform(partition, self.field_name, action, self.current) | ||
size = len(body) // 2 | ||
self.current += size | ||
self.percent_completed = self.current / self.total | ||
|
||
return { | ||
"body": body, | ||
"retries": self.retries, | ||
"size": size | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# The OpenSearch Contributors require contributions made to | ||
# this file be licensed under the Apache-2.0 license or a | ||
# compatible open source license. | ||
|
||
from .param_sources import register as param_sources_register | ||
from .runners import register as runners_register | ||
|
||
|
||
def register(registry): | ||
param_sources_register(registry) | ||
runners_register(registry) |
Oops, something went wrong.