Skip to content

Commit

Permalink
Add unit tests and fix failing unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Roblin <finnrobl@amazon.com>
  • Loading branch information
finnroblin committed Aug 27, 2024
1 parent 5b21773 commit d6a45c3
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 9 deletions.
4 changes: 2 additions & 2 deletions osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from functools import total_ordering
from io import BytesIO
from os.path import commonprefix
from os import cpu_count as os_cpu_count
import os
from typing import List, Optional

import ijson
Expand Down Expand Up @@ -1345,7 +1345,7 @@ def _get_should_calculate_recall(params: dict) -> bool:
num_clients = params.get("num_clients", 0)
if num_clients == 0:
self.logger.debug("Expected num_clients to be specified but was not.")
cpu_count = os_cpu_count()
cpu_count = os.cpu_count()
if cpu_count < num_clients:
self.logger.warning("Number of clients, %s, specified is greater than the number of CPUs, %s, available."\
"This will lead to unperformant context switching on load generation host. Performance "\
Expand Down
3 changes: 2 additions & 1 deletion osbenchmark/worker_coordinator/worker_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,8 @@ async def __call__(self, *args, **kwargs):
async with self.opensearch["default"].new_request_context() as request_context:
# add num_clients to the parameter so that vector search runner can skip calculating recall
# if num_clients > cpu_count().
params.update({"num_clients": self.task.clients})
if params:
params.update({"num_clients": self.task.clients})
total_ops, total_ops_unit, request_meta_data = await execute_single(runner, self.opensearch, params, self.on_error)
request_start = request_context.request_start
request_end = request_context.request_end
Expand Down
258 changes: 257 additions & 1 deletion tests/worker_coordinator/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2578,7 +2578,6 @@ async def test_train_timeout(self, opensearch, sleep, on_client_request_start, o
with self.assertRaisesRegex(TimeoutError, f'Failed to create model: {self.model_id} within {self.retries} retries'):
await runner_under_test(opensearch, self.request)


class VectorSearchQueryRunnerTests(TestCase):
@mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_end')
@mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_start')
Expand Down Expand Up @@ -3279,6 +3278,263 @@ async def test_query_vector_radial_search_with_max_distance(self, opensearch, on
headers={"Accept-Encoding": "identity"}
)

@mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_end')
@mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_start')
@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_should_skip_calculate_recall(self, opensearch, on_client_request_start, on_client_request_end):
with mock.patch("os.cpu_count", return_value=8):
num_clients = 9
class WorkerCoordinatorTestParamSource:
def __init__(self, workload=None, params=None, **kwargs):
if params is None:
params = {}
self._indices = workload.indices
self._params = params
self._current = 1
self._total = params.get("size")
self.infinite = self._total is None

def partition(self, partition_index, total_partitions):
return self

@property
def percent_completed(self):
if self.infinite:
return None
return self._current / self._total

def params(self):
if not self.infinite and self._current > self._total:
raise StopIteration()
self._current += 1
return self._params
# pylint: disable=C0415
from osbenchmark.worker_coordinator import worker_coordinator
# pylint: disable=C0415
from osbenchmark.workload import params
# pylint: disable=C0415
from osbenchmark import workload

# create task here
# sampler is mock
# create actual schedule w new params

opensearch.init_request_context.return_value = {
"client_request_start": 0,
"request_start": 1,
"request_end": 11,
"client_request_end": 12
}

search_response = {
"timed_out": False,
"took": 5,
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"hits": [
{
"_id": 101,
"_score": 0.95
},
{
"_id": 102,
"_score": 0.88
},
{
"_id": 103,
"_score": 0.1
}
]
}
}
opensearch.transport.perform_request = mock.AsyncMock(return_value=io.StringIO(json.dumps(search_response)))

params.register_param_source_for_name("worker-coordinator-test-param-source", WorkerCoordinatorTestParamSource)
test_workload = workload.Workload(name="unittest", description="unittest workload",
indices=None,
test_procedures=None)

task = workload.Task("time-based", workload.Operation("time-based",
workload.OperationType.VectorSearch.to_hyphenated_string(),
params={
"index": "_all",
"type": None,
"operation-type": "vector-search",
"detailed-results": True,
"response-compression-enabled": False,
"k": 3,
"neighbors": [101, 102, 103],
"body": {
"query": {
"knn": {
"location": {
"vector": [
5,
4
],
"k": 3
}
}}
},
"request-params": {},
"cache": False
},
param_source="worker-coordinator-test-param-source"),
warmup_time_period=0.5, time_period=0.5, clients=num_clients,
params={ "clients": num_clients},
completes_parent=False)

sampler = worker_coordinator.Sampler(start_timestamp=0)

runner.register_runner(operation_type=workload.OperationType.VectorSearch, runner=runner.Query(), async_runner=True)
param_source = workload.operation_parameters(test_workload, task)
# pylint: disable=C0415
import threading
schedule = worker_coordinator.schedule_for(task, 0, param_source)
executor = worker_coordinator.AsyncExecutor(client_id=0, task=task, schedule=schedule, opensearch={"default": opensearch},
sampler=sampler, cancel=threading.Event(), complete=threading.Event(),
on_error="continue")
# will run executor + vector search query runner.
await executor()

# make copy of samples since they disappear once first accessed.
samples = sampler.samples
recall_k = samples[0].request_meta_data.get("recall@k")
self.assertEqual(recall_k, None)

@mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_end')
@mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_start')
@mock.patch("opensearchpy.OpenSearch")
@run_async
async def test_query_vector_search_should_actually_calculate_recall_with_default_value(self, opensearch,
on_client_request_start, on_client_request_end):
with mock.patch("os.cpu_count", return_value=8):
num_clients = 8
class WorkerCoordinatorTestParamSource:
def __init__(self, workload=None, params=None, **kwargs):
if params is None:
params = {}
self._indices = workload.indices
self._params = params
self._current = 1
self._total = params.get("size")
self.infinite = self._total is None

def partition(self, partition_index, total_partitions):
return self

@property
def percent_completed(self):
if self.infinite:
return None
return self._current / self._total

def params(self):
if not self.infinite and self._current > self._total:
raise StopIteration()
self._current += 1
return self._params
# pylint: disable=C0415
from osbenchmark.worker_coordinator import worker_coordinator
# pylint: disable=C0415
from osbenchmark.workload import params
# pylint: disable=C0415
from osbenchmark import workload

# create task here
# sampler is mock
# create actual schedule w new params

opensearch.init_request_context.return_value = {
"client_request_start": 0,
"request_start": 1,
"request_end": 11,
"client_request_end": 12
}

search_response = {
"timed_out": False,
"took": 5,
"hits": {
"total": {
"value": 3,
"relation": "eq"
},
"hits": [
{
"_id": 101,
"_score": 0.95
},
{
"_id": 102,
"_score": 0.88
},
{
"_id": 103,
"_score": 0.1
}
]
}
}
opensearch.transport.perform_request = mock.AsyncMock(return_value=io.StringIO(json.dumps(search_response)))

params.register_param_source_for_name("worker-coordinator-test-param-source", WorkerCoordinatorTestParamSource)
test_workload = workload.Workload(name="unittest", description="unittest workload",
indices=None,
test_procedures=None)

task = workload.Task("time-based", workload.Operation("time-based",
workload.OperationType.VectorSearch.to_hyphenated_string(),
params={
"index": "_all",
"type": None,
"operation-type": "vector-search",
"detailed-results": True,
"response-compression-enabled": False,
"k": 3,
"neighbors": [101, 102, 103],
"body": {
"query": {
"knn": {
"location": {
"vector": [
5,
4
],
"k": 3
}
}}
},
"request-params": {},
"cache": False
},
param_source="worker-coordinator-test-param-source"),
warmup_time_period=0.5, time_period=0.5, clients=num_clients,
params={ "clients": num_clients},
completes_parent=False)

sampler = worker_coordinator.Sampler(start_timestamp=0)

runner.register_runner(operation_type=workload.OperationType.VectorSearch, runner=runner.Query(), async_runner=True)
param_source = workload.operation_parameters(test_workload, task)
# pylint: disable=C0415
import threading
schedule = worker_coordinator.schedule_for(task, 0, param_source)
executor = worker_coordinator.AsyncExecutor(client_id=0, task=task, schedule=schedule, opensearch={"default": opensearch},
sampler=sampler, cancel=threading.Event(), complete=threading.Event(),
on_error="continue")
# will run executor + vector search query runner.
await executor()

# make copy of samples since they disappear once first accessed.
samples = sampler.samples
recall_k = samples[0].request_meta_data.get("recall@k")
self.assertEqual(recall_k, 1.0)


class PutPipelineRunnerTests(TestCase):
@mock.patch('osbenchmark.client.RequestContextHolder.on_client_request_end')
Expand Down
15 changes: 10 additions & 5 deletions tests/workload/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,7 +2271,8 @@ def test_passes_cache(self):
})
p = source.params()

