Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create query for tracking metrics #1150

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
fa93f33
Add MetricPlotType
jmholzer Nov 1, 2022
c3d9c00
Add MetricPlotType as a strawberry enum type
jmholzer Nov 1, 2022
a068a94
Add option to limit results for get_all_runs
jmholzer Nov 1, 2022
bd8141d
Add endpoint for metric data
jmholzer Nov 2, 2022
0847cea
Add implementations for formatting metrics data
jmholzer Nov 2, 2022
bb3624a
Refactor run_metrics_data query to return a MetricDataset
jmholzer Nov 7, 2022
f0b9305
Add MetricDataset type
jmholzer Nov 7, 2022
1dca121
Remove MetricPlotType type
jmholzer Nov 7, 2022
40970fe
Refactor format_run_metric_data
jmholzer Nov 7, 2022
77ba70c
Merge branch 'main' into create-query-for-tracking-metrics
jmholzer Nov 7, 2022
7b4f913
Merge branch 'main' of https://github.com/kedro-org/kedro-viz into cr…
tynandebold Nov 9, 2022
7ae66b0
Merge branch 'feature/divide-exp-tracking-details-into-tabs' of https…
tynandebold Nov 9, 2022
7ed996c
Add run_metrics_data to schema
jmholzer Nov 9, 2022
8c998dc
Add doc strings
jmholzer Nov 9, 2022
b62bb6e
Merge branch 'create-query-for-tracking-metrics' of github.com:kedro-…
jmholzer Nov 10, 2022
f08a227
Add MetricDataset to schema
jmholzer Nov 10, 2022
3c7c10a
Add pylint ignore
jmholzer Nov 10, 2022
164bdbf
Modify runs and metrics initialisation to use .values()
jmholzer Nov 10, 2022
c95de0c
Remove unused import
jmholzer Nov 10, 2022
1d4e015
Refactor metric naming not to account for dataset name 'root'
jmholzer Nov 11, 2022
f79540b
Add end-to-end test for the runMetricsData query
jmholzer Nov 11, 2022
699c9dd
Rename MetricDataset to MetricPlotDataset to avoid confusion with Met…
jmholzer Nov 11, 2022
5d287d6
Rename MetricDataset to MetricPlotDataset in schema
jmholzer Nov 11, 2022
5995b87
Add type hints for runs and metrics in _initialise_metric_data_template
jmholzer Nov 11, 2022
f38adc8
Refactor return type of runMetricsData query
jmholzer Nov 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion package/kedro_viz/api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@
from kedro_viz.data_access import data_access_manager
from kedro_viz.integrations.pypi import get_latest_version, is_running_outdated_version

