Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ruff fixes #184

Merged
merged 42 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e02ec19
build against numpy2
DanielYang59 Aug 5, 2024
d581717
Need confirm: remove setuptool upper pin
DanielYang59 Aug 5, 2024
110f4eb
bump pymatgen version to support NP2
DanielYang59 Aug 9, 2024
9996c5d
include wandb into test dependency as it's required by test_trainer
DanielYang59 Aug 9, 2024
1c74710
explicitly use weights_only False to avoid FutureWarning
DanielYang59 Aug 9, 2024
6fa3755
use tem path in test
DanielYang59 Aug 9, 2024
3023f84
bump pymatgen to resolve np2 compatibility issue
DanielYang59 Sep 10, 2024
6a4246d
use int64
DanielYang59 Sep 11, 2024
62f3638
avoid release candidate for numpy build
DanielYang59 Sep 11, 2024
cc1f833
Merge branch 'CederGroupHub:main' into numpy2
DanielYang59 Sep 11, 2024
d628068
revert dtype change for pytorch
DanielYang59 Sep 11, 2024
d0c6b5b
fix error type
DanielYang59 Sep 11, 2024
b4f05c5
NEED CONFIRM: patch numpy random generator
DanielYang59 Sep 11, 2024
7f7d6b6
pip install on win IS SLOW, do something while waiting
DanielYang59 Sep 11, 2024
c70f804
use cython type
DanielYang59 Sep 11, 2024
1d0b153
include numpy c header
DanielYang59 Sep 11, 2024
4f044eb
int64 might be better, intp is still platform dependent
DanielYang59 Sep 11, 2024
45a1fd7
maybe fix setup.py ModuleNotFoundError: No module named 'numpy'
janosh Sep 11, 2024
b8f6cfa
ruff fixes
DanielYang59 Sep 11, 2024
e4ceb6a
Revert "maybe fix setup.py ModuleNotFoundError: No module named 'numpy'"
DanielYang59 Sep 11, 2024
d00a425
I don't think we need to manually install build deps, hoping I'm not …
DanielYang59 Sep 11, 2024
af86ce6
add --system arg to uv pip install
DanielYang59 Sep 11, 2024
d7aaded
merge uv install stage?
DanielYang59 Sep 11, 2024
47a77e0
fix typo in error message
DanielYang59 Sep 11, 2024
570bfed
NEED CONFIRM: use torch.int32 over int for torch tensor
DanielYang59 Sep 11, 2024
ff8bf84
raise error instead of systemexit when failure
DanielYang59 Sep 11, 2024
1f40088
make module level var all cap
DanielYang59 Sep 11, 2024
b60f5e4
fix: /Users/runner/work/chgnet/chgnet/chgnet/model/dynamics.py:297: D…
DanielYang59 Sep 11, 2024
e53337e
let's see if any array is not int64
DanielYang59 Sep 12, 2024
1b84de4
use setup-uv as Janosh suggested
DanielYang59 Sep 12, 2024
59787a6
try to fix package install
DanielYang59 Sep 12, 2024
bda704b
NEED CONFIRM: drop python 3.9
DanielYang59 Sep 12, 2024
4eed205
Revert "NEED CONFIRM: drop python 3.9"
DanielYang59 Sep 12, 2024
1d1ec65
use uv pip install for now
DanielYang59 Sep 12, 2024
4cafef7
revert to traditional uv pip install for now as setup-uv doesn't have…
DanielYang59 Sep 12, 2024
d1bf176
fix more missing type changes
DanielYang59 Sep 12, 2024
107fb02
revert changes to cygraph.pyx
DanielYang59 Sep 12, 2024
190d6e8
revert changes related to np2
DanielYang59 Sep 12, 2024
c9e8e5f
revert more np2 related changes
DanielYang59 Sep 12, 2024
8875487
revert more np related changes
DanielYang59 Sep 12, 2024
8c3ffa2
pin pmg 24.8.9
DanielYang59 Sep 12, 2024
aeb9053
Revert "fix: /Users/runner/work/chgnet/chgnet/chgnet/model/dynamics.p…
DanielYang59 Sep 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _create_graph_fast(
"""
center_index = np.ascontiguousarray(center_index)
neighbor_index = np.ascontiguousarray(neighbor_index)
image = np.ascontiguousarray(image, dtype=np.int_)
image = np.ascontiguousarray(image, dtype=np.int64)
distance = np.ascontiguousarray(distance)
gc_saved = gc.get_threshold()
gc.set_threshold(0)
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def from_dict(cls, dct: dict, **kwargs) -> Self:
@classmethod
def from_file(cls, path: str, **kwargs) -> Self:
"""Build a CHGNet from a saved file."""
state = torch.load(path, map_location=torch.device("cpu"))
state = torch.load(path, map_location=torch.device("cpu"), weights_only=False)
return cls.from_dict(state["model"], **kwargs)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion chgnet/utils/vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def parse_vasp_dir(
dict: a dictionary of lists with keys for structure, uncorrected_total_energy,
energy_per_atom, force, magmom, stress.
"""
if os.path.isdir(base_dir) is False:
if not os.path.isdir(base_dir):
raise NotADirectoryError(f"{base_dir=} is not a directory")

oszicar_path = zpath(f"{base_dir}/OSZICAR")
Expand Down
13 changes: 9 additions & 4 deletions pyproject.toml
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ license = { text = "Modified BSD" }
dependencies = [
"ase>=3.23.0",
"cython>=3",
"numpy>=1.26,<2",
"numpy>=1.26",
"nvidia-ml-py3>=7.352.0",
"pymatgen>=2023.10.11",
"pymatgen>=2024.9.10",
"torch>=1.11.0",
"typing-extensions>=4.12",
]
Expand All @@ -29,7 +29,7 @@ classifiers = [
]

[project.optional-dependencies]
test = ["pytest-cov>=4", "pytest>=8"]
test = ["pytest-cov>=4", "pytest>=8", "wandb>=0.17"]
# needed to run interactive example notebooks
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
Expand All @@ -48,7 +48,12 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }
"chgnet.pretrained" = ["*", "**/*"]

