Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ruff fixes #184

Merged
merged 42 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e02ec19
build against numpy2
DanielYang59 Aug 5, 2024
d581717
Need confirm: remove setuptool upper pin
DanielYang59 Aug 5, 2024
110f4eb
bump pymatgen version to support NP2
DanielYang59 Aug 9, 2024
9996c5d
include wandb into test dependency as it's required by test_trainer
DanielYang59 Aug 9, 2024
1c74710
explicitly use weights_only False to avoid FutureWarning
DanielYang59 Aug 9, 2024
6fa3755
use tem path in test
DanielYang59 Aug 9, 2024
3023f84
bump pymatgen to resolve np2 compatibility issue
DanielYang59 Sep 10, 2024
6a4246d
use int64
DanielYang59 Sep 11, 2024
62f3638
avoid release candidate for numpy build
DanielYang59 Sep 11, 2024
cc1f833
Merge branch 'CederGroupHub:main' into numpy2
DanielYang59 Sep 11, 2024
d628068
revert dtype change for pytorch
DanielYang59 Sep 11, 2024
d0c6b5b
fix error type
DanielYang59 Sep 11, 2024
b4f05c5
NEED CONFIRM: patch numpy random generator
DanielYang59 Sep 11, 2024
7f7d6b6
pip install on win IS SLOW, do something while waiting
DanielYang59 Sep 11, 2024
c70f804
use cython type
DanielYang59 Sep 11, 2024
1d0b153
include numpy c header
DanielYang59 Sep 11, 2024
4f044eb
int64 might be better, intp is still platform dependent
DanielYang59 Sep 11, 2024
45a1fd7
maybe fix setup.py ModuleNotFoundError: No module named 'numpy'
janosh Sep 11, 2024
b8f6cfa
ruff fixes
DanielYang59 Sep 11, 2024
e4ceb6a
Revert "maybe fix setup.py ModuleNotFoundError: No module named 'numpy'"
DanielYang59 Sep 11, 2024
d00a425
I don't think we need to manually install build deps, hoping I'm not …
DanielYang59 Sep 11, 2024
af86ce6
add --system arg to uv pip install
DanielYang59 Sep 11, 2024
d7aaded
merge uv install stage?
DanielYang59 Sep 11, 2024
47a77e0
fix typo in error message
DanielYang59 Sep 11, 2024
570bfed
NEED CONFIRM: use torch.int32 over int for torch tensor
DanielYang59 Sep 11, 2024
ff8bf84
raise error instead of systemexit when failure
DanielYang59 Sep 11, 2024
1f40088
make module level var all cap
DanielYang59 Sep 11, 2024
b60f5e4
fix: /Users/runner/work/chgnet/chgnet/chgnet/model/dynamics.py:297: D…
DanielYang59 Sep 11, 2024
e53337e
let's see if any array is not int64
DanielYang59 Sep 12, 2024
1b84de4
use setup-uv as Janosh suggested
DanielYang59 Sep 12, 2024
59787a6
try to fix package install
DanielYang59 Sep 12, 2024
bda704b
NEED CONFIRM: drop python 3.9
DanielYang59 Sep 12, 2024
4eed205
Revert "NEED CONFIRM: drop python 3.9"
DanielYang59 Sep 12, 2024
1d1ec65
use uv pip install for now
DanielYang59 Sep 12, 2024
4cafef7
revert to traditional uv pip install for now as setup-uv doesn't have…
DanielYang59 Sep 12, 2024
d1bf176
fix more missing type changes
DanielYang59 Sep 12, 2024
107fb02
revert changes to cygraph.pyx
DanielYang59 Sep 12, 2024
190d6e8
revert changes related to np2
DanielYang59 Sep 12, 2024
c9e8e5f
revert more np2 related changes
DanielYang59 Sep 12, 2024
8875487
revert more np related changes
DanielYang59 Sep 12, 2024
8c3ffa2
pin pmg 24.8.9
DanielYang59 Sep 12, 2024
aeb9053
Revert "fix: /Users/runner/work/chgnet/chgnet/chgnet/model/dynamics.p…
DanielYang59 Sep 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,10 @@ jobs:
cache: pip
cache-dependency-path: pyproject.toml

- name: Install uv
run: pip install uv

- name: Install dependencies
- name: Install chgnet through uv
run: |
uv pip install cython 'setuptools<70' --system

python setup.py build_ext --inplace

uv pip install -e .[test,logging] --system --resolution=${{ matrix.version.resolution }}
pip install uv
uv pip install -e .[test,logging] --resolution=${{ matrix.version.resolution }} --system

- name: Run Tests
run: pytest --capture=no --cov --cov-report=xml
Expand Down
14 changes: 8 additions & 6 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import gc
import sys
import warnings
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

import numpy as np
import torch
Expand All @@ -13,6 +13,8 @@
from chgnet.graph.graph import Graph, Node

if TYPE_CHECKING:
from typing import Literal

