diff --git a/aiida_quantumespresso/workflows/protocols/pw/base.yaml b/aiida_quantumespresso/workflows/protocols/pw/base.yaml index 0b3ce8041..81219e511 100644 --- a/aiida_quantumespresso/workflows/protocols/pw/base.yaml +++ b/aiida_quantumespresso/workflows/protocols/pw/base.yaml @@ -5,14 +5,14 @@ default_inputs: meta_parameters: conv_thr_per_atom: 0.2e-9 etot_conv_thr_per_atom: 1.e-5 - metadata: - options: - resources: - num_machines: 1 - max_wallclock_seconds: 43200 # Twelve hours - withmpi: True pseudo_family: 'SSSP/1.1/PBE/efficiency' pw: + metadata: + options: + resources: + num_machines: 1 + max_wallclock_seconds: 43200 # Twelve hours + withmpi: True parameters: CONTROL: calculation: scf diff --git a/aiida_quantumespresso/workflows/pw/base.py b/aiida_quantumespresso/workflows/pw/base.py index 859b82da6..46d113f46 100644 --- a/aiida_quantumespresso/workflows/pw/base.py +++ b/aiida_quantumespresso/workflows/pw/base.py @@ -200,14 +200,18 @@ def get_builder_from_protocol( parameters['SYSTEM']['nspin'] = 2 parameters['SYSTEM']['starting_magnetization'] = starting_magnetization - builder.pw['code'] = code # pylint: disable=no-member - builder.pw['pseudos'] = pseudo_family.get_pseudos(structure=structure) # pylint: disable=no-member - builder.pw['structure'] = structure # pylint: disable=no-member - builder.pw['parameters'] = orm.Dict(dict=parameters) # pylint: disable=no-member - builder.pw['metadata'] = inputs['metadata'] # pylint: disable=no-member + # pylint: disable=no-member + builder.pw['code'] = code + builder.pw['pseudos'] = pseudo_family.get_pseudos(structure=structure) + builder.pw['structure'] = structure + builder.pw['parameters'] = orm.Dict(dict=parameters) + builder.pw['metadata'] = inputs['pw']['metadata'] + if 'parallelization' in inputs['pw']: + builder.pw['parallelization'] = orm.Dict(dict=inputs['pw']['parallelization']) builder.clean_workdir = orm.Bool(inputs['clean_workdir']) builder.kpoints_distance = orm.Float(inputs['kpoints_distance']) builder.kpoints_force_parity = orm.Bool(inputs['kpoints_force_parity']) + # pylint: enable=no-member return builder diff --git a/tests/workflows/protocols/pw/test_base.py b/tests/workflows/protocols/pw/test_base.py index c44275c0f..356239ad1 100644 --- a/tests/workflows/protocols/pw/test_base.py +++ b/tests/workflows/protocols/pw/test_base.py @@ -93,3 +93,37 @@ def test_initial_magnetic_moments(fixture_code, generate_structure): assert parameters['SYSTEM']['nspin'] == 2 assert parameters['SYSTEM']['starting_magnetization'] == {'Si': 0.25} + + +def test_metadata_overrides(fixture_code, generate_structure): + """Test that pw metadata is correctly passed through overrides.""" + code = fixture_code('quantumespresso.pw') + structure = generate_structure() + + overrides = {'pw': {'metadata': {'options': {'resources': {'num_machines': 1e90}, 'max_wallclock_seconds': 1}}}} + builder = PwBaseWorkChain.get_builder_from_protocol( + code, + structure, + overrides=overrides, + ) + metadata = builder.pw.metadata + + assert metadata['options']['resources']['num_machines'] == 1e90 + assert metadata['options']['max_wallclock_seconds'] == 1 + + +def test_parallelization_overrides(fixture_code, generate_structure): + """Test that pw parallelization settings are correctly passed through overrides.""" + code = fixture_code('quantumespresso.pw') + structure = generate_structure() + + overrides = {'pw': {'parallelization': {'npool': 4, 'ndiag': 12}}} + builder = PwBaseWorkChain.get_builder_from_protocol( + code, + structure, + overrides=overrides, + ) + parallelization = builder.pw.parallelization + + assert parallelization['npool'] == 4 + assert parallelization['ndiag'] == 12