Skip to content

Commit

Permalink
Fix tests for RelaxType
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Jan 11, 2021
1 parent cf76270 commit 8a745bb
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 16 deletions.
2 changes: 0 additions & 2 deletions aiida_quantumespresso/cli/workflows/pw/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def launch_workflow(
from aiida.plugins import WorkflowFactory
from qe_tools import CONSTANTS

from aiida_quantumespresso.common.types import RelaxType
from aiida_quantumespresso.utils.resources import get_default_options, get_automatic_parallelization_options

builder = WorkflowFactory('quantumespresso.pw.relax').get_builder()
Expand Down Expand Up @@ -81,7 +80,6 @@ def launch_workflow(
raise click.BadParameter(str(exception))

builder.structure = structure
builder.relax_type = Str(RelaxType.ATOMS.value)
builder.base.kpoints_distance = Float(kpoints_distance)
builder.base.pw.code = code
builder.base.pw.pseudos = pseudo_family.get_pseudos(structure=structure)
Expand Down
1 change: 0 additions & 1 deletion aiida_quantumespresso/workflows/protocols/pw/relax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ default_inputs:
clean_workdir: True
max_meta_convergence_iterations: 5
meta_convergence: True
relax_type: atoms_cell
volume_convergence: 0.02
base:
pw:
Expand Down
9 changes: 5 additions & 4 deletions tests/workflows/protocols/pw/test_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ def test_spin_type(fixture_code, generate_structure):


def test_relax_type(fixture_code, generate_structure):
"""Test ``PwBandsWorkChain.get_builder_from_protocol`` overriding the ``relax_type`` input."""
"""Test ``PwBandsWorkChain.get_builder_from_protocol`` setting the ``relax_type`` input."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()

overrides = {'relax': {'relax_type': RelaxType.NONE.value}}
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides)
assert builder.relax['relax_type'].value == RelaxType.NONE.value
builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.NONE)
assert builder.relax['base']['pw']['parameters']['CONTROL']['calculation'] == 'scf'
with pytest.raises(KeyError):
builder.relax['base']['pw']['parameters']['CELL'] # pylint: disable=pointless-statement
4 changes: 2 additions & 2 deletions tests/workflows/protocols/pw/test_bands/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ relax:
withmpi: true
parameters:
CELL:
cell_dofree: all
press_conv_thr: 0.5
CONTROL:
calculation: scf
calculation: vc-relax
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
tprnfor: true
Expand All @@ -66,7 +67,6 @@ relax:
Si: Si<md5=57fa15d98af99972c7b7aa5c179b0bb8>
max_meta_convergence_iterations: 5
meta_convergence: true
relax_type: atoms_cell
volume_convergence: 0.02
scf:
kpoints_distance: 0.15
Expand Down
50 changes: 45 additions & 5 deletions tests/workflows/protocols/pw/test_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,50 @@ def test_spin_type(fixture_code, generate_structure):
assert parameters['SYSTEM']['starting_magnetization'] == {'Si': 0.1}


@pytest.mark.parametrize('relax_type', RelaxType)
def test_relax_type(fixture_code, generate_structure, relax_type):
"""Docs."""
def test_relax_type(fixture_code, generate_structure):
"""Test ``PwRelaxWorkChain.get_builder_from_protocol`` with ``spin_type`` keyword."""
code = fixture_code('quantumespresso.pw')
structure = generate_structure()
builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, overrides={'relax_type': relax_type.value})
assert builder.relax_type.value == relax_type.value

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.NONE)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'scf'
with pytest.raises(KeyError):
builder.base['pw']['parameters']['CELL'] # pylint: disable=pointless-statement

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'relax'
with pytest.raises(KeyError):
builder.base['pw']['parameters']['CELL'] # pylint: disable=pointless-statement

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.VOLUME)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'volume'
assert builder.base['pw']['settings'].get_dict() == {'FIXED_COORDS': [[True, True, True], [True, True, True]]}

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.SHAPE)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'shape'
assert builder.base['pw']['settings'].get_dict() == {'FIXED_COORDS': [[True, True, True], [True, True, True]]}

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.CELL)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'all'
assert builder.base['pw']['settings'].get_dict() == {'FIXED_COORDS': [[True, True, True], [True, True, True]]}

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS_VOLUME)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'volume'
with pytest.raises(KeyError):
builder.base['pw']['settings'] # pylint: disable=pointless-statement

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS_SHAPE)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'shape'
with pytest.raises(KeyError):
builder.base['pw']['settings'] # pylint: disable=pointless-statement

builder = PwRelaxWorkChain.get_builder_from_protocol(code, structure, relax_type=RelaxType.ATOMS_CELL)
assert builder.base['pw']['parameters']['CONTROL']['calculation'] == 'vc-relax'
assert builder.base['pw']['parameters']['CELL']['cell_dofree'] == 'all'
with pytest.raises(KeyError):
builder.base['pw']['settings'] # pylint: disable=pointless-statement
4 changes: 2 additions & 2 deletions tests/workflows/protocols/pw/test_relax/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ base:
withmpi: true
parameters:
CELL:
cell_dofree: all
press_conv_thr: 0.5
CONTROL:
calculation: scf
calculation: vc-relax
etot_conv_thr: 2.0e-05
forc_conv_thr: 0.0001
tprnfor: true
Expand Down Expand Up @@ -67,6 +68,5 @@ base_final_scf:
clean_workdir: true
max_meta_convergence_iterations: 5
meta_convergence: true
relax_type: atoms_cell
structure: Si2
volume_convergence: 0.02

0 comments on commit 8a745bb

Please sign in to comment.