Skip to content

Commit

Permalink
Refactor visualize dataset stats from DataNodeMetadata to DataNode (#…
Browse files Browse the repository at this point in the history
…1499)

* add stats to data node

* lint and format check fix

* fix pytests

* fix layout issue

* fix transcoded data stats
  • Loading branch information
ravi-kumar-pilla authored Aug 24, 2023
1 parent 67cc993 commit bc46f43
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 77 deletions.
2 changes: 2 additions & 0 deletions package/kedro_viz/api/rest/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Config:
class DataNodeAPIResponse(BaseGraphNodeAPIResponse):
layer: Optional[str]
dataset_type: Optional[str]
stats: Optional[Dict]

class Config:
schema_extra = {
Expand All @@ -75,6 +76,7 @@ class Config:
"type": "data",
"layer": "primary",
"dataset_type": "kedro.extras.datasets.pandas.csv_dataset.CSVDataSet",
"stats": {"rows": 10, "columns": 2, "file_size": 2300},
}
}

Expand Down
6 changes: 2 additions & 4 deletions package/kedro_viz/api/rest/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,10 @@ async def get_single_node_metadata(node_id: str):
return TaskNodeMetadata(node)

if isinstance(node, DataNode):
dataset_stats = data_access_manager.get_stats_for_data_node(node)
return DataNodeMetadata(node, dataset_stats)
return DataNodeMetadata(node)

if isinstance(node, TranscodedDataNode):
dataset_stats = data_access_manager.get_stats_for_data_node(node)
return TranscodedDataNodeMetadata(node, dataset_stats)
return TranscodedDataNodeMetadata(node)

return ParametersNodeMetadata(node)

Expand Down
13 changes: 6 additions & 7 deletions package/kedro_viz/data_access/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline as KedroPipeline
from kedro.pipeline.node import Node as KedroNode
from kedro.pipeline.pipeline import _strip_transcoding
from sqlalchemy.orm import sessionmaker

from kedro_viz.constants import DEFAULT_REGISTERED_PIPELINE_ID, ROOT_MODULAR_PIPELINE_ID
Expand Down Expand Up @@ -102,17 +103,14 @@ def add_dataset_stats(self, stats_dict: Dict):

self.dataset_stats = stats_dict

