Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tag being undefined bug from the backend. #2162

Merged
merged 9 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions package/kedro_viz/data_access/repositories/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def __init__(self):
self.nodes_dict: Dict[str, GraphNode] = {}
self.nodes_list: List[GraphNode] = []

def has_node(self, node: GraphNode) -> bool:
return node.id in self.nodes_dict

def add_node(self, node: GraphNode) -> GraphNode:
if not self.has_node(node):
existing_node = self.nodes_dict.get(node.id)
if existing_node:
# Update tags or other attributes if the node already exists
existing_node.tags.update(node.tags)
else:
self.nodes_dict[node.id] = node
self.nodes_list.append(node)
return self.nodes_dict[node.id]
Expand Down
53 changes: 53 additions & 0 deletions package/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def example_pipeline_with_node_namespaces():
inputs=["raw_transaction_data", "cleaned_transaction_data"],
outputs="validated_transaction_data",
name="validation_node",
tags=["validation"],
),
node(
func=lambda validated_data, enrichment_data: (
Expand Down Expand Up @@ -381,6 +382,23 @@ def edge_case_example_pipelines(
}


@pytest.fixture
def example_pipelines_with_additional_tags(example_pipeline_with_node_namespaces):
"""
Fixture to mock the use cases mentioned in
https://github.com/kedro-org/kedro-viz/issues/2106
"""

pipelines_dict = {
"pipeline": example_pipeline_with_node_namespaces,
"pipeline_with_tags": pipeline(
example_pipeline_with_node_namespaces, tags=["tag1", "tag2"]
),
}

yield pipelines_dict


@pytest.fixture
def expected_modular_pipeline_tree_for_edge_cases():
expected_tree_for_edge_cases_file_path = (
Expand Down Expand Up @@ -554,6 +572,41 @@ def example_api_for_edge_case_pipelines(
yield api


@pytest.fixture
def example_api_for_pipelines_with_additional_tags(
data_access_manager: DataAccessManager,
example_pipelines_with_additional_tags: Dict[str, Pipeline],
example_catalog: DataCatalog,
session_store: BaseSessionStore,
mocker,
):
api = apps.create_api_app_from_project(mock.MagicMock())

# For readability we are not hashing the node id
mocker.patch("kedro_viz.utils._hash", side_effect=lambda value: value)
mocker.patch(
"kedro_viz.data_access.repositories.modular_pipelines._hash",
side_effect=lambda value: value,
)

populate_data(
data_access_manager,
example_catalog,
example_pipelines_with_additional_tags,
session_store,
{},
)
mocker.patch(
"kedro_viz.api.rest.responses.pipelines.data_access_manager",
new=data_access_manager,
)
mocker.patch(
"kedro_viz.api.rest.responses.nodes.data_access_manager",
new=data_access_manager,
)
yield api


@pytest.fixture
def example_transcoded_api(
data_access_manager: DataAccessManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ def test_endpoint_main_no_default_pipeline(self, example_api_no_default_pipeline
{"id": "data_processing", "name": "data_processing"},
]

def test_endpoint_main_for_pipelines_with_additional_tags(
self,
example_api_for_pipelines_with_additional_tags,
):
expected_tags = [
{"id": "tag1", "name": "tag1"},
{"id": "tag2", "name": "tag2"},
{"id": "validation", "name": "validation"},
]
client = TestClient(example_api_for_pipelines_with_additional_tags)
response = client.get("/api/main")
actual_tags = response.json()["tags"]
assert actual_tags == expected_tags

def test_endpoint_main_for_edge_case_pipelines(
self,
example_api_for_edge_case_pipelines,
Expand Down
Loading