Skip to content

Commit a399187

Browse files
authored
black formatting (#64)
1 parent 79f86f9 commit a399187

File tree

7 files changed

+239
-90
lines changed

7 files changed

+239
-90
lines changed

python_workflow_definition/src/python_workflow_definition/aiida.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def load_workflow_json(file_name):
2525
wg = WorkGraph()
2626
task_name_mapping = {}
2727

28-
for id, identifier in convert_nodes_list_to_dict(nodes_list=data[NODES_LABEL]).items():
28+
for id, identifier in convert_nodes_list_to_dict(
29+
nodes_list=data[NODES_LABEL]
30+
).items():
2931
if isinstance(identifier, str) and "." in identifier:
3032
p, m = identifier.rsplit(".", 1)
3133
mod = import_module(p)
@@ -45,7 +47,7 @@ def load_workflow_json(file_name):
4547
# if the input is not exit, it means we pass the data into to the kwargs
4648
# in this case, we add the input socket
4749
if link[TARGET_PORT_LABEL] not in to_task.inputs:
48-
to_socket = to_task.add_input( "workgraph.any", name=link[TARGET_PORT_LABEL])
50+
to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
4951
else:
5052
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
5153
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
@@ -58,7 +60,7 @@ def load_workflow_json(file_name):
5860
# because we are not define the outputs explicitly during the pythonjob creation
5961
# we add it here, and assume the output exit
6062
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
61-
# if str(link["sourcePort"]) not in from_task.outputs:
63+
# if str(link["sourcePort"]) not in from_task.outputs:
6264
from_socket = from_task.add_output(
6365
"workgraph.any",
6466
name=link[SOURCE_PORT_LABEL],
@@ -99,7 +101,7 @@ def write_workflow_json(wg, file_name):
99101
link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
100102
link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
101103
data[EDGES_LABEL].append(link_data)
102-
104+
103105
for node in wg.tasks:
104106
for input in node.inputs:
105107
# assume namespace is not used as input
@@ -121,12 +123,14 @@ def write_workflow_json(wg, file_name):
121123
i += 1
122124
else:
123125
input_node_name = data_node_name_mapping[input.value.uuid]
124-
data[EDGES_LABEL].append({
125-
TARGET_LABEL: node_name_mapping[node.name],
126-
TARGET_PORT_LABEL: input._name,
127-
SOURCE_LABEL: input_node_name,
128-
SOURCE_PORT_LABEL: None
129-
})
126+
data[EDGES_LABEL].append(
127+
{
128+
TARGET_LABEL: node_name_mapping[node.name],
129+
TARGET_PORT_LABEL: input._name,
130+
SOURCE_LABEL: input_node_name,
131+
SOURCE_PORT_LABEL: None,
132+
}
133+
)
130134
with open(file_name, "w") as f:
131135
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)
132136
json.dump(data, f, indent=2)

python_workflow_definition/src/python_workflow_definition/executorlib.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def load_workflow_json(file_name, exe):
4444

4545
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
4646
if isinstance(v, str) and "." in v:
47-
p, m = v.rsplit('.', 1)
47+
p, m = v.rsplit(".", 1)
4848
mod = import_module(p)
4949
nodes_new_dict[int(k)] = getattr(mod, m)
5050
else:
@@ -59,7 +59,12 @@ def load_workflow_json(file_name, exe):
5959
node = nodes_new_dict[lst[0]]
6060
if isfunction(node):
6161
kwargs = {
62-
k: _get_value(result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v, exe=exe)
62+
k: _get_value(
63+
result_dict=result_dict,
64+
nodes_new_dict=nodes_new_dict,
65+
link_dict=v,
66+
exe=exe,
67+
)
6368
for k, v in lst[1].items()
6469
}
6570
result_dict[lst[0]] = exe.submit(node, **kwargs)

python_workflow_definition/src/python_workflow_definition/jobflow.py

Lines changed: 131 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@
2121

2222

2323
def _get_function_dict(flow):
24-
return {
25-
job.uuid: job.function
26-
for job in flow.jobs
27-
}
24+
return {job.uuid: job.function for job in flow.jobs}
2825

2926

3027
def _get_nodes_dict(function_dict):
@@ -37,7 +34,7 @@ def _get_nodes_dict(function_dict):
3734

3835

3936
def _get_edge_from_dict(target, key, value_dict, nodes_mapping_dict):
40-
if len(value_dict['attributes']) == 1:
37+
if len(value_dict["attributes"]) == 1:
4138
return {
4239
TARGET_LABEL: target,
4340
TARGET_PORT_LABEL: key,
@@ -57,72 +54,152 @@ def _get_edges_and_extend_nodes(flow_dict, nodes_mapping_dict, nodes_dict):
5754
edges_lst = []
5855
for job in flow_dict["jobs"]:
5956
for k, v in job["function_kwargs"].items():
60-
if isinstance(v, dict) and "@module" in v and "@class" in v and "@version" in v:
61-
edges_lst.append(_get_edge_from_dict(
62-
target=nodes_mapping_dict[job["uuid"]],
63-
key=k,
64-
value_dict=v,
65-
nodes_mapping_dict=nodes_mapping_dict,
66-
))
67-
elif isinstance(v, dict) and any([isinstance(el, dict) and "@module" in el and "@class" in el and "@version" in el for el in v.values()]):
57+
if (
58+
isinstance(v, dict)
59+
and "@module" in v
60+
and "@class" in v
61+
and "@version" in v
62+
):
63+
edges_lst.append(
64+
_get_edge_from_dict(
65+
target=nodes_mapping_dict[job["uuid"]],
66+
key=k,
67+
value_dict=v,
68+
nodes_mapping_dict=nodes_mapping_dict,
69+
)
70+
)
71+
elif isinstance(v, dict) and any(
72+
[
73+
isinstance(el, dict)
74+
and "@module" in el
75+
and "@class" in el
76+
and "@version" in el
77+
for el in v.values()
78+
]
79+
):
6880
node_dict_index = len(nodes_dict)
6981
nodes_dict[node_dict_index] = get_dict
7082
for kt, vt in v.items():
71-
if isinstance(vt, dict) and "@module" in vt and "@class" in vt and "@version" in vt:
72-
edges_lst.append(_get_edge_from_dict(
73-
target=node_dict_index,
74-
key=kt,
75-
value_dict=vt,
76-
nodes_mapping_dict=nodes_mapping_dict,
77-
))
83+
if (
84+
isinstance(vt, dict)
85+
and "@module" in vt
86+
and "@class" in vt
87+
and "@version" in vt
88+
):
89+
edges_lst.append(
90+
_get_edge_from_dict(
91+
target=node_dict_index,
92+
key=kt,
93+
value_dict=vt,
94+
nodes_mapping_dict=nodes_mapping_dict,
95+
)
96+
)
7897
else:
7998
if vt not in nodes_dict.values():
8099
node_index = len(nodes_dict)
81100
nodes_dict[node_index] = vt
82101
else:
83-
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
84-
edges_lst.append({TARGET_LABEL: node_dict_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
85-
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_dict_index, SOURCE_PORT_LABEL: None})
86-
elif isinstance(v, list) and any([isinstance(el, dict) and "@module" in el and "@class" in el and "@version" in el for el in v]):
102+
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
103+
str(vt)
104+
]
105+
edges_lst.append(
106+
{
107+
TARGET_LABEL: node_dict_index,
108+
TARGET_PORT_LABEL: kt,
109+
SOURCE_LABEL: node_index,
110+
SOURCE_PORT_LABEL: None,
111+
}
112+
)
113+
edges_lst.append(
114+
{
115+
TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
116+
TARGET_PORT_LABEL: k,
117+
SOURCE_LABEL: node_dict_index,
118+
SOURCE_PORT_LABEL: None,
119+
}
120+
)
121+
elif isinstance(v, list) and any(
122+
[
123+
isinstance(el, dict)
124+
and "@module" in el
125+
and "@class" in el
126+
and "@version" in el
127+
for el in v
128+
]
129+
):
87130
node_list_index = len(nodes_dict)
88131
nodes_dict[node_list_index] = get_list
89132
for kt, vt in enumerate(v):
90-
if isinstance(vt, dict) and "@module" in vt and "@class" in vt and "@version" in vt:
91-
edges_lst.append(_get_edge_from_dict(
92-
target=node_list_index,
93-
key=str(kt),
94-
value_dict=vt,
95-
nodes_mapping_dict=nodes_mapping_dict,
96-
))
133+
if (
134+
isinstance(vt, dict)
135+
and "@module" in vt
136+
and "@class" in vt
137+
and "@version" in vt
138+
):
139+
edges_lst.append(
140+
_get_edge_from_dict(
141+
target=node_list_index,
142+
key=str(kt),
143+
value_dict=vt,
144+
nodes_mapping_dict=nodes_mapping_dict,
145+
)
146+
)
97147
else:
98148
if vt not in nodes_dict.values():
99149
node_index = len(nodes_dict)
100150
nodes_dict[node_index] = vt
101151
else:
102-
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[str(vt)]
103-
edges_lst.append({TARGET_LABEL: node_list_index, TARGET_PORT_LABEL: kt, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
104-
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_list_index, SOURCE_PORT_LABEL: None})
152+
node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
153+
str(vt)
154+
]
155+
edges_lst.append(
156+
{
157+
TARGET_LABEL: node_list_index,
158+
TARGET_PORT_LABEL: kt,
159+
SOURCE_LABEL: node_index,
160+
SOURCE_PORT_LABEL: None,
161+
}
162+
)
163+
edges_lst.append(
164+
{
165+
TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
166+
TARGET_PORT_LABEL: k,
167+
SOURCE_LABEL: node_list_index,
168+
SOURCE_PORT_LABEL: None,
169+
}
170+
)
105171
else:
106172
if v not in nodes_dict.values():
107173
node_index = len(nodes_dict)
108174
nodes_dict[node_index] = v
109175
else:
110176
node_index = {tv: tk for tk, tv in nodes_dict.items()}[v]
111-
edges_lst.append({TARGET_LABEL: nodes_mapping_dict[job["uuid"]], TARGET_PORT_LABEL: k, SOURCE_LABEL: node_index, SOURCE_PORT_LABEL: None})
177+
edges_lst.append(
178+
{
179+
TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
180+
TARGET_PORT_LABEL: k,
181+
SOURCE_LABEL: node_index,
182+
SOURCE_PORT_LABEL: None,
183+
}
184+
)
112185
return edges_lst, nodes_dict
113186

