Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run one SCF calculation for all plugins #631

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/aiidalab_qe/plugins/bands/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def get_builder(codes, structure, parameters, **kwargs):

# pop the inputs that are excluded from the expose_inputs
bands.pop("relax")
# run the scf workchain from the app's workchain
bands.pop("scf")
bands.pop("structure", None)
bands.pop("clean_workdir", None)
return bands
Expand All @@ -224,4 +226,6 @@ def get_builder(codes, structure, parameters, **kwargs):
"workchain": PwBandsWorkChain,
"exclude": ("clean_workdir", "structure", "relax"),
"get_builder": get_builder,
"requires_scf": True,
"input_from_ctx": {"bands.pw.parent_folder": "scf_folder"},
}
4 changes: 4 additions & 0 deletions src/aiidalab_qe/plugins/pdos/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def get_builder(codes, structure, parameters, **kwargs):
# pop the inputs that are exclueded from the expose_inputs
pdos.pop("structure", None)
pdos.pop("clean_workdir", None)
# run the scf workchain from the app's workchain
pdos.pop("scf")
else:
raise ValueError("The dos_code and projwfc_code are required.")
return pdos
Expand All @@ -75,4 +77,6 @@ def get_builder(codes, structure, parameters, **kwargs):
"workchain": PdosWorkChain,
"exclude": ("clean_workdir", "structure", "relax"),
"get_builder": get_builder,
"requires_scf": True,
"input_from_ctx": {"nscf.pw.parent_folder": "scf_folder"},
}
86 changes: 77 additions & 9 deletions src/aiidalab_qe/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from aiida.plugins import DataFactory

# AiiDA Quantum ESPRESSO plugin inputs.
from aiida_quantumespresso.common.types import ElectronicType, RelaxType, SpinType
from aiida_quantumespresso.common.types import ElectronicType, SpinType
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain

XyData = DataFactory("core.array.xy")
StructureData = DataFactory("core.structure")
Expand Down Expand Up @@ -59,6 +60,9 @@ def define(cls, spec):
spec.expose_inputs(PwRelaxWorkChain, namespace='relax', exclude=('clean_workdir', 'structure'),
namespace_options={'required': False, 'populate_defaults': False,
'help': 'Inputs for the `PwRelaxWorkChain`, if not specified at all, the relaxation step is skipped.'})
spec.expose_inputs(PwBaseWorkChain, namespace='scf',
exclude=('clean_workdir', 'pw.structure'),
namespace_options={'help': 'Inputs for the `PwBaseWorkChain` for the SCF calculation.'})
i = 0
for name, entry_point in plugin_entries.items():
plugin_workchain = entry_point["workchain"]
Expand Down Expand Up @@ -88,13 +92,17 @@ def define(cls, spec):
cls.run_relax,
cls.inspect_relax
),
if_(cls.should_run_scf)(
cls.run_scf,
cls.inspect_scf,
),
cls.run_plugin,
cls.inspect_plugin,
)
spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_RELAX',
message='The PwRelaxWorkChain sub process failed')
spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_PDOS',
message='The PdosWorkChain sub process failed')
spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_SCF',
message='The SCF sub process failed')
spec.output('structure', valid_type=StructureData, required=False)
# yapf: enable