[build-system]
requires = ["Cython", "setuptools>=65,<70", "wheel"]
requires = [
"Cython",
"setuptools>=65",
"wheel",
"numpy>=2.0.0",
]
build-backend = "setuptools.build_meta"

[tool.ruff]
Expand Down
61 changes: 31 additions & 30 deletions tests/test_crystal_graph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

from time import perf_counter
from unittest.mock import patch

import numpy as np
from pymatgen.core import Structure

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter

np.random.seed(0)

structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
converter = CrystalGraphConverter(atom_graph_cutoff=5, bond_graph_cutoff=3)
converter_legacy = CrystalGraphConverter(
Expand Down Expand Up @@ -127,55 +126,57 @@ def test_crystal_graph_different_cutoff_fast():


def test_crystal_graph_perturb_legacy():
np.random.seed(0)
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add a seed kwarg to Structure.perturb that's simply passed into np.random.default_rng(seed=seed) over in pymatgen

Copy link
Contributor Author

@DanielYang59 DanielYang59 Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, especially with the generator implementation.

I didn't realize this downside before but it turns out with the new generator implementation every rng is a isolated instance, making it impossible to set a global seed and control all random states (or it's just me unaware of that).


start = perf_counter()
graph = converter_legacy(structure_perturbed)
print("Legacy test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_perturb_fast():
np.random.seed(0)
structure_perturbed = structure.copy()
structure_perturbed.perturb(distance=0.1)
fixed_rng = np.random.default_rng(0)
with patch("numpy.random.default_rng", return_value=fixed_rng):
structure_perturbed.perturb(distance=0.1)

start = perf_counter()
graph = converter_fast(structure_perturbed)
print("Fast test_crystal_graph_perturb time:", perf_counter() - start) # noqa: T201

assert list(graph.atom_frac_coord.shape) == [8, 3]
assert list(graph.atom_graph.shape) == [410, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 3).sum().item() == 53
assert (graph.atom_graph[:, 1] == 6).sum().item() == 50

assert list(graph.bond_graph.shape) == [688, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 90
assert (graph.bond_graph[:, 1] == 36).sum().item() == 17
assert (graph.bond_graph[:, 3] == 36).sum().item() == 17
assert (graph.bond_graph[:, 2] == 306).sum().item() == 10
assert list(graph.atom_graph.shape) == [420, 2]
assert (graph.atom_graph[:, 0] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 3).sum().item() == 54
assert (graph.atom_graph[:, 1] == 6).sum().item() == 54

assert list(graph.bond_graph.shape) == [850, 5]
assert (graph.bond_graph[:, 0] == 1).sum().item() == 156
assert (graph.bond_graph[:, 1] == 36).sum().item() == 18
assert (graph.bond_graph[:, 3] == 36).sum().item() == 18
assert (graph.bond_graph[:, 2] == 306).sum().item() == 0
assert (graph.bond_graph[:, 4] == 120).sum().item() == 0
assert list(graph.lattice.shape) == [3, 3]
assert list(graph.undirected2directed.shape) == [205]
assert list(graph.directed2undirected.shape) == [410]
assert list(graph.undirected2directed.shape) == [210]
assert list(graph.directed2undirected.shape) == [420]


def test_crystal_graph_isotropic_strained_legacy():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,19 +153,19 @@ def test_wandb_init(mock_wandb):
)


def test_wandb_log_frequency(mock_wandb):
def test_wandb_log_frequency(tmp_path, mock_wandb):
trainer = Trainer(model=chgnet, wandb_path="test-project/test-run", epochs=1)

# Test epoch logging
trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir="")
trainer.train(train_loader, val_loader, wandb_log_freq="epoch", save_dir=tmp_path)
assert (
mock_wandb.log.call_count == 2 * trainer.epochs
), "Expected one train and one val log per epoch"

mock_wandb.log.reset_mock()

# Test batch logging
trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir="")
trainer.train(train_loader, val_loader, wandb_log_freq="batch", save_dir=tmp_path)
expected_batch_calls = trainer.epochs * len(train_loader)
assert (
mock_wandb.log.call_count > expected_batch_calls
Expand All @@ -183,5 +183,5 @@ def test_wandb_log_frequency(mock_wandb):

# Test no logging when wandb_path is not provided
trainer_no_wandb = Trainer(model=chgnet, epochs=1)
trainer_no_wandb.train(train_loader, val_loader)
trainer_no_wandb.train(train_loader, val_loader, save_dir=tmp_path)
mock_wandb.log.assert_not_called()
2 changes: 1 addition & 1 deletion tests/test_vasp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_parse_vasp_dir_without_magmoms(tmp_path: Path):

def test_parse_vasp_dir_no_data():
# test non-existing directory
with pytest.raises(FileNotFoundError, match="is not a directory"):
with pytest.raises(NotADirectoryError, match="is not a directory"):
parse_vasp_dir(f"{ROOT}/tests/files/non-existent")

# test existing directory without VASP files
Expand Down
Loading