Skip to content

Commit

Permalink
Add traverse_graph / AGE as engine for export
Browse files Browse the repository at this point in the history
The export function now uses the traverse_graph function (with AGE as
the main engine) to collect the extra nodes that are needed to keep a
consistent provenance. This is performed, more specifically, by the
'retrieve_linked_nodes' function. Whereas previously a different query
was performed for each new node added in the previous query step, this
new implementation should do a single new query for all the nodes that
were added in the previous query step. So these changes are not only
important as a first step to homogenize graph traversal throughout the
whole code: an improvement in the export procedure is expected as well.
  • Loading branch information
lekah authored and ramirezfranciscof committed Dec 18, 2019
1 parent 2de10b3 commit 93afc88
Showing 1 changed file with 39 additions and 172 deletions.
211 changes: 39 additions & 172 deletions aiida/tools/importexport/dbexport/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,11 @@ def retrieve_linked_nodes(process_nodes, data_nodes, **kwargs): # pylint: disab
:raises `~aiida.tools.importexport.common.exceptions.ExportValidationError`: if wrong or too many kwargs are given.
"""
from aiida.common.links import LinkType, GraphTraversalRules
from aiida.orm import Data
from aiida.common.links import GraphTraversalRules
from aiida.orm import Node
from aiida.tools.graph.graph_traversers import traverse_graph

# Initialization and set flags according to rules
retrieved_nodes = set()
links_uuid_dict = {}
traversal_rules = {}

# Create the dictionary with graph traversal rules to be used in determing complete node set to be exported
Expand All @@ -403,171 +402,39 @@ def retrieve_linked_nodes(process_nodes, data_nodes, **kwargs): # pylint: disab
# Use the rule value passed in the keyword arguments, or if not the case, use the default
traversal_rules[name] = kwargs.pop(name, rule.default)

# We repeat until there are no further nodes to be visited
while process_nodes or data_nodes:

# If is is a ProcessNode
if process_nodes:
current_node_pk = process_nodes.pop()
# If it is already visited continue to the next node
if current_node_pk in retrieved_nodes:
continue
# Otherwise say that it is a node to be exported
else:
retrieved_nodes.add(current_node_pk)

# INPUT_CALC(Data, CalculationNode) - Backward
if traversal_rules['input_calc_backward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=Data,
output_type=ProcessNode,
direction='backward',
link_type_value=LinkType.INPUT_CALC.value
)
data_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# CREATE(CalculationNode, Data) - Forward
if traversal_rules['create_forward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=Data,
direction='forward',
link_type_value=LinkType.CREATE.value
)
data_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# RETURN(WorkflowNode, Data) - Forward
if traversal_rules['return_forward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=Data,
direction='forward',
link_type_value=LinkType.RETURN.value
)
data_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# INPUT_WORK(Data, WorkflowNode) - Backward
if traversal_rules['input_work_backward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=Data,
output_type=ProcessNode,
direction='backward',
link_type_value=LinkType.INPUT_WORK.value
)
data_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# CALL_CALC(WorkflowNode, CalculationNode) - Forward
if traversal_rules['call_calc_forward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=ProcessNode,
direction='forward',
link_type_value=LinkType.CALL_CALC.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# CALL_CALC(WorkflowNode, CalculationNode) - Backward
if traversal_rules['call_calc_backward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=ProcessNode,
direction='backward',
link_type_value=LinkType.CALL_CALC.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# CALL_WORK(WorkflowNode, WorkflowNode) - Forward
if traversal_rules['call_work_forward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=ProcessNode,
direction='forward',
link_type_value=LinkType.CALL_WORK.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# CALL_WORK(WorkflowNode, WorkflowNode) - Backward
if traversal_rules['call_work_backward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=ProcessNode,
direction='backward',
link_type_value=LinkType.CALL_WORK.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# If it is a Data
else:
current_node_pk = data_nodes.pop()
# If it is already visited continue to the next node
if current_node_pk in retrieved_nodes:
continue
# Otherwise say that it is a node to be exported
else:
retrieved_nodes.add(current_node_pk)

# INPUT_CALC(Data, CalculationNode) - Forward
if traversal_rules['input_calc_forward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=Data,
output_type=ProcessNode,
direction='forward',
link_type_value=LinkType.INPUT_CALC.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# CREATE(CalculationNode, Data) - Backward
if traversal_rules['create_backward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=Data,
direction='backward',
link_type_value=LinkType.CREATE.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# RETURN(WorkflowNode, Data) - Backward
if traversal_rules['return_backward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=ProcessNode,
output_type=Data,
direction='backward',
link_type_value=LinkType.RETURN.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

# INPUT_WORK(Data, WorkflowNode) - Forward
if traversal_rules['input_work_forward']:
links_uuids, found_nodes = _retrieve_linked_nodes_query(
current_node_pk,
input_type=Data,
output_type=ProcessNode,
direction='forward',
link_type_value=LinkType.INPUT_WORK.value
)
process_nodes.update(found_nodes - retrieved_nodes)
links_uuid_dict.update(links_uuids)

return retrieved_nodes, list(links_uuid_dict.values()), traversal_rules
# Creating a set of pks that are the process nodes and data nodes:
starting_nodes = process_nodes.union(data_nodes)

# Calling the traverser, which will traverse the graph exhaustively until it converges to
# a stable subgraph
traverse_results = traverse_graph(starting_nodes, get_links=True, **traversal_rules)

# The exporter wants a different form, a list of dictionaries for every link.
# I create a utility dictionary for pk to uuid. In the future, this should be
# ideally done in the interface or AGE.
if traverse_results['nodes']:
pk_2_uuid_dict = {
pk: uuid for pk, uuid in
QueryBuilder().append(Node, project=('id', 'uuid'), filters={
'id': {
'in': traverse_results['nodes']
}
}).all()
}
else:
pk_2_uuid_dict = {}

# I transform the links from a set of tupled to a list of dictionaries, where every dictionary
# contains the input_id, output_id, label and type under the keys 'input', 'output', 'label', and 'type',
# respectively:
#if traverse_results['links']:
links_uuid_list = [{
'input': pk_2_uuid_dict[link.source_id],
'output': pk_2_uuid_dict[link.target_id],
'label': link.link_label,
'type': link.link_type
} for link in traverse_results['links']]
#else:
# links_uuid_list = []

return traverse_results['nodes'], links_uuid_list, traversal_rules

0 comments on commit 93afc88

Please sign in to comment.