from .serializers import format_run, format_run_tracking_data, format_runs
from .serializers import (
format_run,
format_run_metric_data,
format_run_tracking_data,
format_runs,
)
from .types import (
MetricPlotDataset,
Run,
RunInput,
TrackingDataset,
Expand Down Expand Up @@ -95,6 +101,27 @@ def run_tracking_data(

return all_tracking_datasets

@strawberry.field(
description="Get metrics data for a limited number of recent runs"
)
def run_metrics_data(self, limit: Optional[int] = 25) -> MetricPlotDataset:
run_ids = [
run.id for run in data_access_manager.runs.get_all_runs(limit_amount=limit)
]
group = TrackingDatasetGroup.METRIC

# pylint: disable=line-too-long
metric_dataset_models = data_access_manager.tracking_datasets.get_tracking_datasets_by_group_by_run_ids(
run_ids, group
)

metric_data = {}
for dataset in metric_dataset_models:
metric_data[dataset.dataset_name] = dataset.runs

formatted_metric_data = format_run_metric_data(metric_data)
return MetricPlotDataset(data=formatted_metric_data)


@strawberry.type
class Mutation:
Expand Down
67 changes: 67 additions & 0 deletions package/kedro_viz/api/graphql/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
from collections import defaultdict
from itertools import product
from typing import Dict, Iterable, List, Optional, cast

from strawberry import ID
Expand Down Expand Up @@ -128,3 +129,69 @@ def format_run_tracking_data(
del formatted_tracking_data[tracking_key]

return formatted_tracking_data


def format_run_metric_data(metric_data: Dict) -> Dict:
"""Format metric data to conforms to the schema required by plots on the front
end. Parallel Coordinate plots and Timeseries plots are supported.

Arguments:
metric_data: the data to format

Returns:
a dictionary containing metric data in two sub-dictionaries, containing
metric data aggregated by run_id and by metric respectively.
"""
formatted_metric_data = _initialise_metric_data_template(metric_data)
_populate_metric_data_template(metric_data, **formatted_metric_data)
return formatted_metric_data


def _initialise_metric_data_template(metric_data: Dict) -> Dict:
"""Initialise a dictionary to store formatted metric data.

Arguments:
metric_data: the data being formatted

Returns:
A dictionary with two sub-dictionaries containing lists (initialised
with `None` values) of the correct length for holding metric data
"""
runs: Dict = {}
metrics: Dict = {}
for dataset_name in metric_data:
dataset = metric_data[dataset_name]
for run_id in dataset:
runs[run_id] = []
for metric in dataset[run_id]:
metric_name = f"{dataset_name}.{metric}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think metric_name can be just dataset_name instead of dataset_name.metric

metrics[metric_name] = []

for empty_list in runs.values():
empty_list.extend([None] * len(metrics))
for empty_list in metrics.values():
empty_list.extend([None] * len(runs))

return {"metrics": metrics, "runs": runs}


def _populate_metric_data_template(
metric_data: Dict, runs: Dict, metrics: Dict
) -> None:
"""Populates two dictionaries containing uninitialised lists of
the correct length with metric data. Changes made in-place.

Arguments:
metric_data: the data to be being formatted
runs: a dictionary to store metric data aggregated by run
metrics: a dictionary to store metric data aggregated by metric
"""
print(metric_data)
for (run_idx, run_id), (metric_idx, metric) in product(
enumerate(runs), enumerate(metrics)
):
dataset_name_root, _, metric_name = metric.rpartition(".")
for dataset_name in metric_data:
if dataset_name_root == dataset_name:
value = metric_data[dataset_name][run_id].get(metric_name, None)
runs[run_id][metric_idx] = metrics[metric][run_idx] = value
5 changes: 5 additions & 0 deletions package/kedro_viz/api/graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class TrackingDataset:
run_ids: List[ID]


@strawberry.type(description="Metric data")
class MetricPlotDataset:
data: JSON


TrackingDatasetGroup = strawberry.enum(
TrackingDatasetGroupModel, description="Group to show kind of tracking data"
)
Expand Down
10 changes: 8 additions & 2 deletions package/kedro_viz/data_access/repositories/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@ def add_run(self, run: RunModel):
session.add(run)

@check_db_session
def get_all_runs(self) -> Optional[Iterable[RunModel]]:
def get_all_runs(
self, limit_amount: Optional[int] = None
) -> Optional[Iterable[RunModel]]:
all_runs = (
self._db_session_class() # type: ignore
.query(RunModel)
.order_by(RunModel.id.desc())
.all()
)

if limit_amount:
all_runs = all_runs.limit(limit_amount)
all_runs = all_runs.all()

if all_runs:
self.last_run_id = all_runs[0].id
return all_runs
Expand Down
45 changes: 45 additions & 0 deletions package/tests/test_api/test_graphql/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,51 @@ def test_run_tracking_data_query(

assert response.json() == expected_response

def test_metrics_data(
self,
client,
example_tracking_catalog,
data_access_manager_with_runs,
):
data_access_manager_with_runs.add_catalog(example_tracking_catalog)

response = client.post(
"/graphql",
json={
"query": "query MyQuery {\n runMetricsData(limit: 3) {\n data\n }\n}\n"
},
)

expected = {
"data": {
"runMetricsData": {
"data": {
"metrics": {
"metrics.col1": [1.0, None],
"metrics.col2": [2.0, None],
"metrics.col3": [3.0, None],
"more_metrics.col4": [4.0, None],
"more_metrics.col5": [5.0, None],
"more_metrics.col6": [6.0, None],
},
"runs": {
"2021-11-02T18.24.24.379Z": [
None,
None,
None,
None,
None,
None,
],
"2021-11-03T18.24.24.379Z": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
},
}
}
}
}

assert response.json() == expected

@pytest.mark.parametrize(
"show_diff,expected_response",
[
Expand Down
8 changes: 8 additions & 0 deletions src/apollo/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ The `JSON` scalar type represents JSON values as specified by [ECMA-404](http://
"""
scalar JSON @specifiedBy(url: "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf")

"""Metric data"""
type MetricPlotDataset {
data: JSON!
}

type Mutation {
"""Update run metadata"""
updateRunDetails(runId: ID!, runInput: RunInput!): UpdateRunDetailsResponse!
Expand All @@ -18,6 +23,9 @@ type Query {
"""Get tracking datasets for specified group and run_ids"""
runTrackingData(runIds: [ID!]!, group: TrackingDatasetGroup!, showDiff: Boolean = true): [TrackingDataset!]!

"""Get metrics data for a limited number of recent runs"""
runMetricsData(limit: Int = 25): MetricPlotDataset!

"""Get the installed and latest Kedro-Viz versions"""
version: Version!
}
Expand Down