self.assertEqual(10, len(p))
self.assertEqual(11, len(p))
self.assertEqual(True, p["calculate-recall"])
self.assertEqual("index1", p["index"])
self.assertIsNone(p["type"])
self.assertIsNone(p["request-timeout"])
Expand Down Expand Up @@ -2307,7 +2308,8 @@ def test_uses_data_stream(self):
})
p = source.params()

self.assertEqual(10, len(p))
self.assertEqual(11, len(p))
self.assertEqual(True, p["calculate-recall"])
self.assertEqual("data-stream-1", p["index"])
self.assertIsNone(p["type"])
self.assertEqual(1.0, p["request-timeout"])
Expand Down Expand Up @@ -2354,7 +2356,8 @@ def test_passes_request_parameters(self):
})
p = source.params()

self.assertEqual(10, len(p))
self.assertEqual(11, len(p))
self.assertEqual(True, p["calculate-recall"])
self.assertEqual("index1", p["index"])
self.assertIsNone(p["type"])
self.assertIsNone(p["request-timeout"])
Expand Down Expand Up @@ -2390,7 +2393,8 @@ def test_user_specified_overrides_defaults(self):
})
p = source.params()

self.assertEqual(10, len(p))
self.assertEqual(11, len(p))
self.assertEqual(True, p["calculate-recall"])
self.assertEqual("_all", p["index"])
self.assertEqual("type1", p["type"])
self.assertDictEqual({}, p["request-params"])
Expand Down Expand Up @@ -2423,7 +2427,8 @@ def test_user_specified_data_stream_overrides_defaults(self):
})
p = source.params()

self.assertEqual(10, len(p))
self.assertEqual(11, len(p))
self.assertEqual(True, p["calculate-recall"])
self.assertEqual("data-stream-2", p["index"])
self.assertIsNone(p["type"])
self.assertEqual(1.0, p["request-timeout"])
Expand Down

0 comments on commit d6a45c3

Please sign in to comment.