Expand Down Expand Up @@ -123,17 +131,15 @@ def get_builder_from_protocol(
builder = cls.get_builder()
# Set the structure.
builder.structure = structure
protocol = parameters["workchain"]["protocol"]
args = (codes.get("pw"), structure, protocol)
# relax
relax_overrides = {
"base": parameters["advanced"],
"base_final_scf": parameters["advanced"],
}
protocol = parameters["workchain"]["protocol"]
relax_builder = PwRelaxWorkChain.get_builder_from_protocol(
code=codes.get("pw"),
structure=structure,
protocol=protocol,
relax_type=RelaxType(parameters["workchain"]["relax_type"]),
*args,
electronic_type=ElectronicType(parameters["workchain"]["electronic_type"]),
spin_type=SpinType(parameters["workchain"]["spin_type"]),
initial_magnetic_moments=parameters["advanced"]["initial_magnetic_moments"],
Expand All @@ -145,6 +151,16 @@ def get_builder_from_protocol(
relax_builder.pop("clean_workdir", None)
relax_builder.pop("base_final_scf", None) # never run a final scf
builder.relax = relax_builder
# scf
scf_overrides = parameters["advanced"]
protocol = parameters["workchain"]["protocol"]
scf_builder = PwBaseWorkChain.get_builder_from_protocol(
*args, overrides=scf_overrides, **kwargs
)
# pop the inputs that are excluded from the expose_inputs
scf_builder.pop("structure", None)
scf_builder.pop("clean_workdir", None)
builder.scf = scf_builder

if properties is None:
properties = []
Expand All @@ -156,6 +172,9 @@ def get_builder_from_protocol(
codes, structure, copy.deepcopy(parameters), **kwargs
)
setattr(builder, name, plugin_builder)
# check if the plugin requires a scf calculation
if entry_point["requires_scf"]:
builder.run_scf = True
else:
builder.pop(name, None)

Expand All @@ -176,7 +195,6 @@ def setup(self):

# logic based on the properties input
self.ctx.run_relax = "relax" in self.inputs.properties
self.ctx.run_pdos = "pdos" in self.inputs.properties

def should_run_relax(self):
"""Check if the geometry of the input structure should be optimized."""
Expand Down Expand Up @@ -212,6 +230,53 @@ def inspect_relax(self):
)
self.out("structure", self.ctx.current_structure)

def should_run_scf(self):
"""Check if the SCF calculation should be run."""
run_scf = False
for name, entry_point in plugin_entries.items():
if name in self.inputs.properties:
# check if the plugin requires a scf calculation
if entry_point["requires_scf"]:
run_scf = True
return run_scf

def run_scf(self):
"""Run the PwBaseWorkChain in scf mode"""
inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace="scf"))
inputs.metadata.call_link_label = "scf"
inputs.pw.structure = self.ctx.current_structure

# Make sure to carry the number of bands from the relax workchain if it was run and it wasn't explicitly defined
# in the inputs. One of the base workchains in the relax workchain may have changed the number automatically in
# the sanity checks on band occupations.
if self.ctx.current_number_of_bands:
inputs.pw.parameters = inputs.pw.parameters.get_dict()
inputs.pw.parameters.setdefault("SYSTEM", {}).setdefault(
"nbnd", self.ctx.current_number_of_bands
)

inputs = prepare_process_inputs(PwBaseWorkChain, inputs)
running = self.submit(PwBaseWorkChain, **inputs)

self.report(f"launching PwBaseWorkChain<{running.pk}> in scf mode")

return ToContext(workchain_scf=running)

def inspect_scf(self):
"""Verify that the PwBaseWorkChain for the scf run finished successfully."""
workchain = self.ctx.workchain_scf

if not workchain.is_finished_ok:
self.report(
f"scf PwBaseWorkChain failed with exit status {workchain.exit_status}"
)
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_SCF

self.ctx.scf_folder = workchain.outputs.remote_folder
self.ctx.current_number_of_bands = (
workchain.outputs.output_parameters.base.attributes.get("number_of_bands")
)

def should_run_plugin(self, name):
return name in self.inputs

Expand All @@ -228,6 +293,9 @@ def run_plugin(self):
)
inputs.metadata.call_link_label = name
inputs.structure = self.ctx.current_structure
# set the scf parent folder and other inputs from the context
for key, value in entry_point.get("input_from_ctx", {}).items():
setattr(inputs, key, self.ctx[value])
inputs = prepare_process_inputs(plugin_workchain, inputs)
running = self.submit(plugin_workchain, **inputs)
self.report(f"launching plugin {name} <{running.pk}>")
Expand Down
Loading