Skip to content

Commit

Permalink
Merge pull request #119 from neptune-ai/pg/use-threads
Browse files Browse the repository at this point in the history
feat: Use threads
  • Loading branch information
PatrykGala authored Nov 15, 2024
2 parents fd4aebc + b977bb4 commit ee97770
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 30 deletions.
23 changes: 6 additions & 17 deletions src/neptune_fetcher/api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,17 @@ def fetch_multiple_series_values(
container_id: str,
step_range: Tuple[Union[float, None], Union[float, None]] = (None, None),
) -> Iterator[(str, List[FloatPointValue])]:
max_paths_per_request: int = 100
total_step_size: int = 1_000_000
total_step_limit: int = 1_000_000

paths_len = len(paths)
if paths_len > max_paths_per_request:
results = {}
for i in range(0, paths_len, max_paths_per_request):
batch_paths = paths[i : i + max_paths_per_request]
batch_result = self.fetch_multiple_series_values(
paths=batch_paths,
include_inherited=include_inherited,
container_id=container_id,
step_range=step_range,
)
results.update(batch_result)
return results
if paths_len > total_step_limit:
raise ValueError(f"The number of paths ({paths_len}) exceeds the step limit ({total_step_limit})")

results = {path: [] for path in paths}
attribute_steps = {path: None for path in paths}

