Skip to content

Commit

Permalink
ruff fixes (#184)
Browse files Browse the repository at this point in the history
* build against numpy2

* Need confirm: remove setuptool upper pin

* bump pymatgen version to support NP2

* include wandb into test dependency as it's required by test_trainer

* explicitly use weights_only False to avoid FutureWarning

* use tem path in test

* bump pymatgen to resolve np2 compatibility issue

* use int64

* avoid release candidate for numpy build

* revert dtype change for pytorch

* fix error type

* NEED CONFIRM: patch numpy random generator

* pip install on win IS SLOW, do something while waiting

* use cython type

* include numpy c header

* int64 might be better, intp is still platform dependent

* maybe fix setup.py ModuleNotFoundError: No module named 'numpy'

by lazy-importing numpy

* ruff fixes

* Revert "maybe fix setup.py ModuleNotFoundError: No module named 'numpy'"

This reverts commit 45a1fd7.

* I don't think we need to manually install build deps, hoping I'm not wrong

* add --system arg to uv pip install

* merge uv install stage?

* fix typo in error message

* NEED CONFIRM: use torch.int32 over int for torch tensor

* raise error instead of systemexit when failure

* make module level var all cap

* fix: /Users/runner/work/chgnet/chgnet/chgnet/model/dynamics.py:297: DeprecationWarning: Use FrechetCellFilter for better convergence w.r.t. cell variables.
    atoms = ase_filter(atoms)

* let's see if any array is not int64

* use setup-uv as Janosh suggested

* try to fix package install

* NEED CONFIRM: drop python 3.9

* Revert "NEED CONFIRM: drop python 3.9"

This reverts commit bda704b.

* use uv pip install for now

* revert to traditional uv pip install for now as setup-uv doesn't have permission for some reason

* fix more missing type changes

* revert changes to cygraph.pyx

* revert changes related to np2

* revert more np2 related changes

* revert more np related changes

* pin pmg 24.8.9

* Revert "fix: /Users/runner/work/chgnet/chgnet/chgnet/model/dynamics.py:297: DeprecationWarning: Use FrechetCellFilter for better convergence w.r.t. cell variables."

This reverts commit b60f5e4.

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
DanielYang59 and janosh authored Sep 12, 2024
1 parent fbfe69c commit 9281cf4
Show file tree
Hide file tree
Showing 16 changed files with 83 additions and 81 deletions.
12 changes: 3 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,10 @@ jobs:
cache: pip
cache-dependency-path: pyproject.toml

- name: Install uv
run: pip install uv

- name: Install dependencies
- name: Install chgnet through uv
run: |
uv pip install cython 'setuptools<70' --system
python setup.py build_ext --inplace
uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }}
pip install uv
uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system
- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
Expand Down
14 changes: 8 additions & 6 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gc
import sys
import warnings
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

import numpy as np
import torch
Expand All @@ -13,6 +13,8 @@
from chgnet.graph.graph import Graph, Node

if TYPE_CHECKING:
from typing import Literal

from pymatgen.core import Structure
from typing_extensions import Self

Expand All @@ -21,7 +23,7 @@
except (ImportError, AttributeError):
make_graph = None

datatype = torch.float32
DATATYPE = torch.float32


