Skip to content

Commit

Permalink
Merge pull request #5244 from opsmill/ajtm-12112024-gql-single-data-l…
Browse files Browse the repository at this point in the history
…oader-2

single relationship data loader for graphql
  • Loading branch information
dgarros authored Dec 19, 2024
2 parents e2b7cf2 + 4e9244e commit 9fffe68
Show file tree
Hide file tree
Showing 18 changed files with 669 additions and 132 deletions.
189 changes: 144 additions & 45 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from functools import reduce
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Iterable, Literal, Optional, TypeVar, Union, overload

from infrahub_sdk.utils import deep_merge_dict, is_valid_uuid

from infrahub.core.constants import RelationshipCardinality, RelationshipDirection
from infrahub.core.node import Node
from infrahub.core.node.delete_validator import NodeDeleteValidator
from infrahub.core.query.node import (
Expand Down Expand Up @@ -1077,7 +1078,7 @@ async def get_one(
return node

@classmethod
async def get_many( # pylint: disable=too-many-branches,too-many-statements
async def get_many(
cls,
db: InfrahubDatabase,
ids: list[str],
Expand Down Expand Up @@ -1137,29 +1138,6 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements
profile_attributes_id_map=profile_attributes, profile_ids_by_node_id=profile_ids_by_node_id
)

# if prefetch_relationships is enabled
# Query all the peers associated with all nodes at once.
grouped_peer_nodes: GroupedPeerNodes | None = None
peers: dict[str, Node] = {}
if prefetch_relationships:
query = await NodeListGetRelationshipsQuery.init(
db=db, ids=ids, branch=branch, at=at, branch_agnostic=branch_agnostic
)
await query.execute(db=db)
grouped_peer_nodes = query.get_peers_group_by_node()
peer_ids = grouped_peer_nodes.get_all_peers()

# only query the peers that are not already part of the main list
peer_ids -= set(ids)
peers = await cls.get_many(
ids=list(peer_ids),
branch=branch,
at=at,
db=db,
include_owner=include_owner,
include_source=include_source,
)

nodes: dict[str, Node] = {}

for node_id in ids: # pylint: disable=too-many-nested-blocks
Expand Down Expand Up @@ -1195,29 +1173,127 @@ async def get_many( # pylint: disable=too-many-branches,too-many-statements

nodes[node_id] = item

# --------------------------------------------------------
# Relationships
# --------------------------------------------------------
if prefetch_relationships and grouped_peer_nodes:
for node_id, node in nodes.items():
if not grouped_peer_nodes.has_node(node_id=node_id):
continue
await cls._enrich_node_dicts_with_relationships(
db=db,
branch=branch,
at=at,
nodes_by_id=nodes,
branch_agnostic=branch_agnostic,
include_owner=include_owner,
include_source=include_source,
prefetch_relationships=prefetch_relationships,
fields=fields,
)

for rel_schema in node._schema.relationships:
peer_ids = grouped_peer_nodes.get_peer_ids(
node_id=node_id, rel_name=rel_schema.identifier, direction=rel_schema.direction
)
if not peer_ids:
continue
rel_peers = [peers.get(peer_id, None) or nodes.get(peer_id) for peer_id in peer_ids]
rel_manager: RelationshipManager = getattr(node, rel_schema.name)
if rel_schema.cardinality == "one" and len(rel_peers) > 1:
raise ValueError("Only one relationship expected")
return nodes

rel_manager.has_fetched_relationships = True
await rel_manager.update(db=db, data=rel_peers)
@classmethod
async def _enrich_node_dicts_with_relationships(
cls,
db: InfrahubDatabase,
branch: Branch,
at: Timestamp,
nodes_by_id: dict[str, Node],
branch_agnostic: bool,
include_owner: bool,
include_source: bool,
prefetch_relationships: bool,
fields: dict[str, Any] | None,
) -> None:
if not prefetch_relationships and not fields:
return
cardinality_one_identifiers_by_kind: dict[str, dict[str, RelationshipDirection]] | None = None
all_identifiers: list[str] | None = None
if not prefetch_relationships:
cardinality_one_identifiers_by_kind = _get_cardinality_one_identifiers_by_kind(
nodes=nodes_by_id.values(), fields=fields or {}
)
all_identifiers_set: set[str] = set()
for identifier_direction_map in cardinality_one_identifiers_by_kind.values():
all_identifiers_set.update(identifier_direction_map.keys())
all_identifiers = list(all_identifiers_set)

return nodes
query = await NodeListGetRelationshipsQuery.init(
db=db,
ids=list(nodes_by_id.keys()),
relationship_identifiers=all_identifiers,
branch=branch,
at=at,
branch_agnostic=branch_agnostic,
)
await query.execute(db=db)
grouped_peer_nodes = query.get_peers_group_by_node()
peer_ids = grouped_peer_nodes.get_all_peers()
# there are no peers to enrich the nodes
if not peer_ids:
return

missing_peers: dict[str, Node] = {}
if prefetch_relationships:
# only query the peers that are not already part of the main list
missing_peer_ids = peer_ids - set(nodes_by_id.keys())
missing_peers = await cls.get_many(
ids=list(missing_peer_ids),
branch=branch,
at=at,
db=db,
include_owner=include_owner,
include_source=include_source,
)

for node in nodes_by_id.values():
await cls._enrich_one_node_with_relationships(
db=db,
node=node,
grouped_peer_nodes=grouped_peer_nodes,
nodes_by_id=nodes_by_id | missing_peers,
cardinality_one_identifiers_by_kind=cardinality_one_identifiers_by_kind,
insert_peer_node=prefetch_relationships,
)

@classmethod
async def _enrich_one_node_with_relationships(
cls,
db: InfrahubDatabase,
node: Node,
grouped_peer_nodes: GroupedPeerNodes,
nodes_by_id: dict[str, Node],
cardinality_one_identifiers_by_kind: dict[str, dict[str, RelationshipDirection]] | None,
insert_peer_node: bool,
) -> None:
if not grouped_peer_nodes.has_node(node_id=node.get_id()):
return

node_schema = node.get_schema()
for rel_schema in node_schema.relationships:
peer_ids = grouped_peer_nodes.get_peer_ids(
node_id=node.get_id(), rel_name=rel_schema.get_identifier(), direction=rel_schema.direction
)
if not peer_ids:
continue

rel_manager: RelationshipManager = getattr(node, rel_schema.name)
if insert_peer_node:
rel_peers: list[Node | str] = []
for peer_id in peer_ids:
peer = nodes_by_id.get(peer_id)
if peer:
rel_peers.append(peer)
# if only getting some relationships, make sure we want THIS relationship for THIS node schema
elif cardinality_one_identifiers_by_kind:
required_direction = cardinality_one_identifiers_by_kind.get(node_schema.kind, {}).get(
rel_schema.get_identifier()
)
if required_direction is not rel_schema.direction:
continue
rel_peers = list(peer_ids)
else:
continue
if rel_schema.cardinality is RelationshipCardinality.ONE and len(rel_peers) > 1:
raise ValueError("At most, one relationship expected")

rel_manager.has_fetched_relationships = True
await rel_manager.update(db=db, data=rel_peers)

@classmethod
async def delete(
Expand All @@ -1244,4 +1320,27 @@ async def delete(
return deleted_nodes


def _get_cardinality_one_identifiers_by_kind(
nodes: Iterable[Node],
fields: dict[str, Any],
) -> dict[str, dict[str, RelationshipDirection]]:
# {kind: {relationship_identifier, ...}}
cardinality_one_fields_by_kind = {}
field_names_set = set(fields.keys())
for node in nodes:
node_schema = node.get_schema()
if not node_schema:
continue
# already handled this schema
if node_schema.kind in cardinality_one_fields_by_kind:
continue
cardinality_one_rel_identifiers_in_fields = {
rel_schema.identifier: rel_schema.direction
for rel_schema in node_schema.relationships
if rel_schema.cardinality is RelationshipCardinality.ONE and rel_schema.name in field_names_set
}
cardinality_one_fields_by_kind[node_schema.kind] = cardinality_one_rel_identifiers_in_fields
return cardinality_one_fields_by_kind


registry.manager = NodeManager
27 changes: 24 additions & 3 deletions backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,10 @@ async def delete(self, db: InfrahubDatabase, at: Optional[Timestamp] = None) ->
async def to_graphql(
self,
db: InfrahubDatabase,
fields: Optional[dict] = None,
related_node_ids: Optional[set] = None,
fields: dict | None = None,
related_node_ids: set | None = None,
filter_sensitive: bool = False,
permissions: Optional[dict] = None,
permissions: dict | None = None,
) -> dict:
"""Generate GraphQL Payload for all attributes
Expand Down Expand Up @@ -686,6 +686,27 @@ async def to_graphql(
db=db, filter_sensitive=filter_sensitive, permissions=permissions
)

for relationship_schema in self.get_schema().relationships:
peer_rels = []
if not fields or relationship_schema.name not in fields:
continue
rel_manager = getattr(self, relationship_schema.name, None)
if rel_manager is None:
continue
try:
if relationship_schema.cardinality is RelationshipCardinality.ONE:
rel = rel_manager.get_one()
if rel:
peer_rels = [rel]
else:
peer_rels = list(rel_manager)
if peer_rels:
response[relationship_schema.name] = [
{"node": {"id": relationship.peer_id}} for relationship in peer_rels if relationship.peer_id
]
except LookupError:
continue

return response

async def from_graphql(self, data: dict, db: InfrahubDatabase) -> bool:
Expand Down
14 changes: 9 additions & 5 deletions backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,31 +589,35 @@ class NodeListGetRelationshipsQuery(Query):
type: QueryType = QueryType.READ
insert_return: bool = False

def __init__(self, ids: list[str], **kwargs):
def __init__(self, ids: list[str], relationship_identifiers: list[str] | None = None, **kwargs):
self.ids = ids

self.relationship_identifiers = relationship_identifiers
super().__init__(**kwargs)

async def query_init(self, db: InfrahubDatabase, **kwargs) -> None:
self.params["ids"] = self.ids
self.params["relationship_identifiers"] = self.relationship_identifiers

rels_filter, rels_params = self.branch.get_query_filter_path(at=self.at, branch_agnostic=self.branch_agnostic)
self.params.update(rels_params)

query = """
MATCH (n:Node) WHERE n.uuid IN $ids
MATCH paths_in = ((n)<-[r1:IS_RELATED]-(rel:Relationship)<-[r2:IS_RELATED]-(peer))
WHERE all(r IN relationships(paths_in) WHERE (%(filters)s))
WHERE ($relationship_identifiers IS NULL OR rel.name in $relationship_identifiers)
AND all(r IN relationships(paths_in) WHERE (%(filters)s))
RETURN n, rel, peer, r1, r2, "inbound" as direction
UNION
MATCH (n:Node) WHERE n.uuid IN $ids
MATCH paths_out = ((n)-[r1:IS_RELATED]->(rel:Relationship)-[r2:IS_RELATED]->(peer))
WHERE all(r IN relationships(paths_out) WHERE (%(filters)s))
WHERE ($relationship_identifiers IS NULL OR rel.name in $relationship_identifiers)
AND all(r IN relationships(paths_out) WHERE (%(filters)s))
RETURN n, rel, peer, r1, r2, "outbound" as direction
UNION
MATCH (n:Node) WHERE n.uuid IN $ids
MATCH paths_bidir = ((n)-[r1:IS_RELATED]->(rel:Relationship)<-[r2:IS_RELATED]-(peer))
WHERE all(r IN relationships(paths_bidir) WHERE (%(filters)s))
WHERE ($relationship_identifiers IS NULL OR rel.name in $relationship_identifiers)
AND all(r IN relationships(paths_bidir) WHERE (%(filters)s))
RETURN n, rel, peer, r1, r2, "bidirectional" as direction
""" % {"filters": rels_filter}

Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/core/query/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def get_peer_ids(self) -> list[str]:
return [peer.peer_id for peer in self.get_peers()]

def get_peers(self) -> Generator[RelationshipPeerData, None, None]:
for result in self.get_results_group_by(("peer", "uuid")):
for result in self.get_results_group_by(("peer", "uuid"), ("source_node", "uuid")):
rels = result.get("rels")
data = RelationshipPeerData(
source_id=result.get_node("source_node").get("uuid"),
Expand Down
6 changes: 6 additions & 0 deletions backend/infrahub/core/relationship/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,12 @@ def __iter__(self) -> Iterator[Relationship]:

return iter(self._relationships)

def get_one(self) -> Relationship | None:
if not self.has_fetched_relationships:
raise LookupError("you can't get a relationship before the cache has been populated.")

return self._relationships[0] if self._relationships else None

def __len__(self) -> int:
if not self.has_fetched_relationships:
raise LookupError("you can't count relationships before the cache has been populated.")
Expand Down
3 changes: 3 additions & 0 deletions backend/infrahub/graphql/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from infrahub.core import registry
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import InitializationError
from infrahub.graphql.resolvers.single_relationship import SingleRelationshipResolver

from .manager import GraphQLSchemaManager

Expand All @@ -32,6 +33,7 @@ class GraphqlContext:
db: InfrahubDatabase
branch: Branch
types: dict
single_relationship_resolver: SingleRelationshipResolver
at: Optional[Timestamp] = None
related_node_ids: Optional[set] = None
service: Optional[InfrahubServices] = None
Expand Down Expand Up @@ -86,6 +88,7 @@ def prepare_graphql_params(
context=GraphqlContext(
db=db,
branch=branch,
single_relationship_resolver=SingleRelationshipResolver(),
at=Timestamp(at),
types=gqlm.get_graphql_types(),
related_node_ids=set(),
Expand Down
Empty file.
Loading

0 comments on commit 9fffe68

Please sign in to comment.