diff --git a/src/aiidalab_qe/plugins/bands/workchain.py b/src/aiidalab_qe/plugins/bands/workchain.py index 484ec082f..c3a1fe59f 100644 --- a/src/aiidalab_qe/plugins/bands/workchain.py +++ b/src/aiidalab_qe/plugins/bands/workchain.py @@ -215,6 +215,8 @@ def get_builder(codes, structure, parameters, **kwargs): # pop the inputs that are excluded from the expose_inputs bands.pop("relax") + # run the scf workchain from the app's workchain + bands.pop("scf") bands.pop("structure", None) bands.pop("clean_workdir", None) return bands @@ -224,4 +226,6 @@ def get_builder(codes, structure, parameters, **kwargs): "workchain": PwBandsWorkChain, "exclude": ("clean_workdir", "structure", "relax"), "get_builder": get_builder, + "requires_scf": True, + "input_from_ctx": {"bands.pw.parent_folder": "scf_folder"}, } diff --git a/src/aiidalab_qe/plugins/pdos/workchain.py b/src/aiidalab_qe/plugins/pdos/workchain.py index bd9915f2e..d6cde4c5c 100644 --- a/src/aiidalab_qe/plugins/pdos/workchain.py +++ b/src/aiidalab_qe/plugins/pdos/workchain.py @@ -66,6 +66,8 @@ def get_builder(codes, structure, parameters, **kwargs): # pop the inputs that are exclueded from the expose_inputs pdos.pop("structure", None) pdos.pop("clean_workdir", None) + # run the scf workchain from the app's workchain + pdos.pop("scf") else: raise ValueError("The dos_code and projwfc_code are required.") return pdos @@ -75,4 +77,6 @@ def get_builder(codes, structure, parameters, **kwargs): "workchain": PdosWorkChain, "exclude": ("clean_workdir", "structure", "relax"), "get_builder": get_builder, + "requires_scf": True, + "input_from_ctx": {"nscf.pw.parent_folder": "scf_folder"}, } diff --git a/src/aiidalab_qe/workflows/__init__.py b/src/aiidalab_qe/workflows/__init__.py index ac292cb4a..c3d66f79a 100644 --- a/src/aiidalab_qe/workflows/__init__.py +++ b/src/aiidalab_qe/workflows/__init__.py @@ -5,9 +5,10 @@ from aiida.plugins import DataFactory # AiiDA Quantum ESPRESSO plugin inputs. -from aiida_quantumespresso.common.types import ElectronicType, RelaxType, SpinType +from aiida_quantumespresso.common.types import ElectronicType, SpinType from aiida_quantumespresso.utils.mapping import prepare_process_inputs from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain +from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain XyData = DataFactory("core.array.xy") StructureData = DataFactory("core.structure") @@ -59,6 +60,9 @@ def define(cls, spec): spec.expose_inputs(PwRelaxWorkChain, namespace='relax', exclude=('clean_workdir', 'structure'), namespace_options={'required': False, 'populate_defaults': False, 'help': 'Inputs for the `PwRelaxWorkChain`, if not specified at all, the relaxation step is skipped.'}) + spec.expose_inputs(PwBaseWorkChain, namespace='scf', + exclude=('clean_workdir', 'pw.structure'), + namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the SCF calculation.'}) i = 0 for name, entry_point in plugin_entries.items(): plugin_workchain = entry_point["workchain"] @@ -88,13 +92,17 @@ def define(cls, spec): cls.run_relax, cls.inspect_relax ), + if_(cls.should_run_scf)( + cls.run_scf, + cls.inspect_scf, + ), cls.run_plugin, cls.inspect_plugin, ) spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_RELAX', message='The PwRelaxWorkChain sub process failed') - spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_PDOS', - message='The PdosWorkChain sub process failed') + spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_SCF', + message='The SCF sub process failed') spec.output('structure', valid_type=StructureData, required=False) # yapf: enable @@ -123,17 +131,15 @@ def get_builder_from_protocol( builder = cls.get_builder() # Set the structure. builder.structure = structure + protocol = parameters["workchain"]["protocol"] + args = (codes.get("pw"), structure, protocol) # relax relax_overrides = { "base": parameters["advanced"], "base_final_scf": parameters["advanced"], } - protocol = parameters["workchain"]["protocol"] relax_builder = PwRelaxWorkChain.get_builder_from_protocol( - code=codes.get("pw"), - structure=structure, - protocol=protocol, - relax_type=RelaxType(parameters["workchain"]["relax_type"]), + *args, electronic_type=ElectronicType(parameters["workchain"]["electronic_type"]), spin_type=SpinType(parameters["workchain"]["spin_type"]), initial_magnetic_moments=parameters["advanced"]["initial_magnetic_moments"], @@ -145,6 +151,16 @@ def get_builder_from_protocol( relax_builder.pop("clean_workdir", None) relax_builder.pop("base_final_scf", None) # never run a final scf builder.relax = relax_builder + # scf + scf_overrides = parameters["advanced"] + protocol = parameters["workchain"]["protocol"] + scf_builder = PwBaseWorkChain.get_builder_from_protocol( + *args, overrides=scf_overrides, **kwargs + ) + # pop the inputs that are excluded from the expose_inputs + scf_builder.pop("structure", None) + scf_builder.pop("clean_workdir", None) + builder.scf = scf_builder if properties is None: properties = [] @@ -156,6 +172,9 @@ def get_builder_from_protocol( codes, structure, copy.deepcopy(parameters), **kwargs ) setattr(builder, name, plugin_builder) + # check if the plugin requires a scf calculation + if entry_point["requires_scf"]: + builder.run_scf = True else: builder.pop(name, None) @@ -176,7 +195,6 @@ def setup(self): # logic based on the properties input self.ctx.run_relax = "relax" in self.inputs.properties - self.ctx.run_pdos = "pdos" in self.inputs.properties def should_run_relax(self): """Check if the geometry of the input structure should be optimized.""" @@ -212,6 +230,53 @@ def inspect_relax(self): ) self.out("structure", self.ctx.current_structure) + def should_run_scf(self): + """Check if the SCF calculation should be run.""" + run_scf = False + for name, entry_point in plugin_entries.items(): + if name in self.inputs.properties: + # check if the plugin requires a scf calculation + if entry_point["requires_scf"]: + run_scf = True + return run_scf + + def run_scf(self): + """Run the PwBaseWorkChain in scf mode""" + inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace="scf")) + inputs.metadata.call_link_label = "scf" + inputs.pw.structure = self.ctx.current_structure + + # Make sure to carry the number of bands from the relax workchain if it was run and it wasn't explicitly defined + # in the inputs. One of the base workchains in the relax workchain may have changed the number automatically in + # the sanity checks on band occupations. + if self.ctx.current_number_of_bands: + inputs.pw.parameters = inputs.pw.parameters.get_dict() + inputs.pw.parameters.setdefault("SYSTEM", {}).setdefault( + "nbnd", self.ctx.current_number_of_bands + ) + + inputs = prepare_process_inputs(PwBaseWorkChain, inputs) + running = self.submit(PwBaseWorkChain, **inputs) + + self.report(f"launching PwBaseWorkChain<{running.pk}> in scf mode") + + return ToContext(workchain_scf=running) + + def inspect_scf(self): + """Verify that the PwBaseWorkChain for the scf run finished successfully.""" + workchain = self.ctx.workchain_scf + + if not workchain.is_finished_ok: + self.report( + f"scf PwBaseWorkChain failed with exit status {workchain.exit_status}" + ) + return self.exit_codes.ERROR_SUB_PROCESS_FAILED_SCF + + self.ctx.scf_folder = workchain.outputs.remote_folder + self.ctx.current_number_of_bands = ( + workchain.outputs.output_parameters.base.attributes.get("number_of_bands") + ) + def should_run_plugin(self, name): return name in self.inputs @@ -228,6 +293,9 @@ def run_plugin(self): ) inputs.metadata.call_link_label = name inputs.structure = self.ctx.current_structure + # set the scf parent folder and other inputs from the context + for key, value in entry_point.get("input_from_ctx", {}).items(): + setattr(inputs, key, self.ctx[value]) inputs = prepare_process_inputs(plugin_workchain, inputs) running = self.submit(plugin_workchain, **inputs) self.report(f"launching plugin {name} <{running.pk}>")