Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor visualize dataset stats from DataNodeMetadata to DataNode #1499

Merged
merged 6 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package/kedro_viz/api/rest/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Config:
class DataNodeAPIResponse(BaseGraphNodeAPIResponse):
layer: Optional[str]
dataset_type: Optional[str]
stats: Optional[Dict]

class Config:
schema_extra = {
Expand All @@ -75,6 +76,7 @@ class Config:
"type": "data",
"layer": "primary",
"dataset_type": "kedro.extras.datasets.pandas.csv_dataset.CSVDataSet",
"stats": {"rows": 10, "columns": 2, "file_size": 2300},
}
}

Expand Down
6 changes: 2 additions & 4 deletions package/kedro_viz/api/rest/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,10 @@ async def get_single_node_metadata(node_id: str):
return TaskNodeMetadata(node)

if isinstance(node, DataNode):
dataset_stats = data_access_manager.get_stats_for_data_node(node)
return DataNodeMetadata(node, dataset_stats)
return DataNodeMetadata(node)

if isinstance(node, TranscodedDataNode):
dataset_stats = data_access_manager.get_stats_for_data_node(node)
return TranscodedDataNodeMetadata(node, dataset_stats)
return TranscodedDataNodeMetadata(node)

return ParametersNodeMetadata(node)

Expand Down
13 changes: 6 additions & 7 deletions package/kedro_viz/data_access/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from kedro.io import DataCatalog
from kedro.pipeline import Pipeline as KedroPipeline
from kedro.pipeline.node import Node as KedroNode
from kedro.pipeline.pipeline import _strip_transcoding
from sqlalchemy.orm import sessionmaker

from kedro_viz.constants import DEFAULT_REGISTERED_PIPELINE_ID, ROOT_MODULAR_PIPELINE_ID
Expand Down Expand Up @@ -102,17 +103,14 @@ def add_dataset_stats(self, stats_dict: Dict):

self.dataset_stats = stats_dict

def get_stats_for_data_node(
self, data_node: Union[DataNode, TranscodedDataNode]
) -> Dict:
"""Returns the dataset statistics for the data node if found else returns an
empty dictionary
def get_stats_for_data_node(self, data_node_name: str) -> Union[Dict, None]:
"""Returns the dataset statistics for the data node if found

