Skip to content

Commit

Permalink
Fix Incar check_params for Union type (#3958)
Browse files Browse the repository at this point in the history
* add test case for Union type LREAL

* remove debug msg

* update type check mechanism

* use eval for type checking

* use isinstance syntax

* try to increase dependency palettable version

* bump monty to 2024.7.29

* pin torch version until matgl release

* Revert "pin torch version until matgl release"

This reverts commit 215c888.

* skip failing matgl tests for now

---------

Signed-off-by: Janosh Riebesell <janosh.riebesell@gmail.com>
Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
DanielYang59 and janosh authored Aug 2, 2024
1 parent b35b99e commit 940eb60
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ classifiers = [
dependencies = [
"joblib>=1",
"matplotlib>=3.8",
"monty>=2024.5.24",
"monty>=2024.7.29",
"networkx>=2.2",
"palettable>=3.1.1",
'numpy>=1.25.0,<2.0 ; platform_system == "Windows"',
'numpy>=1.25.0 ; platform_system != "Windows"',
"palettable>=3.3.3",
"pandas>=2",
"plotly>=4.5.0",
"pybtex>=0.24.0",
Expand Down
2 changes: 1 addition & 1 deletion src/pymatgen/io/vasp/incar_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@
"type": "bool"
},
"LREAL": {
"type": "Union[bool, str]",
"type": "(bool, str)",
"values": [
false,
true,
Expand Down
6 changes: 3 additions & 3 deletions src/pymatgen/io/vasp/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,10 +1028,10 @@ def check_params(self) -> None:
continue

# Check value and its type
param_type = incar_params[tag].get("type")
allowed_values = incar_params[tag].get("values")
param_type: str = incar_params[tag].get("type")
allowed_values: list[Any] = incar_params[tag].get("values")

if param_type is not None and type(val).__name__ != param_type:
if param_type is not None and not isinstance(val, eval(param_type)):
warnings.warn(f"{tag}: {val} is not a {param_type}", BadIncarWarning, stacklevel=2)

# Only check value when it's not None,
Expand Down
4 changes: 4 additions & 0 deletions tests/core/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,6 +1769,7 @@ def test_relax_ase_opt_kwargs(self):
assert traj[0] != traj[-1]
assert os.path.isfile(traj_file)

@pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency")
def test_calculate_m3gnet(self):
pytest.importorskip("matgl")
calculator = self.get_structure("Si").calculate()
Expand All @@ -1780,6 +1781,7 @@ def test_calculate_m3gnet(self):
assert np.linalg.norm(calculator.results["forces"]) == approx(7.8123485e-06, abs=0.2)
assert np.linalg.norm(calculator.results["stress"]) == approx(1.7861567, abs=2)

@pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency")
def test_relax_m3gnet(self):
matgl = pytest.importorskip("matgl")
struct = self.get_structure("Si")
Expand All @@ -1790,6 +1792,7 @@ def test_relax_m3gnet(self):
actual = relaxed.dynamics[key]
assert actual == val, f"expected {key} to be {val}, {actual=}"

@pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency")
def test_relax_m3gnet_fixed_lattice(self):
matgl = pytest.importorskip("matgl")
struct = self.get_structure("Si")
Expand All @@ -1798,6 +1801,7 @@ def test_relax_m3gnet_fixed_lattice(self):
assert isinstance(relaxed.calc, matgl.ext.ase.M3GNetCalculator)
assert relaxed.dynamics["optimizer"] == "BFGS"

@pytest.mark.skip("TODO: #3958 wait for matgl resolve of torch dependency")
def test_relax_m3gnet_with_traj(self):
pytest.importorskip("matgl")
struct = self.get_structure("Si")
Expand Down
1 change: 1 addition & 0 deletions tests/io/vasp/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ def test_check_params(self):
"AMIN": 0.01,
"ICHARG": 1,
"MAGMOM": [1, 2, 4, 5],
"LREAL": True, # special case: Union type
"NBAND": 250, # typo in tag
"METAGGA": "SCAM", # typo in value
"EDIFF": 5 + 1j, # value should be a float
Expand Down

0 comments on commit 940eb60

Please sign in to comment.