class CrystalGraphConverter(nn.Module):
Expand Down Expand Up @@ -122,10 +124,10 @@ def forward(
requires_grad=False,
)
atom_frac_coord = torch.tensor(
structure.frac_coords, dtype=datatype, requires_grad=True
structure.frac_coords, dtype=DATATYPE, requires_grad=True
)
lattice = torch.tensor(
structure.lattice.matrix, dtype=datatype, requires_grad=True
structure.lattice.matrix, dtype=DATATYPE, requires_grad=True
)
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8
Expand All @@ -150,7 +152,7 @@ def forward(
# Report structures that failed creating bond graph
# This happen occasionally with pymatgen version issue
structure.to(filename="bond_graph_error.cif")
raise SystemExit(
raise RuntimeError(
f"Failed creating bond graph for {graph_id}, check bond_graph_error.cif"
) from exc
bond_graph = torch.tensor(bond_graph, dtype=torch.int32)
Expand All @@ -175,7 +177,7 @@ def forward(
atomic_number=atomic_number,
atom_frac_coord=atom_frac_coord,
atom_graph=atom_graph,
neighbor_image=torch.tensor(image, dtype=datatype),
neighbor_image=torch.tensor(image, dtype=DATATYPE),
directed2undirected=directed2undirected,
undirected2directed=undirected2directed,
bond_graph=bond_graph,
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]
if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list):
raise ValueError(
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 "
f"* number of undirected edges={len(self.directed_edges_list)}!"
f"* number of undirected edges={len(self.undirected_edges_list)}!"
f"This indicates directed edges are not complete"
)
line_graph = []
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def fit(
if isinstance(structure, Structure):
atomic_number = torch.tensor(
[site.specie.Z for site in structure],
dtype=int,
dtype=torch.int32,
requires_grad=False,
)
else:
Expand Down
9 changes: 5 additions & 4 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import io
import pickle
import sys
import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen, NVTBerendsen
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase.md.verlet import VelocityVerlet
from ase.optimize.bfgs import BFGS
Expand Down Expand Up @@ -610,11 +610,12 @@ def __init__(
except Exception:
bulk_modulus_au = 2 / 160.2176
compressibility_au = 1 / bulk_modulus_au
print(
warnings.warn(
"Warning!!! Equation of State fitting failed, setting bulk "
"modulus to 2 GPa. NPT simulation can proceed with incorrect "
"pressure relaxation time."
"User input for bulk modulus is recommended."
"User input for bulk modulus is recommended.",
stacklevel=2,
)
self.bulk_modulus = bulk_modulus

Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def from_dict(cls, dct: dict, **kwargs) -> Self:
@classmethod
def from_file(cls, path: str, **kwargs) -> Self:
"""Build a CHGNet from a saved file."""
state = torch.load(path, map_location=torch.device("cpu"))
state = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
return cls.from_dict(state["model"], **kwargs)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import datetime
import inspect
import os
import random
import shutil
import time
from datetime import datetime
from typing import TYPE_CHECKING, Literal, get_args

import numpy as np
Expand Down Expand Up @@ -285,7 +285,7 @@ def train(
raise ValueError("Model needs to be initialized")
global best_checkpoint # noqa: PLW0603
if save_dir is None:
save_dir = f"{datetime.now():%m-%d-%Y}"
save_dir = f"{datetime.datetime.now(tz=datetime.timezone.utc):%m-%d-%Y}"

print(f"Begin Training: using {self.device} device")
print(f"training targets: {self.targets}")
Expand Down
2 changes: 1 addition & 1 deletion chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parse_vasp_dir(
dict: a dictionary of lists with keys for structure, uncorrected_total_energy,
energy_per_atom, force, magmom, stress.
"""
if os.path.isdir(base_dir) is False:
if not os.path.isdir(base_dir):
raise NotADirectoryError(f"{base_dir=} is not a directory")

oszicar_path = zpath(f"{base_dir}/OSZICAR")
Expand Down
35 changes: 18 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ dependencies = [
"cython>=3",
"numpy>=1.26,<2",
"nvidia-ml-py3>=7.352.0",
"pymatgen>=2023.10.11",
"pymatgen==2024.8.9",
"torch>=1.11.0",
"typing-extensions>=4.12",
]
classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Chemistry",
"Topic :: Scientific/Engineering :: Physics",
]

[project.optional-dependencies]
test = ["pytest-cov>=4", "pytest>=8"]
test = ["pytest-cov>=4", "pytest>=8", "wandb>=0.17"]
# needed to run interactive example notebooks
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
Expand All @@ -49,7 +49,11 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }
"chgnet.pretrained" = ["*", "**/*"]

[build-system]
requires = ["Cython", "setuptools>=65,<70", "wheel"]
requires = [
"Cython",
"setuptools>=65",
"wheel",
]
build-backend = "setuptools.build_meta"

[tool.ruff]
Expand All @@ -59,35 +63,32 @@ target-version = "py39"
select = ["ALL"]
ignore = [
"ANN001", # TODO add missing type annotations
"ANN003",
"ANN101",
"ANN102",
"ANN003", # Missing type annotation for **{name}
"ANN101", # Missing type annotation for {name} in method
"ANN102", # Missing type annotation for {name} in classmethod
"B019", # Use of functools.lru_cache on methods can lead to memory leaks
"BLE001",
"BLE001", # use of general except Exception
"C408", # unnecessary-collection-call
"C901", # function is too complex
"COM812", # trailing comma missing
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"DTZ005", # use of datetime.now() without timezone
"E731", # do not assign a lambda expression, use a def
"EM",
"EM", # error message related
"ERA001", # found commented out code
"ISC001",
"NPY002", # TODO replace legacy np.random.seed
"PLR0912", # too many branches
"PLR0913", # too many args in function def
"PLR0915", # too many statements
"PLW2901", # Outer for loop variable overwritten by inner assignment target
"PT006", # pytest-parametrize-names-wrong-type
"PTH", # prefer Path to os.path
"S108",
"S108", # Probable insecure usage of temporary file or directory
"S301", # pickle can be unsafe
"S310",
"S311",
"TRY003",
"TRY300",
"S310", # Audit URL open for permitted schemes
"S311", # pseudo-random generators not suitable for cryptographic purposes
"TRY003", # Avoid specifying long messages outside the exception class
"TRY300", # Consider moving this statement to an else block
]
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@

ext_modules = [Extension("chgnet.graph.cygraph", ["chgnet/graph/cygraph.pyx"])]

setup(ext_modules=ext_modules, setup_requires=["Cython"])
setup(
ext_modules=ext_modules,
setup_requires=["Cython"],
)
2 changes: 1 addition & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _set_make_graph() -> Generator[None, None, None]:


@pytest.mark.parametrize(
"atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (5, None), (4, 2)]
("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (5, None), (4, 2)]
)
def test_crystal_graph_converter_cutoff(
atom_graph_cutoff: float | None, bond_graph_cutoff: float | None
Expand Down
61 changes: 31 additions & 30 deletions tests/test_crystal_graph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from time import perf_counter
from unittest.mock import patch

import numpy as np
from pymatgen.core import Structure

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter

np.random.seed(0)

structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3)
converter_legacy = CrystalGraphConverter(
Expand Down Expand Up @@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast():


def test_crystal_graph_perturb_legacy():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_legacy(structure_perturbed)
print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_perturb_fast():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_fast(structure_perturbed)
print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_isotropic_strained_legacy():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_atom_embedding(atom_feature_dim: int, max_num_elements: int) -> None:
assert "index out of range" in str(exc_info.value)


@pytest.mark.parametrize("atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (6, 4)])
@pytest.mark.parametrize(("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (6, 4)])
def test_bond_encoder(atom_graph_cutoff: float, bond_graph_cutoff: float) -> None:
undirected2directed = torch.tensor([0, 1])
image = torch.zeros((2, 3))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.mark.parametrize(
"algorithm, ase_filter, assign_magmoms",
("algorithm", "ase_filter", "assign_magmoms"),
[("legacy", FrechetCellFilter, True), ("fast", ExpCellFilter, False)],
)
def test_relaxation(
Expand Down
Loading

0 comments on commit 9281cf4

Please sign in to comment.