while attribute_steps:
step_size = total_step_size // len(attribute_steps)
series_step_limit = total_step_limit // len(attribute_steps)
requests = [
_SeriesRequest(
path=path,
Expand All @@ -166,14 +155,14 @@ def fetch_multiple_series_values(
values = self._fetch_series_values(
requests=requests,
step_range=step_range,
limit=step_size,
limit=series_step_limit,
)

new_attribute_steps = {}
for request, series_values in zip(requests, values):
path = request.path
results[path].extend(series_values)
if len(series_values) == step_size:
if len(series_values) == series_step_limit:
new_attribute_steps[path] = series_values[-1].step
else:
path_result = results.pop(path)
Expand Down
44 changes: 31 additions & 13 deletions src/neptune_fetcher/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import datetime
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import (
Dict,
List,
Expand Down Expand Up @@ -86,19 +89,34 @@ def prefetch_series_values(

float_series_paths = [path for path in paths if isinstance(self[path], FloatSeries)]

with tqdm(desc="Fetching metrics", total=len(float_series_paths), unit="metrics") as progress_bar:
result = self._backend.fetch_multiple_series_values(
float_series_paths,
include_inherited=include_inherited,
container_id=self._container_id,
step_range=step_range,
)

for path, points in result:
self[path].include_inherited = include_inherited
self[path].step_range = step_range
self[path].prefetched_data = list(points)
progress_bar.update()
max_workers = int(os.getenv("NEPTUNE_FETCHER_MAX_WORKERS", 10))

with tqdm(
desc="Fetching metrics", total=len(float_series_paths), unit="metrics"
) as progress_bar, ThreadPoolExecutor(max_workers) as executor:
lock = threading.Lock()
batch_size = 300

def fetch(start_index: int):
result = self._backend.fetch_multiple_series_values(
float_series_paths[start_index : start_index + batch_size],
include_inherited=include_inherited,
container_id=self._container_id,
step_range=step_range,
)

for path, points in result:
points = list(points)
with lock: # lock is inside the loop because result is a generator that fetches data lazily
self[path].include_inherited = include_inherited
self[path].step_range = step_range
self[path].prefetched_data = points
progress_bar.update()

futures = executor.map(fetch, range(0, len(float_series_paths), batch_size))

# Wait for all futures to finish
list(futures)

def __getitem__(self, path: str) -> Union[Field, FloatSeries]:
self.cache_miss(
Expand Down
159 changes: 159 additions & 0 deletions tests/unit/test_api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from datetime import (
datetime,
timezone,
)
from typing import (
List,
Tuple,
)
from unittest.mock import (
Mock,
patch,
)

from neptune_retrieval_api.proto.neptune_pb.api.model.series_values_pb2 import (
ProtoFloatPointValueDTO,
ProtoFloatSeriesValuesDTO,
ProtoFloatSeriesValuesResponseDTO,
ProtoFloatSeriesValuesSingleSeriesResponseDTO,
)
from pytest import fixture

from neptune_fetcher.api.api_client import ApiClient
from neptune_fetcher.fields import FloatPointValue


@fixture
def get_multiple_float_series_values_proto():
with patch("neptune_retrieval_api.api.default.get_multiple_float_series_values_proto.sync_detailed") as patched:
yield patched


class TestApiClient(ApiClient):
def __init__(self):
# don't call super().__init__ to avoid an attempt to authenticate
self._backend = None


def response(body: ProtoFloatSeriesValuesResponseDTO, status_code: int = 200):
content = body.SerializeToString()
return Mock(status_code=Mock(value=status_code), content=content)


def values_model(steps_values: List[Tuple[float, float]]) -> List[FloatPointValue]:
return [
FloatPointValue(
timestamp=datetime.fromtimestamp(i / 1000.0, tz=timezone.utc),
step=step,
value=value,
)
for i, (step, value) in enumerate(steps_values)
]


def values_dto(steps_values: List[Tuple[float, float]]) -> [ProtoFloatPointValueDTO]:
return [
ProtoFloatPointValueDTO(
timestamp_millis=i,
step=step,
value=value,
)
for i, (step, value) in enumerate(steps_values)
]


def single_series_dto(
steps_values: List[Tuple[float, float]], request_id: str = "0"
) -> ProtoFloatSeriesValuesSingleSeriesResponseDTO:
return ProtoFloatSeriesValuesSingleSeriesResponseDTO(
requestId=request_id,
series=ProtoFloatSeriesValuesDTO(
total_item_count=len(steps_values),
values=values_dto(steps_values),
),
)


def multiple_series_dto(
steps_values: List[List[Tuple[float, float]]],
) -> ProtoFloatSeriesValuesResponseDTO:
return ProtoFloatSeriesValuesResponseDTO(
series=[single_series_dto(steps_values, request_id=str(i)) for i, steps_values in enumerate(steps_values)]
)


def test_fetch_multiple_series_values__single_path__returns_empty_series(get_multiple_float_series_values_proto):
api_client = TestApiClient()
return_value = response(multiple_series_dto([[]]))
get_multiple_float_series_values_proto.return_value = return_value

results = api_client.fetch_multiple_series_values(
paths=["path1"],
include_inherited=True,
container_id="container_id",
)
results = dict(results)

assert results["path1"] == []


def test_fetch_multiple_series_values__single_path__returns_values(get_multiple_float_series_values_proto):
api_client = TestApiClient()
values = [(step, step * 2) for step in range(10)]
return_value = response(multiple_series_dto([values]))
get_multiple_float_series_values_proto.return_value = return_value

results = api_client.fetch_multiple_series_values(
paths=["path1"],
include_inherited=True,
container_id="container_id",
)
results = dict(results)

assert results["path1"] == values_model(steps_values=values)


def test_fetch_multiple_series_values__multiple_paths__returns_values(get_multiple_float_series_values_proto):
api_client = TestApiClient()
paths = ["path1", "path2", "path3"]
values = [[(step, step * (10**i)) for step in range(10)] for i in range(len(paths))]
return_value = response(multiple_series_dto(values))
get_multiple_float_series_values_proto.return_value = return_value

results = api_client.fetch_multiple_series_values(
paths=paths,
include_inherited=True,
container_id="container_id",
)
results = dict(results)

assert results["path1"] == values_model(steps_values=values[0])
assert results["path2"] == values_model(steps_values=values[1])
assert results["path3"] == values_model(steps_values=values[2])


def test_fetch_multiple_series_values__single_path__returns_values_exceeding_batch(
get_multiple_float_series_values_proto,
):
api_client = TestApiClient()
values = [(step, step * 2) for step in range(2_300_000)]
get_multiple_float_series_values_proto.side_effect = [
response(multiple_series_dto([batch]))
for batch in [values[:1_000_000], values[1_000_000:2_000_000], values[2_000_000:]]
]

results = api_client.fetch_multiple_series_values(
paths=["path1"],
include_inherited=True,
container_id="container_id",
)
result = dict(results)

assert len(result["path1"]) == len(values)
expected_values = (
values_model(steps_values=values[:1_000_000])
+ values_model(steps_values=values[1_000_000:2_000_000])
+ values_model(steps_values=values[2_000_000:])
)
for value, expected in zip(result["path1"], expected_values):
assert value == expected

0 comments on commit ee97770

Please sign in to comment.