Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PwCalculation: refactor parent_folder validation #927

Merged
merged 1 commit into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions src/aiida_quantumespresso/calculations/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,30 @@ def define(cls, spec):
'is `False` and/or `electron_maxstep` is 0.')
# yapf: enable

@staticmethod
def validate_inputs_base(value, _):
"""Validate the top level namespace."""
@classmethod
def validate_inputs(cls, value, port_namespace):
"""Validate the top level namespace.

Check that the restart input parameters are set correctly. In case of 'nscf' and 'bands' calculations, this
means ``parent_folder`` is provided. For other calculations, if the ``parent_folder`` is provided, the restart
settings must be set to use some of the outputs.

Note that the validator will only check the logic in case the ``parent_folder`` is a port in the
``port_namespace``. This is because the ``PwCalculation`` can be wrapped inside a work chain that only provides
the ``parent_folder`` input at a later step in the outline. To avoid raising any warnings, such a work chain
must exclude the ``parent_folder`` port when exposing the inputs of the ``PwCalculation``.
"""
result = super().validate_inputs(value, port_namespace)

if result is not None:
return result

parameters = value['parameters'].get_dict()
calculation_type = parameters.get('CONTROL', {}).get('calculation', 'scf')

if 'parent_folder' not in port_namespace:
return

# If a `parent_folder` input is provided, make sure the inputs are set to restart
if 'parent_folder' in value and calculation_type not in ('nscf', 'bands'):
if not any([
Expand All @@ -198,39 +216,14 @@ def validate_inputs_base(value, _):
" parameters['ELECTRONS']['startingwfc'] = 'file'\n"
)

@classmethod
def validate_inputs(cls, value, port_namespace):
"""Validate the top level namespace.

Check that the restart input parameters are set correctly. In case of 'nscf' and 'bands' calculations, this
means ``parent_folder`` is provided. For other calculations, if the ``parent_folder`` is provided, the restart
settings must be set to use some of the outputs.

Note that the validator is split in two methods: ``validate_inputs`` and ``validate_inputs_base``. This is to
facilitate work chains that wrap this calculation that will provide the ``parent_folder`` themselves and so do
not require the user to provide it at launch of the work chain. This will fail because of the validation in this
validator, however, which is why the rest of the logic is moved to ``validate_inputs_base``. The wrapping work
chain can change the ``validate_input`` validator for ``validate_inputs_base`` thereby allowing the parent
folder to be defined during the work chains runtime, while still keep the rest of the namespace validation.
"""
result = super().validate_inputs(value, port_namespace)

if result is not None:
return result

parameters = value['parameters'].get_dict()
calculation_type = parameters.get('CONTROL', {}).get('calculation', 'scf')

if calculation_type in ('nscf', 'bands'):
if 'parent_folder' not in value:
warnings.warn(
f'`parent_folder` not provided for `{calculation_type}` calculation. For work chains wrapping this '
'calculation, you can disable this warning by setting the validator of the `PwCalculation` port to '
'`PwCalculation.validate_inputs_base`.'
'calculation, you can disable this warning by excluding the `parent_folder` when exposing the '
'inputs of the `PwCalculation`.'
)

return cls.validate_inputs_base(value, port_namespace)

@classproperty
def filename_input_hubbard_parameters(cls):
"""Return the relative file name of the file containing the Hubbard parameters.
Expand Down
4 changes: 1 addition & 3 deletions src/aiida_quantumespresso/workflows/pdos.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from aiida.orm.nodes.data.base import to_aiida_type
import jsonschema

from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.utils.mapping import prepare_process_inputs

from .protocols.utils import ProtocolMixin
Expand Down Expand Up @@ -242,13 +241,12 @@ def define(cls, spec):
spec.expose_inputs(
PwBaseWorkChain,
namespace='nscf',
exclude=('clean_workdir', 'pw.structure'),
exclude=('clean_workdir', 'pw.structure', 'pw.parent_folder'),
namespace_options={
'help': 'Inputs for the `PwBaseWorkChain` of the `nscf` calculation.',
'validator': validate_nscf
}
)
spec.inputs['nscf']['pw'].validator = PwCalculation.validate_inputs_base
spec.expose_inputs(
DosCalculation,
namespace='dos',
Expand Down
5 changes: 2 additions & 3 deletions src/aiida_quantumespresso/workflows/pw/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from aiida.engine import ToContext, WorkChain, if_

from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
Expand Down Expand Up @@ -60,7 +59,7 @@ def define(cls, spec):
exclude=('clean_workdir', 'pw.structure'),
namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the SCF calculation.'})
spec.expose_inputs(PwBaseWorkChain, namespace='bands',
exclude=('clean_workdir', 'pw.structure', 'pw.kpoints', 'pw.kpoints_distance'),
exclude=('clean_workdir', 'pw.structure', 'pw.kpoints', 'pw.kpoints_distance', 'pw.parent_folder'),
namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the BANDS calculation.'})
spec.input('structure', valid_type=orm.StructureData, help='The inputs structure.')
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
Expand All @@ -71,7 +70,7 @@ def define(cls, spec):
help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.')
spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False,
help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.')
spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base

spec.inputs.validator = validate_inputs
spec.outline(
cls.setup,
Expand Down