Skip to content

Simplify updating the format by defining the labels only once #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions python_workflow_definition/src/python_workflow_definition/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
from aiida_workgraph import WorkGraph, task
from aiida_workgraph.socket import TaskSocketNamespace

from python_workflow_definition.shared import convert_nodes_list_to_dict
from python_workflow_definition.shared import (
convert_nodes_list_to_dict,
NODES_LABEL,
EDGES_LABEL,
SOURCE_LABEL,
SOURCE_PORT_LABEL,
TARGET_LABEL,
TARGET_PORT_LABEL,
)


def load_workflow_json(file_name):
Expand All @@ -17,7 +25,7 @@ def load_workflow_json(file_name):
wg = WorkGraph()
task_name_mapping = {}

for id, identifier in convert_nodes_list_to_dict(nodes_list=data["nodes"]).items():
for id, identifier in convert_nodes_list_to_dict(nodes_list=data[NODES_LABEL]).items():
if isinstance(identifier, str) and "." in identifier:
p, m = identifier.rsplit(".", 1)
mod = import_module(p)
Expand All @@ -32,33 +40,33 @@ def load_workflow_json(file_name):
task_name_mapping[id] = data_node

# add links
for link in data["edges"]:
to_task = task_name_mapping[str(link["target"])]
for link in data[EDGES_LABEL]:
to_task = task_name_mapping[str(link[TARGET_LABEL])]
# if the input is not exit, it means we pass the data into to the kwargs
# in this case, we add the input socket
if link["targetPort"] not in to_task.inputs:
to_socket = to_task.add_input( "workgraph.any", name=link["targetPort"])
if link[TARGET_PORT_LABEL] not in to_task.inputs:
to_socket = to_task.add_input( "workgraph.any", name=link[TARGET_PORT_LABEL])
else:
to_socket = to_task.inputs[link["targetPort"]]
from_task = task_name_mapping[str(link["source"])]
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
if isinstance(from_task, orm.Data):
to_socket.value = from_task
else:
try:
if link["sourcePort"] is None:
link["sourcePort"] = "result"
if link[SOURCE_PORT_LABEL] is None:
link[SOURCE_PORT_LABEL] = "result"
# because we are not define the outputs explicitly during the pythonjob creation
# we add it here, and assume the output exit
if link["sourcePort"] not in from_task.outputs:
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
# if str(link["sourcePort"]) not in from_task.outputs:
from_socket = from_task.add_output(
"workgraph.any",
name=link["sourcePort"],
name=link[SOURCE_PORT_LABEL],
# name=str(link["sourcePort"]),
metadata={"is_function_output": True},
)
else:
from_socket = from_task.outputs[link["sourcePort"]]
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]

wg.add_link(from_socket, to_socket)
except Exception as e:
Expand All @@ -68,7 +76,7 @@ def load_workflow_json(file_name):


def write_workflow_json(wg, file_name):
data = {"nodes": [], "edges": []}
data = {NODES_LABEL: [], EDGES_LABEL: []}
node_name_mapping = {}
data_node_name_mapping = {}
i = 0
Expand All @@ -78,19 +86,19 @@ def write_workflow_json(wg, file_name):

callable_name = executor["callable_name"]
callable_name = f"{executor['module_path']}.{callable_name}"
data["nodes"].append({"id": i, "function": callable_name})
data[NODES_LABEL].append({"id": i, "function": callable_name})
i += 1

for link in wg.links:
link_data = link.to_dict()
# if the from socket is the default result, we set it to None
if link_data["from_socket"] == "result":
link_data["from_socket"] = None
link_data["target"] = node_name_mapping[link_data.pop("to_node")]
link_data["targetPort"] = link_data.pop("to_socket")
link_data["source"] = node_name_mapping[link_data.pop("from_node")]
link_data["sourcePort"] = link_data.pop("from_socket")
data["edges"].append(link_data)
link_data[TARGET_LABEL] = node_name_mapping[link_data.pop("to_node")]
link_data[TARGET_PORT_LABEL] = link_data.pop("to_socket")
link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
data[EDGES_LABEL].append(link_data)

