diff --git a/package/kedro_viz/api/graphql/schema.py b/package/kedro_viz/api/graphql/schema.py index 1515ede008..43fdf381bc 100644 --- a/package/kedro_viz/api/graphql/schema.py +++ b/package/kedro_viz/api/graphql/schema.py @@ -18,8 +18,14 @@ from kedro_viz.data_access import data_access_manager from kedro_viz.integrations.pypi import get_latest_version, is_running_outdated_version -from .serializers import format_run, format_run_tracking_data, format_runs +from .serializers import ( + format_run, + format_run_metric_data, + format_run_tracking_data, + format_runs, +) from .types import ( + MetricPlotDataset, Run, RunInput, TrackingDataset, @@ -95,6 +101,27 @@ def run_tracking_data( return all_tracking_datasets + @strawberry.field( + description="Get metrics data for a limited number of recent runs" + ) + def run_metrics_data(self, limit: Optional[int] = 25) -> MetricPlotDataset: + run_ids = [ + run.id for run in data_access_manager.runs.get_all_runs(limit_amount=limit) + ] + group = TrackingDatasetGroup.METRIC + + # pylint: disable=line-too-long + metric_dataset_models = data_access_manager.tracking_datasets.get_tracking_datasets_by_group_by_run_ids( + run_ids, group + ) + + metric_data = {} + for dataset in metric_dataset_models: + metric_data[dataset.dataset_name] = dataset.runs + + formatted_metric_data = format_run_metric_data(metric_data) + return MetricPlotDataset(data=formatted_metric_data) + @strawberry.type class Mutation: diff --git a/package/kedro_viz/api/graphql/serializers.py b/package/kedro_viz/api/graphql/serializers.py index a014df2c91..ce5ec774bf 100644 --- a/package/kedro_viz/api/graphql/serializers.py +++ b/package/kedro_viz/api/graphql/serializers.py @@ -5,6 +5,7 @@ import json from collections import defaultdict +from itertools import product from typing import Dict, Iterable, List, Optional, cast from strawberry import ID @@ -128,3 +129,69 @@ def format_run_tracking_data( del formatted_tracking_data[tracking_key] return formatted_tracking_data + + +def format_run_metric_data(metric_data: Dict) -> Dict: + """Format metric data to conforms to the schema required by plots on the front + end. Parallel Coordinate plots and Timeseries plots are supported. + + Arguments: + metric_data: the data to format + + Returns: + a dictionary containing metric data in two sub-dictionaries, containing + metric data aggregated by run_id and by metric respectively. + """ + formatted_metric_data = _initialise_metric_data_template(metric_data) + _populate_metric_data_template(metric_data, **formatted_metric_data) + return formatted_metric_data + + +def _initialise_metric_data_template(metric_data: Dict) -> Dict: + """Initialise a dictionary to store formatted metric data. + + Arguments: + metric_data: the data being formatted + + Returns: + A dictionary with two sub-dictionaries containing lists (initialised + with `None` values) of the correct length for holding metric data + """ + runs: Dict = {} + metrics: Dict = {} + for dataset_name in metric_data: + dataset = metric_data[dataset_name] + for run_id in dataset: + runs[run_id] = [] + for metric in dataset[run_id]: + metric_name = f"{dataset_name}.{metric}" + metrics[metric_name] = [] + + for empty_list in runs.values(): + empty_list.extend([None] * len(metrics)) + for empty_list in metrics.values(): + empty_list.extend([None] * len(runs)) + + return {"metrics": metrics, "runs": runs} + + +def _populate_metric_data_template( + metric_data: Dict, runs: Dict, metrics: Dict +) -> None: + """Populates two dictionaries containing uninitialised lists of + the correct length with metric data. Changes made in-place. + + Arguments: + metric_data: the data to be being formatted + runs: a dictionary to store metric data aggregated by run + metrics: a dictionary to store metric data aggregated by metric + """ + print(metric_data) + for (run_idx, run_id), (metric_idx, metric) in product( + enumerate(runs), enumerate(metrics) + ): + dataset_name_root, _, metric_name = metric.rpartition(".") + for dataset_name in metric_data: + if dataset_name_root == dataset_name: + value = metric_data[dataset_name][run_id].get(metric_name, None) + runs[run_id][metric_idx] = metrics[metric][run_idx] = value diff --git a/package/kedro_viz/api/graphql/types.py b/package/kedro_viz/api/graphql/types.py index 312777f78d..e45a9e6cb3 100644 --- a/package/kedro_viz/api/graphql/types.py +++ b/package/kedro_viz/api/graphql/types.py @@ -34,6 +34,11 @@ class TrackingDataset: run_ids: List[ID] +@strawberry.type(description="Metric data") +class MetricPlotDataset: + data: JSON + + TrackingDatasetGroup = strawberry.enum( TrackingDatasetGroupModel, description="Group to show kind of tracking data" ) diff --git a/package/kedro_viz/data_access/repositories/runs.py b/package/kedro_viz/data_access/repositories/runs.py index 429a2a589f..1cbee55ec4 100644 --- a/package/kedro_viz/data_access/repositories/runs.py +++ b/package/kedro_viz/data_access/repositories/runs.py @@ -43,13 +43,19 @@ def add_run(self, run: RunModel): session.add(run) @check_db_session - def get_all_runs(self) -> Optional[Iterable[RunModel]]: + def get_all_runs( + self, limit_amount: Optional[int] = None + ) -> Optional[Iterable[RunModel]]: all_runs = ( self._db_session_class() # type: ignore .query(RunModel) .order_by(RunModel.id.desc()) - .all() ) + + if limit_amount: + all_runs = all_runs.limit(limit_amount) + all_runs = all_runs.all() + if all_runs: self.last_run_id = all_runs[0].id return all_runs diff --git a/package/tests/test_api/test_graphql/test_queries.py b/package/tests/test_api/test_graphql/test_queries.py index 13bb1a796d..b79d09946a 100644 --- a/package/tests/test_api/test_graphql/test_queries.py +++ b/package/tests/test_api/test_graphql/test_queries.py @@ -171,6 +171,51 @@ def test_run_tracking_data_query( assert response.json() == expected_response + def test_metrics_data( + self, + client, + example_tracking_catalog, + data_access_manager_with_runs, + ): + data_access_manager_with_runs.add_catalog(example_tracking_catalog) + + response = client.post( + "/graphql", + json={ + "query": "query MyQuery {\n runMetricsData(limit: 3) {\n data\n }\n}\n" + }, + ) + + expected = { + "data": { + "runMetricsData": { + "data": { + "metrics": { + "metrics.col1": [1.0, None], + "metrics.col2": [2.0, None], + "metrics.col3": [3.0, None], + "more_metrics.col4": [4.0, None], + "more_metrics.col5": [5.0, None], + "more_metrics.col6": [6.0, None], + }, + "runs": { + "2021-11-02T18.24.24.379Z": [ + None, + None, + None, + None, + None, + None, + ], + "2021-11-03T18.24.24.379Z": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + }, + } + } + } + } + + assert response.json() == expected + @pytest.mark.parametrize( "show_diff,expected_response", [ diff --git a/src/apollo/schema.graphql b/src/apollo/schema.graphql index a0bc5162d8..1cceb6035d 100644 --- a/src/apollo/schema.graphql +++ b/src/apollo/schema.graphql @@ -3,6 +3,11 @@ The `JSON` scalar type represents JSON values as specified by [ECMA-404](http:// """ scalar JSON @specifiedBy(url: "http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-404.pdf") +"""Metric data""" +type MetricPlotDataset { + data: JSON! +} + type Mutation { """Update run metadata""" updateRunDetails(runId: ID!, runInput: RunInput!): UpdateRunDetailsResponse! @@ -18,6 +23,9 @@ type Query { """Get tracking datasets for specified group and run_ids""" runTrackingData(runIds: [ID!]!, group: TrackingDatasetGroup!, showDiff: Boolean = true): [TrackingDataset!]! + """Get metrics data for a limited number of recent runs""" + runMetricsData(limit: Int = 25): MetricPlotDataset! + """Get the installed and latest Kedro-Viz versions""" version: Version! }