Args:
The data node for which we need the statistics
The data node name for which we need the statistics
"""

return self.dataset_stats.get(data_node.name, {})
return self.dataset_stats.get(data_node_name, None)

def add_pipeline(self, registered_pipeline_id: str, pipeline: KedroPipeline):
"""Iterate through all the nodes and datasets in a "registered" pipeline
Expand Down Expand Up @@ -278,6 +276,7 @@ def add_dataset(
layer=layer,
tags=set(),
dataset=obj,
stats=self.get_stats_for_data_node(_strip_transcoding(dataset_name)),
ravi-kumar-pilla marked this conversation as resolved.
Show resolved Hide resolved
is_free_input=is_free_input,
)
graph_node = self.nodes.add_node(graph_node)
Expand Down
23 changes: 15 additions & 8 deletions package/kedro_viz/models/flowchart.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def create_data_node(
layer: Optional[str],
tags: Set[str],
dataset: AbstractDataset,
stats: Optional[Dict],
is_free_input: bool = False,
) -> Union["DataNode", "TranscodedDataNode"]:
"""Create a graph node of type DATA for a given Kedro DataSet instance.
Expand All @@ -188,6 +189,8 @@ def create_data_node(
tags: The set of tags assigned to assign to the graph representation
of this dataset. N.B. currently it's derived from the node's tags.
dataset: A dataset in a Kedro pipeline.
stats: The dictionary of dataset statistics, e.g.
{"rows":2, "columns":3, "file_size":100}
is_free_input: Whether the dataset is a free input in the pipeline
Returns:
An instance of DataNode.
Expand All @@ -201,6 +204,7 @@ def create_data_node(
tags=tags,
layer=layer,
is_free_input=is_free_input,
stats=stats,
)

return DataNode(
Expand All @@ -210,6 +214,7 @@ def create_data_node(
layer=layer,
kedro_obj=dataset,
is_free_input=is_free_input,
stats=stats,
)

@classmethod
Expand Down Expand Up @@ -434,6 +439,9 @@ class DataNode(GraphNode):
# the type of this graph node, which is DATA
type: str = GraphNodeType.DATA.value

# statistics for the data node
stats: Optional[Dict] = field(default=None)

def __post_init__(self):
self.dataset_type = get_dataset_type(self.kedro_obj)

Expand Down Expand Up @@ -517,6 +525,9 @@ class TranscodedDataNode(GraphNode):
# the type of this graph node, which is DATA
type: str = GraphNodeType.DATA.value

# statistics for the data node
stats: Optional[Dict] = field(default=None)

def has_metadata(self) -> bool:
return True

Expand All @@ -541,7 +552,6 @@ class DataNodeMetadata(GraphNodeMetadata):

# the underlying data node to which this metadata belongs
data_node: InitVar[DataNode]
dataset_stats: InitVar[Dict]

# the optional plot data if the underlying dataset has a plot.
# currently only applicable for PlotlyDataSet
Expand All @@ -561,12 +571,12 @@ class DataNodeMetadata(GraphNodeMetadata):
stats: Optional[Dict] = field(init=False, default=None)

# TODO: improve this scheme.
def __post_init__(self, data_node: DataNode, dataset_stats: Dict):
def __post_init__(self, data_node: DataNode):
self.type = data_node.dataset_type
dataset = cast(AbstractDataset, data_node.kedro_obj)
dataset_description = dataset._describe()
self.filepath = _parse_filepath(dataset_description)
self.stats = dataset_stats
self.stats = data_node.stats

# Run command is only available if a node is an output, i.e. not a free input
if not data_node.is_free_input:
Expand Down Expand Up @@ -625,11 +635,8 @@ class TranscodedDataNodeMetadata(GraphNodeMetadata):

# the underlying data node to which this metadata belongs
transcoded_data_node: InitVar[TranscodedDataNode]
dataset_stats: InitVar[Dict]

def __post_init__(
self, transcoded_data_node: TranscodedDataNode, dataset_stats: Dict
):
def __post_init__(self, transcoded_data_node: TranscodedDataNode):
original_version = transcoded_data_node.original_version

self.original_type = get_dataset_type(original_version)
Expand All @@ -640,7 +647,7 @@ def __post_init__(

dataset_description = original_version._describe()
self.filepath = _parse_filepath(dataset_description)
self.stats = dataset_stats
self.stats = transcoded_data_node.stats

if not transcoded_data_node.is_free_input:
self.run_command = (
Expand Down
5 changes: 4 additions & 1 deletion package/kedro_viz/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ def populate_data(
data_access_manager.set_db_session(session_class)

data_access_manager.add_catalog(catalog)
data_access_manager.add_pipelines(pipelines)

# add dataset stats before adding pipelines
data_access_manager.add_dataset_stats(stats_dict)

data_access_manager.add_pipelines(pipelines)


def run_server(
host: str = DEFAULT_HOST,
Expand Down
20 changes: 18 additions & 2 deletions package/tests/test_api/test_rest/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def assert_example_data(response_data):
"type": "data",
"layer": "raw",
"dataset_type": "pandas.csv_dataset.CSVDataSet",
"stats": None,
},
{
"id": "f0ebef01",
Expand All @@ -118,6 +119,7 @@ def assert_example_data(response_data):
"type": "parameters",
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "0ecea0de",
Expand All @@ -128,6 +130,7 @@ def assert_example_data(response_data):
"type": "data",
"layer": "model_inputs",
"dataset_type": "pandas.csv_dataset.CSVDataSet",
"stats": {"columns": 12, "rows": 29768},
},
{
"id": "7b140b3f",
Expand All @@ -150,6 +153,7 @@ def assert_example_data(response_data):
"type": "parameters",
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "d5a8b994",
Expand All @@ -160,6 +164,7 @@ def assert_example_data(response_data):
"type": "data",
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
{
"id": "uk.data_processing",
Expand All @@ -170,6 +175,7 @@ def assert_example_data(response_data):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "uk.data_science",
Expand All @@ -180,6 +186,7 @@ def assert_example_data(response_data):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "uk",
Expand All @@ -190,6 +197,7 @@ def assert_example_data(response_data):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
]
assert_nodes_equal(response_data.pop("nodes"), expected_nodes)
Expand Down Expand Up @@ -480,6 +488,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
{
"id": "f0ebef01",
Expand All @@ -490,6 +499,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": ["uk", "uk.data_processing"],
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "0ecea0de",
Expand All @@ -500,6 +510,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "2302ea78",
Expand All @@ -519,6 +530,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "1d06a0d7",
Expand All @@ -529,6 +541,7 @@ def assert_example_transcoded_data(response_data):
"modular_pipelines": [],
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
]

Expand Down Expand Up @@ -572,7 +585,6 @@ def test_transcoded_data_node_metadata(self, example_transcoded_api):
"pandas.parquet_dataset.ParquetDataSet",
],
"run_command": "kedro run --to-outputs=model_inputs@pandas2",
"stats": {},
}


Expand Down Expand Up @@ -614,7 +626,6 @@ def test_data_node_metadata_for_free_input(self, client):
assert response.json() == {
"filepath": "raw_data.csv",
"type": "pandas.csv_dataset.CSVDataSet",
"stats": {},
}

def test_parameters_node_metadata(self, client):
Expand Down Expand Up @@ -664,6 +675,7 @@ def test_get_pipeline(self, client):
"type": "data",
"layer": "model_inputs",
"dataset_type": "pandas.csv_dataset.CSVDataSet",
"stats": {"columns": 12, "rows": 29768},
},
{
"id": "7b140b3f",
Expand All @@ -686,6 +698,7 @@ def test_get_pipeline(self, client):
"type": "parameters",
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "d5a8b994",
Expand All @@ -696,6 +709,7 @@ def test_get_pipeline(self, client):
"type": "data",
"layer": None,
"dataset_type": "io.memory_dataset.MemoryDataset",
"stats": None,
},
{
"id": "uk",
Expand All @@ -706,6 +720,7 @@ def test_get_pipeline(self, client):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
{
"id": "uk.data_science",
Expand All @@ -716,6 +731,7 @@ def test_get_pipeline(self, client):
"modular_pipelines": None,
"layer": None,
"dataset_type": None,
"stats": None,
},
]
assert_nodes_equal(response_data.pop("nodes"), expected_nodes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_add_input(self):
layer="model",
tags=set(),
dataset=kedro_dataset,
stats=None,
)
modular_pipelines.add_input("data_science", data_node)
assert data_node.id in data_science_pipeline.inputs
Expand All @@ -62,6 +63,7 @@ def test_add_output(self):
layer="model",
tags=set(),
dataset=kedro_dataset,
stats=None,
)
modular_pipelines.add_output("data_science", data_node)
assert data_node.id in data_science_pipeline.outputs
Expand Down
Loading