Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine: implement sub-dicts for context in workchains #4871

Merged
merged 2 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import collections.abc
import functools
import logging
from typing import Any, List, Optional, Sequence, Union, TYPE_CHECKING
from typing import Any, List, Optional, Sequence, Union, Tuple, TYPE_CHECKING

from plumpy.persistence import auto_persist
from plumpy.process_states import Wait, Continue
Expand Down Expand Up @@ -121,25 +121,59 @@ def on_run(self):
super().on_run()
self.node.set_stepper_state_info(str(self._stepper))

def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]:
"""
Returns a reference to a sub-dictionary of the context and the last key,
after resolving a potentially segmented key where required sub-dictionaries are created as needed.

:param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary
"""
ctx = self.ctx
ctx_path = key.split('.')

for index, path in enumerate(ctx_path[:-1]):
try:
ctx = ctx[path]
except KeyError: # see below why this is the only exception we have to catch here
ctx[path] = AttributeDict() # create the sub-dict and update the context
ctx = ctx[path]
continue

# Notes:
# * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking
# * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables
# (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself
# * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable
# would be an AttributeDict we can append things to it since the order of tasks is maintained.
if type(ctx) != AttributeDict: # pylint: disable=C0123
raise ValueError(
f'Can not update the context for key `{key}`:'
f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index+1])}`, expected AttributeDict'
)

return ctx, ctx_path[-1]

def insert_awaitable(self, awaitable: Awaitable) -> None:
"""Insert an awaitable that should be terminated before before continuing to the next step.

:param awaitable: the thing to await
:type awaitable: :class:`aiida.engine.processes.workchains.awaitable.Awaitable`
"""
self._awaitables.append(awaitable)
ctx, key = self._resolve_nested_context(awaitable.key)

# Already assign the awaitable itself to the location in the context container where it is supposed to end up
# once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the
# order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the
# awaitable as a placeholder, in the `resolve_awaitable`, it can be found and replaced by the resolved value.
if awaitable.action == AwaitableAction.ASSIGN:
self.ctx[awaitable.key] = awaitable
ctx[key] = awaitable
elif awaitable.action == AwaitableAction.APPEND:
self.ctx.setdefault(awaitable.key, []).append(awaitable)
ctx.setdefault(key, []).append(awaitable)
else:
assert f'Unknown awaitable action: {awaitable.action}'
raise AssertionError(f'Unsupported awaitable action: {awaitable.action}')

self._awaitables.append(
awaitable
) # add only if everything went ok, otherwise we end up in an inconsistent state
self._update_process_status()

def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
Expand All @@ -149,23 +183,25 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:

:param awaitable: the awaitable to resolve
"""
self._awaitables.remove(awaitable)

ctx, key = self._resolve_nested_context(awaitable.key)

if awaitable.action == AwaitableAction.ASSIGN:
self.ctx[awaitable.key] = value
ctx[key] = value
elif awaitable.action == AwaitableAction.APPEND:
# Find the same awaitable inserted in the context
container = self.ctx[awaitable.key]
container = ctx[key]
for index, placeholder in enumerate(container):
if placeholder.pk == awaitable.pk and isinstance(placeholder, Awaitable):
if isinstance(placeholder, Awaitable) and placeholder.pk == awaitable.pk:
container[index] = value
break
else:
assert f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.pk}`'
raise AssertionError(f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.key}`')
else:
assert f'Unknown awaitable action: {awaitable.action}'
raise AssertionError(f'Unsupported awaitable action: {awaitable.action}')

awaitable.resolved = True
self._awaitables.remove(awaitable) # remove only if everything went ok, otherwise we may lose track

if not self.has_terminated():
# the process may be terminated, for example, if the process was killed or excepted
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from aiida.engine import WorkChain


class SomeWorkChain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(
cls.submit_workchains,
cls.inspect_workchains,
)

def submit_workchains(self):
for i in range(3):
future = self.submit(SomeWorkChain)
key = f'workchain.sub{i}'
self.to_context(**{key: future})

def inspect_workchains(self):
for i in range(3):
assert self.ctx.workchain[f'sub{i}'].is_finished_ok
11 changes: 11 additions & 0 deletions docs/source/topics/workflows/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,17 @@ The ``self.ctx.workchains`` now contains a list with the nodes of the completed
Note that the use of ``append_`` is not just limited to the ``to_context`` method.
You can also use it in exactly the same way with ``ToContext`` to append a process to a list in the context in multiple outline steps.

Nested context keys
^^^^^^^^^^^^^^^^^^^

To simplify the organization of the context, the keys may contain dots ``.``, transparently creating namespaces in the process.
As an example compare the following to the parallel submission example above:

.. include:: include/snippets/workchains/run_workchain_submit_append.py
:code: python

This allows to create intuitively grouped and easily accessible structures of child calculations or workchains.

.. _topics:workflows:usage:workchains:reporting:

Reporting
Expand Down
178 changes: 177 additions & 1 deletion tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from aiida.common import exceptions
from aiida.common.links import LinkType
from aiida.common.utils import Capturing
from aiida.engine import ExitCode, Process, ToContext, WorkChain, if_, while_, return_, launch, calcfunction
from aiida.engine import ExitCode, Process, ToContext, WorkChain, if_, while_, return_, launch, calcfunction, append_
from aiida.engine.persistence import ObjectLoader
from aiida.manage.manager import get_manager
from aiida.orm import load_node, Bool, Float, Int, Str
Expand Down Expand Up @@ -780,6 +780,182 @@ def result(self):

run_and_check_success(Workchain)

def test_nested_to_context(self):
val = Int(5).store()

test_case = self

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(**{'sub1.sub2.result_a': self.submit(SimpleWc)})
return ToContext(**{'sub1.sub2.result_b': self.submit(SimpleWc)})

def result(self):
test_case.assertEqual(self.ctx.sub1.sub2.result_a.outputs.result, val)
test_case.assertEqual(self.ctx.sub1.sub2.result_b.outputs.result, val)

run_and_check_success(Workchain)

def test_nested_to_context_with_append(self):
val1 = Int(5).store()
val2 = Int(6).store()

test_case = self

class SimpleWc1(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val1)

class SimpleWc2(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val2)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(**{'sub1.workchains': append_(self.submit(SimpleWc1))})
return ToContext(**{'sub1.workchains': append_(self.submit(SimpleWc2))})
dev-zero marked this conversation as resolved.
Show resolved Hide resolved

def result(self):
test_case.assertEqual(self.ctx.sub1.workchains[0].outputs.result, val1)
test_case.assertEqual(self.ctx.sub1.workchains[1].outputs.result, val2)

run_and_check_success(Workchain)

def test_nested_to_context_no_overlap(self):
val = Int(5).store()

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(**{'result_a': self.submit(SimpleWc)})
return ToContext(**{'result_a.sub1': self.submit(SimpleWc)})

def result(self):
raise RuntimeError('Never reached: the second to_context above should fail')

process = Workchain()
with pytest.raises(ValueError):
launch.run(process)

def test_nested_to_context_no_overlap_with_append(self):
val = Int(5).store()

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(workchains=append_(self.submit(SimpleWc))) # make the workchains point to a list
return ToContext(**{'workchains.sub1.sub2': self.submit(SimpleWc)}) # now try to treat it as a sub-dict

def result(self):
raise RuntimeError('Never reached: the second to_context above should fail')

process = Workchain()
with pytest.raises(ValueError):
launch.run(process)

def test_nested_to_context_no_overlap_with_append2(self):
val = Int(5).store()

class SimpleWc(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.result)
spec.outputs.dynamic = True

def result(self):
self.out('result', val)

class Workchain(WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.begin, cls.result)

def begin(self):
self.to_context(workchains=append_(self.submit(SimpleWc))) # make the workchains point to a list
return ToContext(
**{'workchains.sub1': self.submit(SimpleWc)}
) # now try to treat the final path element it as a sub-dict

def result(self):
raise RuntimeError('Never reached: the second to_context above should fail')

process = Workchain()
with pytest.raises(ValueError):
launch.run(process)

def test_namespace_nondb_mapping(self):
"""
Regression test for a bug in _flatten_inputs
Expand Down