from pymatgen.core import Structure
from typing_extensions import Self

Expand All @@ -21,7 +23,7 @@
except (ImportError, AttributeError):
make_graph = None

datatype = torch.float32
DATATYPE = torch.float32


class CrystalGraphConverter(nn.Module):
Expand Down Expand Up @@ -122,10 +124,10 @@ def forward(
requires_grad=False,
)
atom_frac_coord = torch.tensor(
structure.frac_coords, dtype=datatype, requires_grad=True
structure.frac_coords, dtype=DATATYPE, requires_grad=True
)
lattice = torch.tensor(
structure.lattice.matrix, dtype=datatype, requires_grad=True
structure.lattice.matrix, dtype=DATATYPE, requires_grad=True
)
center_index, neighbor_index, image, distance = structure.get_neighbor_list(
r=self.atom_graph_cutoff, sites=structure.sites, numerical_tol=1e-8
Expand All @@ -150,7 +152,7 @@ def forward(
# Report structures that failed creating bond graph
# This happen occasionally with pymatgen version issue
structure.to(filename="bond_graph_error.cif")
raise SystemExit(
Copy link
Contributor Author

@DanielYang59 DanielYang59 Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is intended, but I don't think SystemExit should be raised here? ff8bf84

raise RuntimeError(
f"Failed creating bond graph for {graph_id}, check bond_graph_error.cif"
) from exc
bond_graph = torch.tensor(bond_graph, dtype=torch.int32)
Expand All @@ -175,7 +177,7 @@ def forward(
atomic_number=atomic_number,
atom_frac_coord=atom_frac_coord,
atom_graph=atom_graph,
neighbor_image=torch.tensor(image, dtype=datatype),
neighbor_image=torch.tensor(image, dtype=DATATYPE),
directed2undirected=directed2undirected,
undirected2directed=undirected2directed,
bond_graph=bond_graph,
Expand Down
2 changes: 1 addition & 1 deletion chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def line_graph_adjacency_list(self, cutoff) -> tuple[list[list[int]], list[int]]
if len(self.directed_edges_list) != 2 * len(self.undirected_edges_list):
raise ValueError(
f"Error: number of directed edges={len(self.directed_edges_list)} != 2 "
f"* number of undirected edges={len(self.directed_edges_list)}!"
f"* number of undirected edges={len(self.undirected_edges_list)}!"
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
f"This indicates directed edges are not complete"
)
line_graph = []
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def fit(
if isinstance(structure, Structure):
atomic_number = torch.tensor(
[site.specie.Z for site in structure],
dtype=int,
dtype=torch.int32,
requires_grad=False,
)
else:
Expand Down
9 changes: 5 additions & 4 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import io
import pickle
import sys
import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen, NVTBerendsen
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
from ase.md.verlet import VelocityVerlet
from ase.optimize.bfgs import BFGS
Expand Down Expand Up @@ -610,11 +610,12 @@ def __init__(
except Exception:
bulk_modulus_au = 2 / 160.2176
compressibility_au = 1 / bulk_modulus_au
print(
warnings.warn(
"Warning!!! Equation of State fitting failed, setting bulk "
"modulus to 2 GPa. NPT simulation can proceed with incorrect "
"pressure relaxation time."
"User input for bulk modulus is recommended."
"User input for bulk modulus is recommended.",
stacklevel=2,
)
self.bulk_modulus = bulk_modulus

Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def from_dict(cls, dct: dict, **kwargs) -> Self:
@classmethod
def from_file(cls, path: str, **kwargs) -> Self:
"""Build a CHGNet from a saved file."""
state = torch.load(path, map_location=torch.device("cpu"))
state = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
return cls.from_dict(state["model"], **kwargs)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import datetime
import inspect
import os
import random
import shutil
import time
from datetime import datetime
from typing import TYPE_CHECKING, Literal, get_args

import numpy as np
Expand Down Expand Up @@ -285,7 +285,7 @@ def train(
raise ValueError("Model needs to be initialized")
global best_checkpoint # noqa: PLW0603
if save_dir is None:
save_dir = f"{datetime.now():%m-%d-%Y}"
save_dir = f"{datetime.datetime.now(tz=datetime.timezone.utc):%m-%d-%Y}"
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved

print(f"Begin Training: using {self.device} device")
print(f"training targets: {self.targets}")
Expand Down
2 changes: 1 addition & 1 deletion chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parse_vasp_dir(
dict: a dictionary of lists with keys for structure, uncorrected_total_energy,
energy_per_atom, force, magmom, stress.
"""
if os.path.isdir(base_dir) is False:
if not os.path.isdir(base_dir):
raise NotADirectoryError(f"{base_dir=} is not a directory")

oszicar_path = zpath(f"{base_dir}/OSZICAR")
Expand Down
35 changes: 18 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ dependencies = [
"cython>=3",
"numpy>=1.26,<2",
"nvidia-ml-py3>=7.352.0",
"pymatgen>=2023.10.11",
"pymatgen==2024.8.9",
"torch>=1.11.0",
"typing-extensions>=4.12",
]
classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Chemistry",
"Topic :: Scientific/Engineering :: Physics",
]

[project.optional-dependencies]
test = ["pytest-cov>=4", "pytest>=8"]
test = ["pytest-cov>=4", "pytest>=8", "wandb>=0.17"]
# needed to run interactive example notebooks
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
Expand All @@ -48,7 +48,11 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }
"chgnet.pretrained" = ["*", "**/*"]

[build-system]
requires = ["Cython", "setuptools>=65,<70", "wheel"]
requires = [
"Cython",
"setuptools>=65",
"wheel",
]
build-backend = "setuptools.build_meta"

[tool.ruff]
Expand All @@ -58,35 +62,32 @@ target-version = "py39"
select = ["ALL"]
ignore = [
"ANN001", # TODO add missing type annotations
"ANN003",
"ANN101",
"ANN102",
"ANN003", # Missing type annotation for **{name}
"ANN101", # Missing type annotation for {name} in method
"ANN102", # Missing type annotation for {name} in classmethod
"B019", # Use of functools.lru_cache on methods can lead to memory leaks
"BLE001",
"BLE001", # use of general except Exception
"C408", # unnecessary-collection-call
"C901", # function is too complex
"COM812", # trailing comma missing
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
"D205", # 1 blank line required between summary line and description
"DTZ005", # use of datetime.now() without timezone
"E731", # do not assign a lambda expression, use a def
"EM",
"EM", # error message related
"ERA001", # found commented out code
"ISC001",
"NPY002", # TODO replace legacy np.random.seed
"PLR0912", # too many branches
"PLR0913", # too many args in function def
"PLR0915", # too many statements
"PLW2901", # Outer for loop variable overwritten by inner assignment target
"PT006", # pytest-parametrize-names-wrong-type
"PTH", # prefer Path to os.path
"S108",
"S108", # Probable insecure usage of temporary file or directory
"S301", # pickle can be unsafe
"S310",
"S311",
"TRY003",
"TRY300",
"S310", # Audit URL open for permitted schemes
"S311", # pseudo-random generators not suitable for cryptographic purposes
"TRY003", # Avoid specifying long messages outside the exception class
"TRY300", # Consider moving this statement to an else block
]
pydocstyle.convention = "google"
isort.required-imports = ["from __future__ import annotations"]
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@

ext_modules = [Extension("chgnet.graph.cygraph", ["chgnet/graph/cygraph.pyx"])]

setup(ext_modules=ext_modules, setup_requires=["Cython"])
setup(
ext_modules=ext_modules,
setup_requires=["Cython"],
)
2 changes: 1 addition & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _set_make_graph() -> Generator[None, None, None]:


@pytest.mark.parametrize(
"atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (5, None), (4, 2)]
("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (5, None), (4, 2)]
)
def test_crystal_graph_converter_cutoff(
atom_graph_cutoff: float | None, bond_graph_cutoff: float | None
Expand Down
61 changes: 31 additions & 30 deletions tests/test_crystal_graph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from time import perf_counter
from unittest.mock import patch

import numpy as np
from pymatgen.core import Structure

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter

np.random.seed(0)

structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3)
converter_legacy = CrystalGraphConverter(
Expand Down Expand Up @@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast():


def test_crystal_graph_perturb_legacy():
np.random.seed(0)
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add a seed kwarg to Structure.perturb that's simply passed into np.random.default_rng(seed=seed) over in pymatgen

Copy link
Contributor Author

@DanielYang59 DanielYang59 Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, especially with the generator implementation.

I didn't realize this downside before but it turns out with the new generator implementation every rng is a isolated instance, making it impossible to set a global seed and control all random states (or it's just me unaware of that).


start = perf_counter()
graph = converter_legacy(structure_perturbed)
print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_perturb_fast():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_fast(structure_perturbed)
print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_isotropic_strained_legacy():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_atom_embedding(atom_feature_dim: int, max_num_elements: int) -> None:
assert "index out of range" in str(exc_info.value)


@pytest.mark.parametrize("atom_graph_cutoff, bond_graph_cutoff", [(5, 3), (6, 4)])
@pytest.mark.parametrize(("atom_graph_cutoff", "bond_graph_cutoff"), [(5, 3), (6, 4)])
def test_bond_encoder(atom_graph_cutoff: float, bond_graph_cutoff: float) -> None:
undirected2directed = torch.tensor([0, 1])
image = torch.zeros((2, 3))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.mark.parametrize(
"algorithm, ase_filter, assign_magmoms",
("algorithm", "ase_filter", "assign_magmoms"),
[("legacy", FrechetCellFilter, True), ("fast", ExpCellFilter, False)],
)
def test_relaxation(
Expand Down
Loading
Loading