Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <jmazane@amazon.com>
  • Loading branch information
jmazanec15 committed Apr 21, 2022
1 parent 7b1c5a5 commit dec66f1
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 44 deletions.
42 changes: 25 additions & 17 deletions benchmarks/osb/extensions/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class HDF5DataSet(DataSet):
<https://github.com/erikbern/ann-benchmarks#data-sets>`_
"""

FORMAT_NAME = "hdf5"

def __init__(self, dataset_path: str, context: Context):
file = h5py.File(dataset_path)
self.data = cast(h5py.Dataset, file[self._parse_context(context)])
Expand Down Expand Up @@ -110,26 +112,32 @@ class BigANNVectorDataSet(DataSet):
"""

DATA_SET_HEADER_LENGTH = 8
U8BIN_EXTENSION = "u8bin"
FBIN_EXTENSION = "fbin"
FORMAT_NAME = "bigann"

BYTES_PER_U8INT = 1
BYTES_PER_FLOAT = 4

def __init__(self, dataset_path: str):
self.file = open(dataset_path, 'rb')
self.file.seek(self.BEGINNING, os.SEEK_END)
self.file.seek(BigANNVectorDataSet.BEGINNING, os.SEEK_END)
num_bytes = self.file.tell()
self.file.seek(self.BEGINNING)
self.file.seek(BigANNVectorDataSet.BEGINNING)

if num_bytes < self.DATA_SET_HEADER_LENGTH:
if num_bytes < BigANNVectorDataSet.DATA_SET_HEADER_LENGTH:
raise Exception("File is invalid")

self.num_points = int.from_bytes(self.file.read(4), "little")
self.dimension = int.from_bytes(self.file.read(4), "little")
self.bytes_per_num = self._get_data_size(dataset_path)

if (num_bytes - self.DATA_SET_HEADER_LENGTH) != self.num_points * \
if (num_bytes - BigANNVectorDataSet.DATA_SET_HEADER_LENGTH) != self.num_points * \
self.dimension * self.bytes_per_num:
raise Exception("File is invalid")

self.reader = self._value_reader(dataset_path)
self.current = self.BEGINNING
self.current = BigANNVectorDataSet.BEGINNING

def read(self, chunk_size: int):
if self.current >= self.size():
Expand All @@ -152,8 +160,8 @@ def seek(self, offset: int):
if offset >= self.size():
raise Exception("Offset must be less than the data set size")

bytes_offset = self.DATA_SET_HEADER_LENGTH + self.dimension * \
self.bytes_per_num * offset
bytes_offset = BigANNVectorDataSet.DATA_SET_HEADER_LENGTH + \
self.dimension * self.bytes_per_num * offset
self.file.seek(bytes_offset)
self.current = offset

Expand All @@ -165,27 +173,27 @@ def size(self):
return self.num_points

def reset(self):
self.file.seek(self.DATA_SET_HEADER_LENGTH)
self.current = self.BEGINNING
self.file.seek(BigANNVectorDataSet.DATA_SET_HEADER_LENGTH)
self.current = BigANNVectorDataSet.BEGINNING

@staticmethod
def _get_data_size(file_name):
ext = file_name.split('.')[-1]
if ext == "u8bin":
return 1
if ext == BigANNVectorDataSet.U8BIN_EXTENSION:
return BigANNVectorDataSet.BYTES_PER_U8INT

if ext == "fbin":
return 4
if ext == BigANNVectorDataSet.FBIN_EXTENSION:
return BigANNVectorDataSet.BYTES_PER_FLOAT

raise Exception("Unknown extension")

@staticmethod
def _value_reader(file_name):
ext = file_name.split('.')[-1]
if ext == "u8bin":
return lambda file: float(int.from_bytes(file.read(1), "little"))
if ext == BigANNVectorDataSet.U8BIN_EXTENSION:
return lambda file: float(int.from_bytes(file.read(BigANNVectorDataSet.BYTES_PER_U8INT), "little"))

if ext == "fbin":
return lambda file: struct.unpack('<f', file.read(4))
if ext == BigANNVectorDataSet.FBIN_EXTENSION:
return lambda file: struct.unpack('<f', file.read(BigANNVectorDataSet.BYTES_PER_FLOAT))

raise Exception("Unknown extension")
4 changes: 2 additions & 2 deletions benchmarks/osb/extensions/param_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(self, workload, params, **kwargs):
self.offset = 0

def _read_data_set(self):
if self.data_set_format == "hdf5":
if self.data_set_format == HDF5DataSet.FORMAT_NAME:
return HDF5DataSet(self.data_set_path, Context.INDEX)
if self.data_set_format == "bigann":
if self.data_set_format == BigANNVectorDataSet.FORMAT_NAME:
return BigANNVectorDataSet(self.data_set_path)
raise ConfigurationError("Invalid data set format")

Expand Down
39 changes: 14 additions & 25 deletions benchmarks/osb/extensions/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,13 @@ async def __call__(self, opensearch, params):
size = parse_int_parameter("size", params)
retries = parse_int_parameter("retries", params, 0) + 1

for _ in range(retries):
try:
await opensearch.bulk(
body=params["body"],
timeout='5m'
)
await opensearch.bulk(
body=params["body"],
timeout='5m',
max_retries=retries
)

return size, "docs"
except:
pass

raise TimeoutError("Failed to submit bulk request in specified number "
"of retries: {}".format(retries))
return size, "docs"

def __repr__(self, *args, **kwargs):
return "custom-vector-bulk"
Expand All @@ -52,18 +46,12 @@ class CustomRefreshRunner:
async def __call__(self, opensearch, params):
retries = parse_int_parameter("retries", params, 0) + 1

for _ in range(retries):
try:
await opensearch.indices.refresh(
index=parse_string_parameter("index", params)
)

return
except:
pass
await opensearch.indices.refresh(
index=parse_string_parameter("index", params),
max_retries=retries
)

raise TimeoutError("Failed to refresh the index in specified number "
"of retries: {}".format(retries))
return

def __repr__(self, *args, **kwargs):
return "custom-refresh"
Expand Down Expand Up @@ -93,11 +81,12 @@ async def __call__(self, opensearch, params):
return 1, "models_trained"

if model_response['state'] == 'failed':
raise Error("Failed to create model: {}".format(model_response))
raise Exception("Failed to create model: {}".format(model_response))

i += 1

raise TimeoutError('Failed to create model: {}'.format(model_id))
raise Exception('Failed to create model: {} within timeout {} seconds'
.format(model_id, timeout))

def __repr__(self, *args, **kwargs):
return "train-model"
Expand Down

0 comments on commit dec66f1

Please sign in to comment.