From 9e684baf39a69be1acc9039d04af13d3e1065905 Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Fri, 5 May 2023 17:19:53 +0200 Subject: [PATCH] `PwCalculation`: refactor `parent_folder` validation --- src/aiida_quantumespresso/calculations/pw.py | 53 ++++++++----------- src/aiida_quantumespresso/workflows/pdos.py | 4 +- .../workflows/pw/bands.py | 5 +- 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/src/aiida_quantumespresso/calculations/pw.py b/src/aiida_quantumespresso/calculations/pw.py index ae1757d6c..cf199deba 100644 --- a/src/aiida_quantumespresso/calculations/pw.py +++ b/src/aiida_quantumespresso/calculations/pw.py @@ -176,12 +176,30 @@ def define(cls, spec): 'is `False` and/or `electron_maxstep` is 0.') # yapf: enable - @staticmethod - def validate_inputs_base(value, _): - """Validate the top level namespace.""" + @classmethod + def validate_inputs(cls, value, port_namespace): + """Validate the top level namespace. + + Check that the restart input parameters are set correctly. In case of 'nscf' and 'bands' calculations, this + means ``parent_folder`` is provided. For other calculations, if the ``parent_folder`` is provided, the restart + settings must be set to use some of the outputs. + + Note that the validator will only check the logic in case the ``parent_folder`` is a port in the + ``port_namespace``. This is because the ``PwCalculation`` can be wrapped inside a work chain that only provides + the ``parent_folder`` input at a later step in the outline. To avoid raising any warnings, such a work chain + must exclude the ``parent_folder`` port when exposing the inputs of the ``PwCalculation``. + """ + result = super().validate_inputs(value, port_namespace) + + if result is not None: + return result + parameters = value['parameters'].get_dict() calculation_type = parameters.get('CONTROL', {}).get('calculation', 'scf') + if 'parent_folder' not in port_namespace: + return + # If a `parent_folder` input is provided, make sure the inputs are set to restart if 'parent_folder' in value and calculation_type not in ('nscf', 'bands'): if not any([ @@ -198,39 +216,14 @@ def validate_inputs_base(value, _): " parameters['ELECTRONS']['startingwfc'] = 'file'\n" ) - @classmethod - def validate_inputs(cls, value, port_namespace): - """Validate the top level namespace. - - Check that the restart input parameters are set correctly. In case of 'nscf' and 'bands' calculations, this - means ``parent_folder`` is provided. For other calculations, if the ``parent_folder`` is provided, the restart - settings must be set to use some of the outputs. - - Note that the validator is split in two methods: ``validate_inputs`` and ``validate_inputs_base``. This is to - facilitate work chains that wrap this calculation that will provide the ``parent_folder`` themselves and so do - not require the user to provide it at launch of the work chain. This will fail because of the validation in this - validator, however, which is why the rest of the logic is moved to ``validate_inputs_base``. The wrapping work - chain can change the ``validate_input`` validator for ``validate_inputs_base`` thereby allowing the parent - folder to be defined during the work chains runtime, while still keep the rest of the namespace validation. - """ - result = super().validate_inputs(value, port_namespace) - - if result is not None: - return result - - parameters = value['parameters'].get_dict() - calculation_type = parameters.get('CONTROL', {}).get('calculation', 'scf') - if calculation_type in ('nscf', 'bands'): if 'parent_folder' not in value: warnings.warn( f'`parent_folder` not provided for `{calculation_type}` calculation. For work chains wrapping this ' - 'calculation, you can disable this warning by setting the validator of the `PwCalculation` port to ' - '`PwCalculation.validate_inputs_base`.' + 'calculation, you can disable this warning by excluding the `parent_folder` when exposing the ' + 'inputs of the `PwCalculation`.' ) - return cls.validate_inputs_base(value, port_namespace) - @classproperty def filename_input_hubbard_parameters(cls): """Return the relative file name of the file containing the Hubbard parameters. diff --git a/src/aiida_quantumespresso/workflows/pdos.py b/src/aiida_quantumespresso/workflows/pdos.py index 0118b9e4e..888379f63 100644 --- a/src/aiida_quantumespresso/workflows/pdos.py +++ b/src/aiida_quantumespresso/workflows/pdos.py @@ -52,7 +52,6 @@ from aiida.orm.nodes.data.base import to_aiida_type import jsonschema -from aiida_quantumespresso.calculations.pw import PwCalculation from aiida_quantumespresso.utils.mapping import prepare_process_inputs from .protocols.utils import ProtocolMixin @@ -242,13 +241,12 @@ def define(cls, spec): spec.expose_inputs( PwBaseWorkChain, namespace='nscf', - exclude=('clean_workdir', 'pw.structure'), + exclude=('clean_workdir', 'pw.structure', 'pw.parent_folder'), namespace_options={ 'help': 'Inputs for the `PwBaseWorkChain` of the `nscf` calculation.', 'validator': validate_nscf } ) - spec.inputs['nscf']['pw'].validator = PwCalculation.validate_inputs_base spec.expose_inputs( DosCalculation, namespace='dos', diff --git a/src/aiida_quantumespresso/workflows/pw/bands.py b/src/aiida_quantumespresso/workflows/pw/bands.py index b7bd45b3a..e6c5f77cc 100644 --- a/src/aiida_quantumespresso/workflows/pw/bands.py +++ b/src/aiida_quantumespresso/workflows/pw/bands.py @@ -5,7 +5,6 @@ from aiida.engine import ToContext, WorkChain, if_ from aiida_quantumespresso.calculations.functions.seekpath_structure_analysis import seekpath_structure_analysis -from aiida_quantumespresso.calculations.pw import PwCalculation from aiida_quantumespresso.utils.mapping import prepare_process_inputs from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain @@ -60,7 +59,7 @@ def define(cls, spec): exclude=('clean_workdir', 'pw.structure'), namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the SCF calculation.'}) spec.expose_inputs(PwBaseWorkChain, namespace='bands', - exclude=('clean_workdir', 'pw.structure', 'pw.kpoints', 'pw.kpoints_distance'), + exclude=('clean_workdir', 'pw.structure', 'pw.kpoints', 'pw.kpoints_distance', 'pw.parent_folder'), namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the BANDS calculation.'}) spec.input('structure', valid_type=orm.StructureData, help='The inputs structure.') spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False), @@ -71,7 +70,7 @@ def define(cls, spec): help='Explicit kpoints to use for the BANDS calculation. Specify either this or `bands_kpoints_distance`.') spec.input('bands_kpoints_distance', valid_type=orm.Float, required=False, help='Minimum kpoints distance for the BANDS calculation. Specify either this or `bands_kpoints`.') - spec.inputs['bands']['pw'].validator = PwCalculation.validate_inputs_base + spec.inputs.validator = validate_inputs spec.outline( cls.setup,