def get_stats_for_data_node(
self, data_node: Union[DataNode, TranscodedDataNode]
) -> Dict:
"""Returns the dataset statistics for the data node if found else returns an
empty dictionary
def get_stats_for_data_node(self, data_node_name: str) -> Union[Dict, None]:
"""Returns the dataset statistics for the data node if found
Args:
The data node for which we need the statistics
The data node name for which we need the statistics
"""

return self.dataset_stats.get(data_node.name, {})
return self.dataset_stats.get(data_node_name, None)

def add_pipeline(self, registered_pipeline_id: str, pipeline: KedroPipeline):
"""Iterate through all the nodes and datasets in a "registered" pipeline
Expand Down Expand Up @@ -278,6 +276,7 @@ def add_dataset(
layer=layer,
tags=set(),
dataset=obj,
stats=self.get_stats_for_data_node(_strip_transcoding(dataset_name)),
is_free_input=is_free_input,
)
graph_node = self.nodes.add_node(graph_node)
Expand Down
23 changes: 15 additions & 8 deletions package/kedro_viz/models/flowchart.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def create_data_node(
layer: Optional[str],
tags: Set[str],
dataset: AbstractDataset,
stats: Optional[Dict],
is_free_input: bool = False,
) -> Union["DataNode", "TranscodedDataNode"]:
"""Create a graph node of type DATA for a given Kedro DataSet instance.
Expand All @@ -188,6 +189,8 @@ def create_data_node(
tags: The set of tags assigned to assign to the graph representation
of this dataset. N.B. currently it's derived from the node's tags.
dataset: A dataset in a Kedro pipeline.
stats: The dictionary of dataset statistics, e.g.
{"rows":2, "columns":3, "file_size":100}
is_free_input: Whether the dataset is a free input in the pipeline
Returns:
An instance of DataNode.
Expand All @@ -201,6 +204,7 @@ def create_data_node(
tags=tags,
layer=layer,
is_free_input=is_free_input,
stats=stats,
)

return DataNode(
Expand All @@ -210,6 +214,7 @@ def create_data_node(
layer=layer,
kedro_obj=dataset,
is_free_input=is_free_input,
stats=stats,
)

@classmethod
Expand Down Expand Up @@ -434,6 +439,9 @@ class DataNode(GraphNode):
# the type of this graph node, which is DATA
type: str = GraphNodeType.DATA.value

# statistics for the data node
stats: Optional[Dict] = field(default=None)

def __post_init__(self):
self.dataset_type = get_dataset_type(self.kedro_obj)

Expand Down Expand Up @@ -517,6 +525,9 @@ class TranscodedDataNode(GraphNode):
# the type of this graph node, which is DATA
type: str = GraphNodeType.DATA.value

# statistics for the data node
stats: Optional[Dict] = field(default=None)

def has_metadata(self) -> bool:
return True

Expand All @@ -541,7 +552,6 @@ class DataNodeMetadata(GraphNodeMetadata):

# the underlying data node to which this metadata belongs
data_node: InitVar[DataNode]
dataset_stats: InitVar[Dict]

# the optional plot data if the underlying dataset has a plot.
# currently only applicable for PlotlyDataSet
Expand All @@ -561,12 +571,12 @@ class DataNodeMetadata(GraphNodeMetadata):
stats: Optional[Dict] = field(init=False, default=None)

# TODO: improve this scheme.
def __post_init__(self, data_node: DataNode, dataset_stats: Dict):
def __post_init__(self, data_node: DataNode):
self.type = data_node.dataset_type
dataset = cast(AbstractDataset, data_node.kedro_obj)
dataset_description = dataset._describe()
self.filepath = _parse_filepath(dataset_description)
self.stats = dataset_stats
self.stats = data_node.stats

# Run command is only available if a node is an output, i.e. not a free input
if not data_node.is_free_input:
Expand Down Expand Up @@ -625,11 +635,8 @@ class TranscodedDataNodeMetadata(GraphNodeMetadata):

# the underlying data node to which this metadata belongs
transcoded_data_node: InitVar[TranscodedDataNode]
dataset_stats: InitVar[Dict]

def __post_init__(
self, transcoded_data_node: TranscodedDataNode, dataset_stats: Dict
):
def __post_init__(self, transcoded_data_node: TranscodedDataNode):
original_version = transcoded_data_node.original_version

self.original_type = get_dataset_type(original_version)
Expand All @@ -640,7 +647,7 @@ def __post_init__(

dataset_description = original_version._describe()
self.filepath = _parse_filepath(dataset_description)
self.stats = dataset_stats
self.stats = transcoded_data_node.stats

if not transcoded_data_node.is_free_input:
self.run_command = (
Expand Down
5 changes: 4 additions & 1 deletion package/kedro_viz/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ def populate_data(
data_access_manager.set_db_session(session_class)

data_access_manager.add_catalog(catalog)
data_access_manager.add_pipelines(pipelines)

# add dataset stats before adding pipelines
data_access_manager.add_dataset_stats(stats_dict)

data_access_manager.add_pipelines(pipelines)


def run_server(
host: str = DEFAULT_HOST,
Expand Down
20 changes: 18 additions & 2 deletions package/tests/test_api/test_rest/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def assert_example_data(response_data):
"type": "data",
"layer": "raw",
"dataset_type": "pandas.csv_dataset.CSVDataSet",
"stats": None,
},
{
"id": "f0ebef01",
Expand All @@ -118,6 +119,7 @@ def assert_example_data(response_data):
"type": "parameters",
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "0ecea0de",
Expand All @@ -128,6 +130,7 @@ def assert_example_data(response_data):
"type": "data",
"layer": "model_inputs",
"dataset_type": "pandas.csv_dataset.CSVDataSet",
"stats": {"columns": 12, "rows": 29768},
},
{
"id": "7b140b3f",
Expand All @@ -150,6 +153,7 @@ def assert_example_data(response_data):
"type": "parameters",
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "d5a8b994",
Expand All @@ -160,6 +164,7 @@ def assert_example_data(response_data):
"type": "data",
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
{
"id": "uk.data_processing",
Expand All @@ -170,6 +175,7 @@ def assert_example_data(response_data):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "uk.data_science",
Expand All @@ -180,6 +186,7 @@ def assert_example_data(response_data):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "uk",
Expand All @@ -190,6 +197,7 @@ def assert_example_data(response_data):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
]
assert_nodes_equal(response_data.pop("nodes"), expected_nodes)
Expand Down Expand Up @@ -480,6 +488,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
{
"id": "f0ebef01",
Expand All @@ -490,6 +499,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": ["uk", "uk.data_processing"],
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "0ecea0de",
Expand All @@ -500,6 +510,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "2302ea78",
Expand All @@ -519,6 +530,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "1d06a0d7",
Expand All @@ -529,6 +541,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
]

Expand Down Expand Up @@ -572,7 +585,6 @@ def test_transcoded_data_node_metadata(self, example_transcoded_api):
"pandas.parquet_dataset.ParquetDataSet",
],
"run_command": "kedro run --to-outputs=model_inputs@pandas2",
"stats": {},
}


Expand Down Expand Up @@ -614,7 +626,6 @@ def test_data_node_metadata_for_free_input(self, client):
assert response.json() == {
"filepath": "raw_data.csv",
"type": "pandas.csv_dataset.CSVDataSet",
"stats": {},
}

def test_parameters_node_metadata(self, client):
Expand Down Expand Up @@ -664,6 +675,7 @@ def test_get_pipeline(self, client):
"type": "data",
"layer": "model_inputs",
"dataset_type": "pandas.csv_dataset.CSVDataSet",
"stats": {"columns": 12, "rows": 29768},
},
{
"id": "7b140b3f",
Expand All @@ -686,6 +698,7 @@ def test_get_pipeline(self, client):
"type": "parameters",
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "d5a8b994",
Expand All @@ -696,6 +709,7 @@ def test_get_pipeline(self, client):
"type": "data",
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
{
"id": "uk",
Expand All @@ -706,6 +720,7 @@ def test_get_pipeline(self, client):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "uk.data_science",
Expand All @@ -716,6 +731,7 @@ def test_get_pipeline(self, client):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
]
assert_nodes_equal(response_data.pop("nodes"), expected_nodes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_add_input(self):
layer="model",
tags=set(),
dataset=kedro_dataset,
stats=None,
)
modular_pipelines.add_input("data_science", data_node)
assert data_node.id in data_science_pipeline.inputs
Expand All @@ -62,6 +63,7 @@ def test_add_output(self):
layer="model",
tags=set(),
dataset=kedro_dataset,
stats=None,
)
modular_pipelines.add_output("data_science", data_node)
assert data_node.id in data_science_pipeline.outputs
Expand Down
Loading

0 comments on commit bc46f43

Please sign in to comment.