From c9251db47cc00822149752f613ca0f825e19db31 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 6 Dec 2022 05:43:16 +0100 Subject: [PATCH 1/3] =?UTF-8?q?=F0=9F=94=A7=20Type=20annotate=20aiida/tool?= =?UTF-8?q?s/visualization/graph.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 1 - aiida/tools/visualization/graph.py | 499 +++++++++++++---------------- docs/source/nitpick-exceptions | 4 + 3 files changed, 231 insertions(+), 273 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 17290cd115..1cfb55e233 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -186,7 +186,6 @@ repos: aiida/tools/groups/paths.py| aiida/tools/query/calculation.py| aiida/tools/query/mapping.py| - aiida/tools/visualization/graph.py| aiida/transports/cli.py| aiida/transports/plugins/local.py| aiida/transports/plugins/ssh.py| diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index 785c387094..7cec0a1f67 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -10,9 +10,10 @@ """ provides functionality to create graphs of the AiiDa data providence, *via* graphviz. """ +from __future__ import annotations + import os -from types import MappingProxyType # pylint: disable=no-name-in-module,useless-suppression -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Protocol, Sequence from graphviz import Digraph @@ -27,18 +28,22 @@ __all__ = ('Graph', 'default_link_styles', 'default_node_styles', 'pstate_node_styles', 'default_node_sublabels') +LinkAnnotateType = Literal[None, False, 'label', 'type', 'both'] + + +class LinkStyleFunc(Protocol): + """Protocol for a link style function""" -def default_link_styles(link_pair, add_label, add_type): - # type: (LinkPair, bool, bool) -> dict - """map link_pair to a graphviz edge style + def __call__(self, link_pair: LinkPair, add_label: bool, add_type: bool) -> dict: + ... + + +def default_link_styles(link_pair: LinkPair, add_label: bool, add_type: bool) -> dict: + """Map link_pair to a graphviz edge style :param link_type: a LinkPair attribute - :type link_type: aiida.orm.utils.links.LinkPair :param add_label: include link label - :type add_label: bool :param add_type: include link type - :type add_type: bool - :rtype: dict """ style = { @@ -78,12 +83,10 @@ def default_link_styles(link_pair, add_label, add_type): return style -def default_node_styles(node): - """map a node to a graphviz node style +def default_node_styles(node: orm.Node) -> dict: + """Map a node to a graphviz node style :param node: the node to map - :type node: aiida.orm.nodes.node.Node - :rtype: dict """ class_node_type = node.class_node_type @@ -136,12 +139,10 @@ def default_node_styles(node): return node_style -def pstate_node_styles(node): - """map a process node to a graphviz node style +def pstate_node_styles(node: orm.Node) -> dict: + """Map a process node to a graphviz node style :param node: the node to map - :type node: aiida.orm.nodes.node.Node - :rtype: dict """ class_node_type = node.class_node_type @@ -205,13 +206,11 @@ def pstate_node_styles(node): } -def default_node_sublabels(node): - """function mapping nodes to a sublabel +def default_node_sublabels(node: orm.Node) -> str: + """Function mapping nodes to a sub-label (e.g. specifying some attribute values) :param node: the node to map - :type node: aiida.orm.nodes.node.Node - :rtype: str """ # pylint: disable=too-many-branches @@ -225,11 +224,12 @@ def default_node_sublabels(node): elif class_node_type == 'data.core.bool.Bool.': sublabel = f"{node.base.attributes.get('value', '')}" elif class_node_type == 'data.core.code.Code.': - sublabel = f'{os.path.basename(node.get_execname())}@{node.computer.label}' + label = '?' if node.computer is None else node.computer.label + sublabel = f'{os.path.basename(node.get_execname())}@{label}' elif class_node_type == 'data.core.singlefile.SinglefileData.': sublabel = node.filename elif class_node_type == 'data.core.remote.RemoteData.': - sublabel = f'@{node.computer.label}' + sublabel = f'@{node.computer.label}' if node.computer is not None else '@?' elif class_node_type == 'data.core.structure.StructureData.': sublabel = node.get_formula() elif class_node_type == 'data.core.cif.CifData.': @@ -244,22 +244,22 @@ def default_node_sublabels(node): elif class_node_type == 'data.core.upf.UpfData.': sublabel = f"{node.base.attributes.get('element', '')}" elif isinstance(node, orm.ProcessNode): - sublabel = [] + sublabel_list = [] if node.process_state is not None: - sublabel.append(f'State: {node.process_state.value}') + sublabel_list.append(f'State: {node.process_state.value}') if node.exit_status is not None: - sublabel.append(f'Exit Code: {node.exit_status}') - sublabel = '\n'.join(sublabel) + sublabel_list.append(f'Exit Code: {node.exit_status}') + sublabel = '\n'.join(sublabel_list) else: sublabel = node.get_description() return sublabel -def get_node_id_label(node, id_type): - """return an identifier str for the node """ +def get_node_id_label(node: orm.Node, id_type: Literal['pk', 'uuid', 'label']) -> str: + """Return an identifier str for the node """ if id_type == 'pk': - return node.pk + return str(node.pk) if id_type == 'uuid': return node.uuid.split('-')[0] if id_type == 'label': @@ -267,8 +267,8 @@ def get_node_id_label(node, id_type): raise ValueError(f'node_id_type not recognised: {id_type}') -def _get_node_label(node, id_type='pk'): - """return a label text of node and the return format is ' ()'.""" +def _get_node_label(node: orm.Node, id_type: Literal['pk', 'uuid', 'label'] = 'pk') -> str: + """Return a label text of node and the return format is ' ()'.""" if isinstance(node, orm.Data): label = f'{node.__class__.__name__} ({get_node_id_label(node, id_type)})' elif isinstance(node, orm.ProcessNode): @@ -283,9 +283,15 @@ def _get_node_label(node, id_type='pk'): def _add_graphviz_node( - graph, node, node_style_func, node_sublabel_func, style_override=None, include_sublabels=True, id_type='pk' + graph: Digraph, + node: orm.Node, + node_style_func, + node_sublabel_func, + style_override: None | dict = None, + include_sublabels: bool = True, + id_type: Literal['pk', 'uuid', 'label'] = 'pk' ): - """create a node in the graph + """Create a node in the graph The first line of the node text is always ' ()'. Then, if ``include_sublabels=True``, subsequent lines are added, @@ -293,17 +299,13 @@ def _add_graphviz_node( :param graph: the graphviz.Digraph to add the node to :param node: the node to add - :type node: aiida.orm.nodes.node.Node :param node_style_func: callable mapping a node instance to a dictionary defining the graphviz node style :param node_sublabel_func: callable mapping a node instance to a sub-label for the node text :param style_override: style dictionary, whose keys will override the final computed style - :type style_override: None or dict :param include_sublabels: whether to include the sublabels for nodes - :type include_sublabels: bool - :param id_type: the type of identifier to use for node labels ('pk' or 'uuid') - :type id_type: str + :param id_type: the type of identifier to use for node labels - nodes are styled based on the node type + Nodes are styled based on the node type For subclasses of Data, the ``class_node_type`` attribute is used for mapping to type specific styles @@ -331,14 +333,13 @@ def _add_graphviz_node( return graph.node(f'N{node.pk}', **node_style) -def _add_graphviz_edge(graph, in_node, out_node, style=None): - """add graphviz edge between two nodes +def _add_graphviz_edge(graph: Digraph, in_node: orm.Node, out_node: orm.Node, style: dict | None = None) -> dict: + """Add graphviz edge between two nodes :param graph: the graphviz.DiGraph to add the edge to :param in_node: the head node :param out_node: the tail node - :param style: the graphviz style (Default value = None) - :type style: dict or None + :param style: the graphviz style """ if style is None: @@ -351,51 +352,45 @@ def _add_graphviz_edge(graph, in_node, out_node, style=None): class Graph: - """a class to create graphviz graphs of the AiiDA node provenance""" + """A class to create graphviz graphs of the AiiDA node provenance.""" def __init__( self, - engine=None, - graph_attr=None, - global_node_style=None, - global_edge_style=None, - include_sublabels=True, - link_style_fn=None, - node_style_fn=None, - node_sublabel_fn=None, - node_id_type='pk', - backend: Optional['StorageBackend'] = None + engine: str | None = None, + graph_attr: dict | None = None, + global_node_style: dict | None = None, + global_edge_style: dict | None = None, + include_sublabels: bool = True, + link_style_fn: LinkStyleFunc | None = None, + node_style_fn: Callable[[orm.Node], dict] | None = None, + node_sublabel_fn: Callable[[orm.Node], str] | None = None, + node_id_type: Literal['pk', 'uuid', 'label'] = 'pk', + backend: StorageBackend | None = None ): - """a class to create graphviz graphs of the AiiDA node provenance + """A class to create graphviz graphs of the AiiDA node provenance Nodes and edges, are cached, so that they are only created once - :param engine: the graphviz engine, e.g. dot, circo (Default value = None) - :type engine: str or None - :param graph_attr: attributes for the graphviz graph (Default value = None) - :type graph_attr: dict or None + :param engine: the graphviz engine, e.g. dot, circo + :param graph_attr: attributes for the graphviz graph :param global_node_style: styles which will be added to all nodes. - Note this will override any builtin attributes (Default value = None) - :type global_node_style: dict or None + Note this will override any builtin attributes :param global_edge_style: styles which will be added to all edges. - Note this will override any builtin attributes (Default value = None) - :type global_edge_style: dict or None - :param include_sublabels: if True, the note text will include node dependant sub-labels (Default value = True) - :type include_sublabels: bool + Note this will override any builtin attributes + :param include_sublabels: if True, the note text will include node dependant sub-labels :param link_style_fn: callable mapping LinkType to graphviz style dict; - link_style_fn(link_type) -> dict (Default value = None) + link_style_fn(link_type, add_label, add_type) -> dict :param node_sublabel_fn: callable mapping nodes to a graphviz style dict; - node_sublabel_fn(node) -> dict (Default value = None) + node_sublabel_fn(node) -> dict :param node_sublabel_fn: callable mapping data node to a sublabel (e.g. specifying some attribute values) - node_sublabel_fn(node) -> str (Default value = None) - :param node_id_type: the type of identifier to within the node text ('pk', 'uuid' or 'label') - :type node_id_type: str + node_sublabel_fn(node) -> str + :param node_id_type: the type of identifier to within the node text """ # pylint: disable=too-many-arguments self._graph = Digraph(engine=engine, graph_attr=graph_attr) - self._nodes = set() - self._edges = set() + self._nodes: set[int] = set() + self._edges: set[tuple[int, int, None | LinkPair]] = set() self._global_node_style = global_node_style or {} self._global_edge_style = global_edge_style or {} self._include_sublabels = include_sublabels @@ -405,35 +400,33 @@ def __init__( self._node_id_type = node_id_type self._backend = backend or get_manager().get_profile_storage() - self._ignore_node_style = _OVERRIDE_STYLES_DICT['ignore_node'] - self._origin_node_style = _OVERRIDE_STYLES_DICT['origin_node'] + self._ignore_node_style = _OVERRIDE_STYLES_DICT['ignore_node'].copy() + self._origin_node_style = _OVERRIDE_STYLES_DICT['origin_node'].copy() @property - def backend(self) -> 'StorageBackend': + def backend(self) -> StorageBackend: """The backend used to create the graph""" return self._backend @property - def graphviz(self): - """return a copy of the graphviz.Digraph""" + def graphviz(self) -> Digraph: + """Return a copy of the graphviz.Digraph""" return self._graph.copy() @property - def nodes(self): - """return a copy of the nodes""" + def nodes(self) -> set[int]: + """Return a copy of the nodes""" return self._nodes.copy() @property - def edges(self): - """return a copy of the edges""" + def edges(self) -> set[tuple[int, int, None | LinkPair]]: + """Return a copy of the edges""" return self._edges.copy() - def _load_node(self, node): - """ load a node (if not already loaded) + def _load_node(self, node: int | str | orm.Node) -> orm.Node: + """Load a node (if not already loaded) :param node: node or node pk/uuid - :type node: int or str or aiida.orm.nodes.node.Node - :returns: aiida.orm.nodes.node.Node """ if isinstance(node, int): return orm.Node.collection(self._backend).get(pk=node) @@ -441,33 +434,14 @@ def _load_node(self, node): return orm.Node.collection(self._backend).get(uuid=node) return node - @staticmethod - def _default_link_types(link_types): - """If link_types is empty, it will return all the links_types - - :param links: iterable with the link_types () - :returns: list of :py:class:`aiida.common.links.LinkType` - """ - if not link_types: - all_link_types = [LinkType.CREATE] - all_link_types.append(LinkType.RETURN) - all_link_types.append(LinkType.INPUT_CALC) - all_link_types.append(LinkType.INPUT_WORK) - all_link_types.append(LinkType.CALL_CALC) - all_link_types.append(LinkType.CALL_WORK) - return all_link_types - - return link_types - - def add_node(self, node, style_override=None, overwrite=False): - """add single node to the graph + def add_node( + self, node: int | str | orm.Node, style_override: dict | None = None, overwrite: bool = False + ) -> orm.Node: + """Add single node to the graph :param node: node or node pk/uuid - :type node: int or str or aiida.orm.nodes.node.Node :param style_override: graphviz style parameters that will override default values - :type style_override: dict or None - :param overwrite: whether to overrite an existing node (Default value = False) - :type overwrite: bool + :param overwrite: whether to overwrite an existing node """ node = self._load_node(node) style = {} if style_override is None else dict(style_override) @@ -485,19 +459,21 @@ def add_node(self, node, style_override=None, overwrite=False): self._nodes.add(node.pk) return node - def add_edge(self, in_node, out_node, link_pair=None, style=None, overwrite=False): - """add single node to the graph + def add_edge( + self, + in_node: int | str | orm.Node, + out_node: int | str | orm.Node, + link_pair: LinkPair | None = None, + style: dict | None = None, + overwrite: bool = False + ) -> None: + """Add single node to the graph :param in_node: node or node pk/uuid - :type in_node: int or aiida.orm.nodes.node.Node :param out_node: node or node pk/uuid - :type out_node: int or str or aiida.orm.nodes.node.Node :param link_pair: defining the relationship between the nodes - :type link_pair: None or aiida.orm.utils.links.LinkPair - :param style: graphviz style parameters (Default value = None) - :type style: dict or None - :param overwrite: whether to overrite existing edge (Default value = False) - :type overwrite: bool + :param style: graphviz style parameters + :param overwrite: whether to overwrite existing edge """ in_node = self._load_node(in_node) if in_node.pk not in self._nodes: @@ -516,26 +492,35 @@ def add_edge(self, in_node, out_node, link_pair=None, style=None, overwrite=Fals _add_graphviz_edge(self._graph, in_node, out_node, style) @staticmethod - def _convert_link_types(link_types): - """convert link types, which may be strings, to a member of LinkType""" - if link_types is None: - return None - if isinstance(link_types, str): - link_types = [link_types] - link_types = tuple(getattr(LinkType, l.upper()) if isinstance(l, str) else l for l in link_types) - return link_types - - def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True): - """add nodes and edges for incoming links to a node + def _convert_link_types( + link_types: None | str | LinkType | Sequence[str] | Sequence[LinkType] + ) -> tuple[LinkType, ...]: + """Convert link types, which may be strings, to a member of LinkType""" + link_types_list: Sequence[LinkType] | Sequence[str] + if not link_types: + link_types_list = [ + LinkType.CREATE, LinkType.RETURN, LinkType.INPUT_CALC, LinkType.INPUT_WORK, LinkType.CALL_CALC, + LinkType.CALL_WORK + ] + elif isinstance(link_types, (str, LinkType)): + link_types_list = [link_types] # type: ignore + else: + link_types_list = link_types + return tuple(getattr(LinkType, l.upper()) if isinstance(l, str) else l for l in link_types_list) + + def add_incoming( + self, + node: int | str | orm.Node, + link_types: None | str | Sequence[str] | LinkType | Sequence[LinkType] = None, + annotate_links: LinkAnnotateType = None, + return_pks: bool = True + ) -> list[int] | list[orm.Node]: + """Add nodes and edges for incoming links to a node :param node: node or node pk/uuid - :type node: aiida.orm.nodes.node.Node or int - :param link_types: filter by link types (Default value = ()) - :type link_types: str or tuple[str] or aiida.common.links.LinkType or tuple[aiida.common.links.LinkType] - :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = None) - :type annotate_links: bool or str - :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True) - :type return_pks: bool + :param link_types: filter by link types + :param annotate_links: label edges with the link 'label', 'type' or 'both' + :param return_pks: whether to return a list of nodes, or list of node pks :returns: list of nodes or node pks """ if annotate_links not in [None, False, 'label', 'type', 'both']: @@ -545,8 +530,7 @@ def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True # incoming nodes are found traversing backwards node_pk = self._load_node(node).pk - valid_link_types = self._default_link_types(link_types) - valid_link_types = self._convert_link_types(valid_link_types) + valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (node_pk,), max_iterations=1, @@ -555,7 +539,7 @@ def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True links_backward=valid_link_types, ) - traversed_nodes = orm.QueryBuilder(backend=self.backend).append( + query = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -563,12 +547,12 @@ def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True project=['id', '*'], tag='node', ) - traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} + traversed_nodes = {query_result[0]: query_result[1] for query_result in query.all()} for _, traversed_node in traversed_nodes.items(): self.add_node(traversed_node, style_override=None) - for link in traversed_graph['links']: + for link in (traversed_graph['links'] or []): source_node = traversed_nodes[link.source_id] target_node = traversed_nodes[link.target_id] link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) @@ -582,17 +566,19 @@ def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True # else: return list(traversed_nodes.values()) - def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True): - """add nodes and edges for outgoing links to a node + def add_outgoing( + self, + node: int | str | orm.Node, + link_types: None | str | Sequence[str] | LinkType | Sequence[LinkType] = None, + annotate_links: LinkAnnotateType = None, + return_pks: bool = True + ) -> list[int] | list[orm.Node]: + """Add nodes and edges for outgoing links to a node :param node: node or node pk - :type node: aiida.orm.nodes.node.Node or int - :param link_types: filter by link types (Default value = ()) - :type link_types: str or tuple[str] or aiida.common.links.LinkType or tuple[aiida.common.links.LinkType] - :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = None) - :type annotate_links: bool or str - :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True) - :type return_pks: bool + :param link_types: filter by link types + :param annotate_links: label edges with the link 'label', 'type' or 'both' + :param return_pks: whether to return a list of nodes, or list of node pks :returns: list of nodes or node pks """ if annotate_links not in [None, False, 'label', 'type', 'both']: @@ -602,8 +588,7 @@ def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True # outgoing nodes are found traversing forwards node_pk = self._load_node(node).pk - valid_link_types = self._default_link_types(link_types) - valid_link_types = self._convert_link_types(valid_link_types) + valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (node_pk,), max_iterations=1, @@ -612,7 +597,7 @@ def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True links_forward=valid_link_types, ) - traversed_nodes = orm.QueryBuilder(backend=self.backend).append( + query = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -620,12 +605,12 @@ def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True project=['id', '*'], tag='node', ) - traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} + traversed_nodes = {query_result[0]: query_result[1] for query_result in query.all()} for _, traversed_node in traversed_nodes.items(): self.add_node(traversed_node, style_override=None) - for link in traversed_graph['links']: + for link in (traversed_graph['links'] or []): source_node = traversed_nodes[link.source_id] target_node = traversed_nodes[link.target_id] link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) @@ -641,39 +626,31 @@ def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True def recurse_descendants( self, - origin, - depth=None, - link_types=(), - annotate_links=False, - origin_style=MappingProxyType(_OVERRIDE_STYLES_DICT['origin_node']), - include_process_inputs=False, - highlight_classes=None, - ): - """add nodes and edges from an origin recursively, + origin: int | str | orm.Node, + depth: int | None = None, + link_types: None | str | Sequence[str] | LinkType | Sequence[LinkType] = None, + annotate_links: LinkAnnotateType = False, + origin_style: dict | None = None, + include_process_inputs: bool = False, + highlight_classes: None | Sequence[str] = None, + ) -> None: + """Add nodes and edges from an origin recursively, following outgoing links :param origin: node or node pk/uuid - :type origin: aiida.orm.nodes.node.Node or int - :param depth: if not None, stop after travelling a certain depth into the graph (Default value = None) - :type depth: None or int - :param link_types: filter by subset of link types (Default value = ()) - :type link_types: tuple or str - :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) - :type annotate_links: bool or str - :param origin_style: node style map for origin node (Default value = None) - :type origin_style: None or dict - :param include_calculation_inputs: include incoming links for all processes (Default value = False) - :type include_calculation_inputs: bool + :param depth: if not None, stop after travelling a certain depth into the graph + :param link_types: filter by subset of link types + :param annotate_links: label edges with the link 'label', 'type' or 'both' + :param origin_style: node style map for origin node + :param include_calculation_inputs: include incoming links for all processes :param highlight_classes: target class in exported graph expected to be highlight and - other nodes are decolorized (Default value = None) - :typle highlight_classes: tuple of class or class + other nodes are decolorized """ # pylint: disable=too-many-arguments,too-many-locals # Get graph traversal rules where the given link types and direction are all set to True, # and all others are set to False origin_pk = self._load_node(origin).pk - valid_link_types = self._default_link_types(link_types) - valid_link_types = self._convert_link_types(valid_link_types) + valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (origin_pk,), max_iterations=depth, @@ -693,10 +670,10 @@ def recurse_descendants( links_backward=[LinkType.INPUT_WORK, LinkType.INPUT_CALC] ) traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) - traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) + traversed_graph['links'] = (traversed_graph['links'] or set()).union(traversed_outputs['links'] or set()) # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary - traversed_nodes = orm.QueryBuilder(backend=self.backend).append( + query = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -704,11 +681,11 @@ def recurse_descendants( project=['id', '*'], tag='node', ) - traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} + traversed_nodes = {query_result[0]: query_result[1] for query_result in query.all()} # Pop the origin node and add it to the graph, applying custom styling origin_node = traversed_nodes.pop(origin_pk) - self.add_node(origin_node, style_override=origin_style) + self.add_node(origin_node, style_override=origin_style or _OVERRIDE_STYLES_DICT['origin_node'].copy()) # Add all traversed nodes to the graph with default styling for _, traversed_node in traversed_nodes.items(): @@ -723,7 +700,7 @@ def recurse_descendants( # Add all links to the Graph, using the {id: Node} dictionary for queryless Node retrieval, applying # appropriate styling - for link in traversed_graph['links']: + for link in (traversed_graph['links'] or []): source_node = traversed_nodes[link.source_id] target_node = traversed_nodes[link.target_id] link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) @@ -734,39 +711,31 @@ def recurse_descendants( def recurse_ancestors( self, - origin, - depth=None, - link_types=(), - annotate_links=False, - origin_style=MappingProxyType(_OVERRIDE_STYLES_DICT['origin_node']), - include_process_outputs=False, - highlight_classes=None, - ): - """add nodes and edges from an origin recursively, + origin: int | str | orm.Node, + depth: int | None = None, + link_types: None | str | Sequence[str] | LinkType | Sequence[LinkType] = None, + annotate_links: LinkAnnotateType = False, + origin_style: dict | None = None, + include_process_outputs: bool = False, + highlight_classes: None | Sequence[str] = None, + ) -> None: + """Add nodes and edges from an origin recursively, following incoming links :param origin: node or node pk/uuid - :type origin: aiida.orm.nodes.node.Node or int - :param depth: if not None, stop after travelling a certain depth into the graph (Default value = None) - :type depth: None or int - :param link_types: filter by subset of link types (Default value = ()) - :type link_types: tuple or str - :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) - :type annotate_links: bool - :param origin_style: node style map for origin node (Default value = None) - :type origin_style: None or dict - :param include_process_outputs: include outgoing links for all processes (Default value = False) - :type include_process_outputs: bool + :param depth: if not None, stop after travelling a certain depth into the graph + :param link_types: filter by subset of link types + :param annotate_links: label edges with the link 'label', 'type' or 'both' + :param origin_style: node style map for origin node + :param include_process_outputs: include outgoing links for all processes :param highlight_classes: class label (as displayed in the graph, e.g. 'StructureData', 'FolderData', etc.) - to be highlight and other nodes are decolorized (Default value = None) - :typle highlight_classes: list or tuple of str + to be highlight and other nodes are decolorized """ # pylint: disable=too-many-arguments,too-many-locals # Get graph traversal rules where the given link types and direction are all set to True, # and all others are set to False origin_pk = self._load_node(origin).pk - valid_link_types = self._default_link_types(link_types) - valid_link_types = self._convert_link_types(valid_link_types) + valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (origin_pk,), max_iterations=depth, @@ -786,10 +755,10 @@ def recurse_ancestors( links_forward=[LinkType.CREATE, LinkType.RETURN] ) traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) - traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) + traversed_graph['links'] = (traversed_graph['links'] or set()).union(traversed_outputs['links'] or set()) # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary - traversed_nodes = orm.QueryBuilder(backend=self.backend).append( + query = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -797,11 +766,11 @@ def recurse_ancestors( project=['id', '*'], tag='node', ) - traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} + traversed_nodes = {query_result[0]: query_result[1] for query_result in query.all()} # Pop the origin node and add it to the graph, applying custom styling origin_node = traversed_nodes.pop(origin_pk) - self.add_node(origin_node, style_override=origin_style) + self.add_node(origin_node, style_override=(origin_style or _OVERRIDE_STYLES_DICT['origin_node'].copy())) # Add all traversed nodes to the graph with default styling for _, traversed_node in traversed_nodes.items(): @@ -816,7 +785,7 @@ def recurse_ancestors( # Add all links to the Graph, using the {id: Node} dictionary for queryless Node retrieval, applying # appropriate styling - for link in traversed_graph['links']: + for link in (traversed_graph['links'] or []): source_node = traversed_nodes[link.source_id] target_node = traversed_nodes[link.target_id] link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) @@ -827,29 +796,23 @@ def recurse_ancestors( def add_origin_to_targets( self, - origin, - target_cls, - target_filters=None, - include_target_inputs=False, - include_target_outputs=False, - origin_style=(), - annotate_links=False - ): + origin: int | str | orm.Node, + target_cls: type[orm.Node], + target_filters: dict | None = None, + include_target_inputs: bool = False, + include_target_outputs: bool = False, + origin_style: Mapping[str, Any] | None = None, + annotate_links: LinkAnnotateType = False + ) -> None: """Add nodes and edges from an origin node to all nodes of a target node class. :param origin: node or node pk/uuid - :type origin: aiida.orm.nodes.node.Node or int :param target_cls: target node class - :param target_filters: (Default value = None) - :type target_filters: dict or None - :param include_target_inputs: (Default value = False) - :type include_target_inputs: bool - :param include_target_outputs: (Default value = False) - :type include_target_outputs: bool - :param origin_style: node style map for origin node (Default value = ()) - :type origin_style: dict or tuple - :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) - :type annotate_links: bool + :param target_filters: filters for query of target nodes + :param include_target_inputs: Include incoming links for all target nodes + :param include_target_outputs: Include outgoing links for all target nodes + :param origin_style: node style map for origin node + :param annotate_links: label edges with the link 'label', 'type' or 'both' """ # pylint: disable=too-many-arguments origin_node = self._load_node(origin) @@ -857,25 +820,23 @@ def add_origin_to_targets( if target_filters is None: target_filters = {} - self.add_node(origin_node, style_override=dict(origin_style)) + self.add_node(origin_node, style_override=dict(origin_style or {})) query = orm.QueryBuilder( backend=self.backend, - **{ - 'path': [{ - 'cls': origin_node.__class__, - 'filters': { - 'id': origin_node.pk - }, - 'tag': 'origin' - }, { - 'cls': target_cls, - 'filters': target_filters, - 'with_ancestors': 'origin', - 'tag': 'target', - 'project': '*' - }] - } + path=[{ + 'cls': origin_node.__class__, + 'filters': { + 'id': origin_node.pk + }, + 'tag': 'origin' + }, { + 'cls': target_cls, + 'filters': target_filters, + 'with_ancestors': 'origin', + 'tag': 'target', + 'project': '*' + }] ) for (target_node,) in query.iterall(): @@ -890,31 +851,25 @@ def add_origin_to_targets( def add_origins_to_targets( self, - origin_cls, - target_cls, - origin_filters=None, - target_filters=None, - include_target_inputs=False, - include_target_outputs=False, - origin_style=(), - annotate_links=False - ): + origin_cls: type[orm.Node], + target_cls: type[orm.Node], + origin_filters: dict | None = None, + target_filters: dict | None = None, + include_target_inputs: bool = False, + include_target_outputs: bool = False, + origin_style: Mapping[str, Any] | None = None, + annotate_links: LinkAnnotateType = False + ) -> None: """Add nodes and edges from all nodes of an origin class to all node of a target node class. :param origin_cls: origin node class :param target_cls: target node class - :param origin_filters: (Default value = None) - :type origin_filters: dict or None - :param target_filters: (Default value = None) - :type target_filters: dict or None - :param include_target_inputs: (Default value = False) - :type include_target_inputs: bool - :param include_target_outputs: (Default value = False) - :type include_target_outputs: bool - :param origin_style: node style map for origin node (Default value = ()) - :type origin_style: dict or tuple - :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) - :type annotate_links: bool + :param origin_filters: filters for origin nodes + :param target_filters: filters for target nodes + :param include_target_inputs: Include incoming links for all target nodes + :param include_target_outputs: Include outgoing links for all target nodes + :param origin_style: node style map for origin node + :param annotate_links: label edges with the link 'label', 'type' or 'both' """ # pylint: disable=too-many-arguments if origin_filters is None: @@ -922,12 +877,12 @@ def add_origins_to_targets( query = orm.QueryBuilder( backend=self.backend, - **{'path': [{ + path=[{ 'cls': origin_cls, 'filters': origin_filters, 'tag': 'origin', 'project': '*' - }]} + }] ) for (node,) in query.iterall(): diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index d665f3238a..b3688a1b9a 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -39,6 +39,7 @@ py:class Logger py:class ModuleType py:class ExternalType py:class UserDefinedType +py:class LinkAnnotateType py:obj ReturnType py:obj SelfType py:obj CollectionType @@ -206,3 +207,6 @@ py:class sqlalchemy.sql.elements.ColumnElement py:class packaging.version.Version py:exc seekpath.hpkot.EdgeCaseWarning + +py:class graphviz.graphs.Digraph +py:class Digraph From 16f3273712139763dcf2476a3c574bfa502261f1 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 13 Dec 2022 14:59:46 +0100 Subject: [PATCH 2/3] remove False from LinkAnnotateType --- aiida/tools/visualization/graph.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index 7cec0a1f67..14f0696927 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -28,7 +28,7 @@ __all__ = ('Graph', 'default_link_styles', 'default_node_styles', 'pstate_node_styles', 'default_node_sublabels') -LinkAnnotateType = Literal[None, False, 'label', 'type', 'both'] +LinkAnnotateType = Literal[None, 'label', 'type', 'both'] class LinkStyleFunc(Protocol): @@ -629,7 +629,7 @@ def recurse_descendants( origin: int | str | orm.Node, depth: int | None = None, link_types: None | str | Sequence[str] | LinkType | Sequence[LinkType] = None, - annotate_links: LinkAnnotateType = False, + annotate_links: LinkAnnotateType = None, origin_style: dict | None = None, include_process_inputs: bool = False, highlight_classes: None | Sequence[str] = None, @@ -714,7 +714,7 @@ def recurse_ancestors( origin: int | str | orm.Node, depth: int | None = None, link_types: None | str | Sequence[str] | LinkType | Sequence[LinkType] = None, - annotate_links: LinkAnnotateType = False, + annotate_links: LinkAnnotateType = None, origin_style: dict | None = None, include_process_outputs: bool = False, highlight_classes: None | Sequence[str] = None, @@ -802,7 +802,7 @@ def add_origin_to_targets( include_target_inputs: bool = False, include_target_outputs: bool = False, origin_style: Mapping[str, Any] | None = None, - annotate_links: LinkAnnotateType = False + annotate_links: LinkAnnotateType = None ) -> None: """Add nodes and edges from an origin node to all nodes of a target node class. @@ -858,7 +858,7 @@ def add_origins_to_targets( include_target_inputs: bool = False, include_target_outputs: bool = False, origin_style: Mapping[str, Any] | None = None, - annotate_links: LinkAnnotateType = False + annotate_links: LinkAnnotateType = None ) -> None: """Add nodes and edges from all nodes of an origin class to all node of a target node class. From 032e274a27a66989c22e61948c8199db85985eeb Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 13 Dec 2022 15:06:41 +0100 Subject: [PATCH 3/3] _OVERRIDE_STYLES_DICT -> function --- aiida/tools/visualization/graph.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index 14f0696927..cecd10ee5e 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -193,17 +193,21 @@ def pstate_node_styles(node: orm.Node) -> dict: return node_style -_OVERRIDE_STYLES_DICT = { - 'ignore_node': { +def _default_ignore_node_styles() -> dict: + """Return the default style for ignored nodes.""" + return { 'color': 'lightgray', 'fillcolor': 'white', 'penwidth': 2, - }, - 'origin_node': { + } + + +def _default_origin_node_styles() -> dict: + """Return the default style for origin nodes.""" + return { 'color': 'red', 'penwidth': 6, - }, -} + } def default_node_sublabels(node: orm.Node) -> str: @@ -400,8 +404,8 @@ def __init__( self._node_id_type = node_id_type self._backend = backend or get_manager().get_profile_storage() - self._ignore_node_style = _OVERRIDE_STYLES_DICT['ignore_node'].copy() - self._origin_node_style = _OVERRIDE_STYLES_DICT['origin_node'].copy() + self._ignore_node_style = _default_ignore_node_styles() + self._origin_node_style = _default_origin_node_styles() @property def backend(self) -> StorageBackend: @@ -685,7 +689,7 @@ def recurse_descendants( # Pop the origin node and add it to the graph, applying custom styling origin_node = traversed_nodes.pop(origin_pk) - self.add_node(origin_node, style_override=origin_style or _OVERRIDE_STYLES_DICT['origin_node'].copy()) + self.add_node(origin_node, style_override=origin_style or _default_origin_node_styles()) # Add all traversed nodes to the graph with default styling for _, traversed_node in traversed_nodes.items(): @@ -770,7 +774,7 @@ def recurse_ancestors( # Pop the origin node and add it to the graph, applying custom styling origin_node = traversed_nodes.pop(origin_pk) - self.add_node(origin_node, style_override=(origin_style or _OVERRIDE_STYLES_DICT['origin_node'].copy())) + self.add_node(origin_node, style_override=(origin_style or _default_origin_node_styles())) # Add all traversed nodes to the graph with default styling for _, traversed_node in traversed_nodes.items():