Skip to content

Commit

Permalink
PwCalculation: refactor parent_folder validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed May 5, 2023
1 parent 7f4c4a1 commit 2a20c3f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 36 deletions.
53 changes: 23 additions & 30 deletions src/aiida_quantumespresso/calculations/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,30 @@ def define(cls, spec):
'is `False` and/or `electron_maxstep` is 0.')
# yapf: enable

@staticmethod
def validate_inputs_base(value, _):
"""Validate the top level namespace."""
@classmethod
def validate_inputs(cls, value, port_namespace):
"""Validate the top level namespace.
Check that the restart input parameters are set correctly. In case of 'nscf' and 'bands' calculations, this
means ``parent_folder`` is provided. For other calculations, if the ``parent_folder`` is provided, the restart
settings must be set to use some of the outputs.
Note that the validator will only check the logic in case the ``parent_folder`` is a port in the
``port_namespace``. This is because the ``PwCalculation`` can be wrapped inside a work chain that only provides
the ``parent_folder`` input at a later step in the outline. To avoid raising any warnings, such a work chain
must exclude the ``parent_folder`` port when exposing the inputs of the ``PwCalculation``.
"""
result = super().validate_inputs(value, port_namespace)

if result is not None:
return result

parameters = value['parameters'].get_dict()
calculation_type = parameters.get('CONTROL', {}).get('calculation', 'scf')

if 'parent_folder' not in port_namespace:
return

# If a `parent_folder` input is provided, make sure the inputs are set to restart
if 'parent_folder' in value and calculation_type not in ('nscf', 'bands'):
if not any([
Expand All @@ -198,39 +216,14 @@ def validate_inputs_base(value, _):
" parameters['ELECTRONS']['startingwfc'] = 'file'\n"
)

@classmethod
def validate_inputs(cls, value, port_namespace):
"""Validate the top level namespace.
Check that the restart input parameters are set correctly. In case of 'nscf' and 'bands' calculations, this
means ``parent_folder`` is provided. For other calculations, if the ``parent_folder`` is provided, the restart
settings must be set to use some of the outputs.
Note that the validator is split in two methods: ``validate_inputs`` and ``validate_inputs_base``. This is to
facilitate work chains that wrap this calculation that will provide the ``parent_folder`` themselves and so do
not require the user to provide it at launch of the work chain. This will fail because of the validation in this
validator, however, which is why the rest of the logic is moved to ``validate_inputs_base``. The wrapping work
chain can change the ``validate_input`` validator for ``validate_inputs_base`` thereby allowing the parent
folder to be defined during the work chains runtime, while still keep the rest of the namespace validation.
"""
result = super().validate_inputs(value, port_namespace)

if result is not None:
return result

parameters = value['parameters'].get_dict()
calculation_type = parameters.get('CONTROL', {}).get('calculation', 'scf')

if calculation_type in ('nscf', 'bands'):
if 'parent_folder' not in value:
warnings.warn(
f'`parent_folder` not provided for `{calculation_type}` calculation. For work chains wrapping this '
'calculation, you can disable this warning by setting the validator of the `PwCalculation` port to '
'`PwCalculation.validate_inputs_base`.'
'calculation, you can disable this warning by excluding the `parent_folder` when exposing the '
'inputs of the `PwCalculation`.'
)

return cls.validate_inputs_base(value, port_namespace)

@classproperty
def filename_input_hubbard_parameters(cls):
"""Return the relative file name of the file containing the Hubbard parameters.
Expand Down
4 changes: 1 addition & 3 deletions src/aiida_quantumespresso/workflows/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from aiida.orm.nodes.data.base import to_aiida_type
import jsonschema

from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.utils.mapping import prepare_process_inputs

from .protocols.utils import ProtocolMixin
Expand Down Expand Up @@ -242,13 +241,12 @@ def define(cls, spec):
spec.expose_inputs(
PwBaseWorkChain,
namespace='nscf',
exclude=('clean_workdir', 'pw.structure'),
exclude=('clean_workdir', 'pw.structure', 'pw.parent_folder'),
namespace_options={
'help': 'Inputs for the `PwBaseWorkChain` of the `nscf` calculation.',
'validator': validate_nscf
}
)
spec.inputs['nscf']['pw'].validator = PwCalculation.validate_inputs_base
spec.expose_inputs(
DosCalculation,
namespace='dos',
Expand Down
5 changes: 2 additions & 3 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from aiida.engine import ToContext, WorkChain, if_

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
Expand Down Expand Up @@ -60,7 +59,7 @@ def define(cls, spec):
exclude=('clean_workdir', 'pw.structure'),
namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the SCF calculation.'})
spec.expose_inputs(PwBaseWorkChain, namespace='bands',
exclude=('clean_workdir', 'pw.structure', 'pw.kpoints', 'pw.kpoints_distance'),
exclude=('clean_workdir', 'pw.structure', 'pw.kpoints', 'pw.kpoints_distance', 'pw.parent_folder'),
namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the BANDS calculation.'})
spec.input('structure', valid_type=orm.StructureData, help='The inputs structure.')
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
Expand All @@ -71,7 +70,7 @@ def define(cls, spec):
help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.')
spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False,
help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.')
spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base

spec.inputs.validator = validate_inputs
spec.outline(
cls.setup,
Expand Down

0 comments on commit 2a20c3f

Please sign in to comment.