for node in wg.tasks:
for input in node.inputs:
Expand All @@ -107,17 +115,17 @@ def write_workflow_json(wg, file_name):
raw_value.pop("node_type", None)
else:
raw_value = input.value.value
data["nodes"].append({"id": i, "value": raw_value})
data[NODES_LABEL].append({"id": i, "value": raw_value})
input_node_name = i
data_node_name_mapping[input.value.uuid] = input_node_name
i += 1
else:
input_node_name = data_node_name_mapping[input.value.uuid]
data["edges"].append({
"target": node_name_mapping[node.name],
"targetPort": input._name,
"source": input_node_name,
"sourcePort": None
data[EDGES_LABEL].append({
TARGET_LABEL: node_name_mapping[node.name],
TARGET_PORT_LABEL: input._name,
SOURCE_LABEL: input_node_name,
SOURCE_PORT_LABEL: None
})
with open(file_name, "w") as f:
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
from inspect import isfunction


from python_workflow_definition.shared import get_dict, get_list, get_kwargs, get_source_handles, convert_nodes_list_to_dict
from python_workflow_definition.shared import (
get_dict,
get_list,
get_kwargs,
get_source_handles,
convert_nodes_list_to_dict,
NODES_LABEL,
EDGES_LABEL,
SOURCE_LABEL,
SOURCE_PORT_LABEL,
)
from python_workflow_definition.purepython import resort_total_lst, group_edges


Expand All @@ -12,7 +22,7 @@ def get_item(obj, key):


def _get_value(result_dict, nodes_new_dict, link_dict, exe):
source, source_handle = link_dict["source"], link_dict["sourcePort"]
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
if source in result_dict.keys():
result = result_dict[source]
elif source in nodes_new_dict.keys():
Expand All @@ -29,10 +39,10 @@ def load_workflow_json(file_name, exe):
with open(file_name, "r") as f:
content = json.load(f)

edges_new_lst = content["edges"]
edges_new_lst = content[EDGES_LABEL]
nodes_new_dict = {}

for k, v in convert_nodes_list_to_dict(nodes_list=content["nodes"]).items():
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
if isinstance(v, str) and "." in v:
p, m = v.rsplit('.', 1)
mod = import_module(p)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@
import numpy as np
from jobflow import job, Flow

from python_workflow_definition.shared import get_dict, get_list, get_kwargs, get_source_handles, convert_nodes_list_to_dict
from python_workflow_definition.shared import (
get_dict,
get_list,
get_kwargs,
get_source_handles,
convert_nodes_list_to_dict,
NODES_LABEL,
EDGES_LABEL,
SOURCE_LABEL,
SOURCE_PORT_LABEL,
TARGET_LABEL,
TARGET_PORT_LABEL,
)


def _get_function_dict(flow):
Expand All @@ -26,9 +38,19 @@ def _get_nodes_dict(function_dict):

def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
if len(value_dict['attributes']) == 1:
return {"target": target, "targetPort": key, "source": nodes_mapping_dict[value_dict["uuid"]], "sourcePort": value_dict["attributes"][0][1]}
return {
TARGET_LABEL: target,
TARGET_PORT_LABEL: key,
SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
SOURCE_PORT_LABEL: value_dict["attributes"][0][1],
}
else:
return {"target": target, "targetPort": key, "source": nodes_mapping_dict[value_dict["uuid"]], "sourcePort": None}
return {
TARGET_LABEL: target,
TARGET_PORT_LABEL: key,
SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
SOURCE_PORT_LABEL: None,
}


def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
Expand Down Expand Up @@ -59,8 +81,8 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
nodes_dict[node_index] = vt
else:
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
edges_lst.append({"target": node_dict_index, "targetPort": kt, "source": node_index, "sourcePort": None})
edges_lst.append({"target": nodes_mapping_dict[job["uuid"]], "targetPort": k, "source": node_dict_index, "sourcePort": None})
edges_lst.append({TARGET_LABEL: node_dict_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_dict_index, SOURCE_PORT_LABEL: None})
elif isinstance(v, list) and any([isinstance(el, dict) and "@module" in el and "@class" in el and "@version" in el for el in v]):
node_list_index = len(nodes_dict)
nodes_dict[node_list_index] = get_list
Expand All @@ -78,15 +100,15 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
nodes_dict[node_index] = vt
else:
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
edges_lst.append({"target": node_list_index, "targetPort": kt, "source": node_index, "sourcePort": None})
edges_lst.append({"target": nodes_mapping_dict[job["uuid"]], "targetPort": k, "source": node_list_index, "sourcePort": None})
edges_lst.append({TARGET_LABEL: node_list_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_list_index, SOURCE_PORT_LABEL: None})
else:
if v not in nodes_dict.values():
node_index = len(nodes_dict)
nodes_dict[node_index] = v
else:
node_index = {tv: tk for tk, tv in nodes_dict.items()}[v]
edges_lst.append({"target": nodes_mapping_dict[job["uuid"]], "targetPort": k, "source": node_index, "sourcePort": None})
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
return edges_lst, nodes_dict


Expand All @@ -99,7 +121,7 @@ def _resort_total_lst(total_dict, nodes_dict):
for ind in sorted(total_dict.keys()):
connect = total_dict[ind]
if ind not in ordered_lst:
source_lst = [sd["source"] for sd in connect.values()]
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
ordered_lst.append(ind)
total_new_dict[ind] = connect
Expand All @@ -109,11 +131,11 @@ def _resort_total_lst(total_dict, nodes_dict):
def _group_edges(edges_lst):
total_dict = {}
for ed_major in edges_lst:
target_id = ed_major["target"]
target_id = ed_major[TARGET_LABEL]
tmp_lst = []
if target_id not in total_dict.keys():
for ed in edges_lst:
if target_id == ed["target"]:
if target_id == ed[TARGET_LABEL]:
tmp_lst.append(ed)
total_dict[target_id] = get_kwargs(lst=tmp_lst)
return total_dict
Expand All @@ -139,8 +161,8 @@ def get_attr_helper(obj, source_handle):
else:
fn = job(method=v)
kwargs = {
kw: input_dict[vw["source"]] if vw["source"] in input_dict else get_attr_helper(
obj=memory_dict[vw["source"]], source_handle=vw["sourcePort"])
kw: input_dict[vw[SOURCE_LABEL]] if vw[SOURCE_LABEL] in input_dict else get_attr_helper(
obj=memory_dict[vw[SOURCE_LABEL]], source_handle=vw[SOURCE_PORT_LABEL])
for kw, vw in total_dict[k].items()
}
memory_dict[k] = fn(**kwargs)
Expand All @@ -159,21 +181,21 @@ def load_workflow_json(file_name):
content = json.load(f)

edges_new_lst = []
for edge in content["edges"]:
if edge["sourcePort"] is None:
for edge in content[EDGES_LABEL]:
if edge[SOURCE_PORT_LABEL] is None:
edges_new_lst.append(edge)
else:
edges_new_lst.append(
{
"target": edge["target"],
"targetPort": edge["targetPort"],
"source": edge["source"],
"sourcePort": str(edge["sourcePort"]),
TARGET_LABEL: edge[TARGET_LABEL],
TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
SOURCE_LABEL: edge[SOURCE_LABEL],
SOURCE_PORT_LABEL: str(edge[SOURCE_PORT_LABEL]),
}
)

nodes_new_dict = {}
for k, v in convert_nodes_list_to_dict(nodes_list=content["nodes"]).items():
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
if isinstance(v, str) and "." in v:
p, m = v.rsplit('.', 1)
mod = import_module(p)
Expand Down Expand Up @@ -214,4 +236,4 @@ def write_workflow_json(flow, file_name="workflow.json"):
nodes_store_lst.append({"id": k, "value": v})

with open(file_name, "w") as f:
json.dump({"nodes": nodes_store_lst, "edges": edges_lst}, f)
json.dump({NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_lst}, f)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
from inspect import isfunction


from python_workflow_definition.shared import get_dict, get_list, get_kwargs, get_source_handles, convert_nodes_list_to_dict
from python_workflow_definition.shared import (
get_dict,
get_list,
get_kwargs,
get_source_handles,
convert_nodes_list_to_dict,
NODES_LABEL,
EDGES_LABEL,
SOURCE_LABEL,
SOURCE_PORT_LABEL,
TARGET_LABEL,
TARGET_PORT_LABEL,
)


def resort_total_lst(total_lst, nodes_dict):
Expand All @@ -13,30 +25,30 @@ def resort_total_lst(total_lst, nodes_dict):
while len(total_new_lst) < len(total_lst):
for ind, connect in total_lst:
if ind not in ordered_lst:
source_lst = [sd["source"] for sd in connect.values()]
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
ordered_lst.append(ind)
total_new_lst.append([ind, connect])
return total_new_lst


def group_edges(edges_lst):
edges_sorted_lst = sorted(edges_lst, key=lambda x: x["target"], reverse=True)
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
total_lst, tmp_lst = [], []
target_id = edges_sorted_lst[0]["target"]
target_id = edges_sorted_lst[0][TARGET_LABEL]
for ed in edges_sorted_lst:
if target_id == ed["target"]:
if target_id == ed[TARGET_LABEL]:
tmp_lst.append(ed)
else:
total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
target_id = ed["target"]
target_id = ed[TARGET_LABEL]
tmp_lst = [ed]
total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
return total_lst


def _get_value(result_dict, nodes_new_dict, link_dict):
source, source_handle = link_dict["source"], link_dict["sourcePort"]
source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
if source in result_dict.keys():
result = result_dict[source]
elif source in nodes_new_dict.keys():
Expand All @@ -53,9 +65,9 @@ def load_workflow_json(file_name):
with open(file_name, "r") as f:
content = json.load(f)

edges_new_lst = content["edges"]
edges_new_lst = content[EDGES_LABEL]
nodes_new_dict = {}
for k, v in convert_nodes_list_to_dict(nodes_list=content["nodes"]).items():
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
if isinstance(v, str) and "." in v:
p, m = v.rsplit('.', 1)
mod = import_module(p)
Expand Down
Loading
Loading