114187

115188
def _resort_total_lst(total_dict, nodes_dict):
116189
nodes_with_dep_lst = list(sorted(total_dict.keys()))
117-
nodes_without_dep_lst = [k for k in nodes_dict.keys() if k not in nodes_with_dep_lst]
190+
nodes_without_dep_lst = [
191+
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
192+
]
118193
ordered_lst = []
119194
total_new_dict = {}
120195
while len(total_new_dict) < len(total_dict):
121196
for ind in sorted(total_dict.keys()):
122197
connect = total_dict[ind]
123198
if ind not in ordered_lst:
124199
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
125-
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
200+
if all(
201+
[s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
202+
):
126203
ordered_lst.append(ind)
127204
total_new_dict[ind] = connect
128205
return total_new_dict
@@ -142,7 +219,7 @@ def _group_edges(edges_lst):
142219

143220

144221
def _get_input_dict(nodes_dict):
145-
return {k:v for k, v in nodes_dict.items() if not isfunction(v)}
222+
return {k: v for k, v in nodes_dict.items() if not isfunction(v)}
146223

147224

148225
def _get_workflow(nodes_dict, input_dict, total_dict, source_handles_dict):
@@ -157,12 +234,21 @@ def get_attr_helper(obj, source_handle):
157234
v = nodes_dict[k]
158235
if isfunction(v):
159236
if k in source_handles_dict.keys():
160-
fn = job(method=v, data=[el for el in source_handles_dict[k] if el is not None])
237+
fn = job(
238+
method=v,
239+
data=[el for el in source_handles_dict[k] if el is not None],
240+
)
161241
else:
162242
fn = job(method=v)
163243
kwargs = {
164-
kw: input_dict[vw[SOURCE_LABEL]] if vw[SOURCE_LABEL] in input_dict else get_attr_helper(
165-
obj=memory_dict[vw[SOURCE_LABEL]], source_handle=vw[SOURCE_PORT_LABEL])
244+
kw: (
245+
input_dict[vw[SOURCE_LABEL]]
246+
if vw[SOURCE_LABEL] in input_dict
247+
else get_attr_helper(
248+
obj=memory_dict[vw[SOURCE_LABEL]],
249+
source_handle=vw[SOURCE_PORT_LABEL],
250+
)
251+
)
166252
for kw, vw in total_dict[k].items()
167253
}
168254
memory_dict[k] = fn(**kwargs)
@@ -197,7 +283,7 @@ def load_workflow_json(file_name):
197283
nodes_new_dict = {}
198284
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
199285
if isinstance(v, str) and "." in v:
200-
p, m = v.rsplit('.', 1)
286+
p, m = v.rsplit(".", 1)
201287
mod = import_module(p)
202288
nodes_new_dict[int(k)] = getattr(mod, m)
203289
else:
@@ -229,7 +315,9 @@ def write_workflow_json(flow, file_name="workflow.json"):
229315
nodes_store_lst = []
230316
for k, v in nodes_dict.items():
231317
if isfunction(v):
232-
nodes_store_lst.append({"id": k, "function": v.__module__ + "." + v.__name__})
318+
nodes_store_lst.append(
319+
{"id": k, "function": v.__module__ + "." + v.__name__}
320+
)
233321
elif isinstance(v, np.ndarray):
234322
nodes_store_lst.append({"id": k, "value": v.tolist()})
235323
else:

python_workflow_definition/src/python_workflow_definition/plot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def plot(file_name):
3535
if v[SOURCE_PORT_LABEL] is None:
3636
edge_label_dict[v[SOURCE_LABEL]].append(k)
3737
else:
38-
edge_label_dict[v[SOURCE_LABEL]].append(k + "=result[" + v[SOURCE_PORT_LABEL] + "]")
38+
edge_label_dict[v[SOURCE_LABEL]].append(
39+
k + "=result[" + v[SOURCE_PORT_LABEL] + "]"
40+
)
3941
for k, v in edge_label_dict.items():
4042
graph.add_edge(str(k), str(target_node), label=", ".join(v))
4143

python_workflow_definition/src/python_workflow_definition/purepython.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020

2121
def resort_total_lst(total_lst, nodes_dict):
2222
nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
23-
nodes_without_dep_lst = [k for k in nodes_dict.keys() if k not in nodes_with_dep_lst]
23+
nodes_without_dep_lst = [
24+
k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
25+
]
2426
ordered_lst, total_new_lst = [], []
2527
while len(total_new_lst) < len(total_lst):
2628
for ind, connect in total_lst:
2729
if ind not in ordered_lst:
2830
source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
29-
if all([s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]):
31+
if all(
32+
[s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
33+
):
3034
ordered_lst.append(ind)
3135
total_new_lst.append([ind, connect])
3236
return total_new_lst
@@ -69,7 +73,7 @@ def load_workflow_json(file_name):
6973
nodes_new_dict = {}
7074
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
7175
if isinstance(v, str) and "." in v:
72-
p, m = v.rsplit('.', 1)
76+
p, m = v.rsplit(".", 1)
7377
mod = import_module(p)
7478
nodes_new_dict[int(k)] = getattr(mod, m)
7579
else:
@@ -84,7 +88,9 @@ def load_workflow_json(file_name):
8488
node = nodes_new_dict[lst[0]]
8589
if isfunction(node):
8690
kwargs = {
87-
k: _get_value(result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v)
91+
k: _get_value(
92+
result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v
93+
)
8894
for k, v in lst[1].items()
8995
}
9096
result_dict[lst[0]] = node(**kwargs)

0 commit comments

Comments
 (0)