Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Michał Sośnicki committed Nov 13, 2024
1 parent ed4fe80 commit 77c7f6e
Showing 1 changed file with 11 additions and 19 deletions.
30 changes: 11 additions & 19 deletions src/neptune_fetcher/api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Tuple,
Union,
)
from tqdm import tqdm

from neptune_api import (
AuthenticatedClient,
Expand Down Expand Up @@ -51,10 +50,10 @@
FloatTimeSeriesValuesRequest,
FloatTimeSeriesValuesRequestOrder,
FloatTimeSeriesValuesRequestSeries,
OpenRangeDTO,
QueryAttributeDefinitionsBodyDTO,
QueryAttributeDefinitionsResultDTO,
QueryAttributesBodyDTO,
OpenRangeDTO,
SearchLeaderboardEntriesParamsDTO,
TimeSeries,
TimeSeriesLineage,
Expand Down Expand Up @@ -126,11 +125,11 @@ def fetch_series_values(
last_step_value = batch[-1].step if batch else None

def fetch_multiple_series_values(
self,
paths: List[str],
include_inherited: bool,
container_id: str,
step_range: Tuple[Union[float, None], Union[float, None]] = (None, None),
self,
paths: List[str],
include_inherited: bool,
container_id: str,
step_range: Tuple[Union[float, None], Union[float, None]] = (None, None),
) -> Iterator[(str, List[FloatPointValue])]:
max_paths_per_request: int = 100
total_step_size: int = 100_000
Expand All @@ -139,7 +138,7 @@ def fetch_multiple_series_values(
if paths_len > max_paths_per_request:
results = {}
for i in range(0, paths_len, max_paths_per_request):
batch_paths = paths[i:i + max_paths_per_request]
batch_paths = paths[i : i + max_paths_per_request]
batch_result = self.fetch_multiple_series_values(
paths=batch_paths,
include_inherited=include_inherited,
Expand All @@ -149,12 +148,8 @@ def fetch_multiple_series_values(
results.update(batch_result)
return results

results = {
path: [] for path in paths
}
attribute_steps = {
path: None for path in paths
}
results = {path: [] for path in paths}
attribute_steps = {path: None for path in paths}

while attribute_steps:
step_size = total_step_size // len(attribute_steps)
Expand Down Expand Up @@ -202,7 +197,7 @@ def _fetch_series_values(
identifier=request.container_id,
type="experiment",
),
lineage=TimeSeriesLineage.FULL if request.include_inherited else TimeSeriesLineage.NONE
lineage=TimeSeriesLineage.FULL if request.include_inherited else TimeSeriesLineage.NONE,
),
after_step=request.after_step,
)
Expand All @@ -216,10 +211,7 @@ def _fetch_series_values(
)

response = backoff_retry(
lambda: get_multiple_float_series_values_proto.sync_detailed(
client=self._backend,
body=request
)
lambda: get_multiple_float_series_values_proto.sync_detailed(client=self._backend, body=request)
)

data: ProtoFloatSeriesValuesResponseDTO = ProtoFloatSeriesValuesResponseDTO.FromString(response.content)
Expand Down

0 comments on commit 77c7f6e

Please sign in to comment.