Skip to content

Commit

Permalink
v0.9.3
Browse files Browse the repository at this point in the history
v0.9.3
  • Loading branch information
YutackPark authored Jul 26, 2024
2 parents 1c97a88 + 493c0a6 commit 6d00fc7
Show file tree
Hide file tree
Showing 58 changed files with 864 additions and 663 deletions.
37 changes: 18 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ repos:
- id: check-shebang-scripts-are-executable
- id: check-merge-conflict
- id: check-vcs-permalinks
- id: check-yaml
- id: debug-statements
- id: destroyed-symlinks
- id: detect-private-key
Expand All @@ -41,32 +40,32 @@ repos:
rev: 5.13.2
hooks:
- id: isort
exclude: 'pair_e3gnn/pair_e3gnn_parallel.h'
exclude: 'pair_e3gnn/comm_brick.h'
exclude: 'pair_e3gnn/comm_brick.cpp'

- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
args: ['--skip-string-normalization', '--line-length=79', '--preview']
exclude: 'pair_e3gnn/pair_e3gnn_parallel.h'
exclude: 'pair_e3gnn/comm_brick.h'
exclude: 'pair_e3gnn/comm_brick.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn.h'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn_parallel.h'
exclude: 'sevenn/pair_e3gnn/comm_brick.h'
exclude: 'sevenn/pair_e3gnn/comm_brick.cpp'

- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
exclude: 'sevenn/pair_e3gnn/pair_e3gnn.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn.h'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn_parallel.h'
exclude: 'sevenn/pair_e3gnn/comm_brick.h'
exclude: 'sevenn/pair_e3gnn/comm_brick.cpp'
require_serial: true
exclude: 'pair_e3gnn/pair_e3gnn_parallel.h'
exclude: 'pair_e3gnn/comm_brick.h'
exclude: 'pair_e3gnn/comm_brick.cpp'

- repo: https://github.com/pre-commit/mirrors-clang-format
rev: 'v17.0.6'
hooks:
- id: clang-format
exclude: 'pair_e3gnn/pair_e3gnn_parallel.h'
exclude: 'pair_e3gnn/comm_brick.h'
exclude: 'pair_e3gnn/comm_brick.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn.h'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp'
exclude: 'sevenn/pair_e3gnn/pair_e3gnn_parallel.h'
exclude: 'sevenn/pair_e3gnn/comm_brick.h'
exclude: 'sevenn/pair_e3gnn/comm_brick.cpp'
235 changes: 103 additions & 132 deletions README.md

Large diffs are not rendered by default.

20 changes: 0 additions & 20 deletions logo_ascii

This file was deleted.

58 changes: 58 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
[project]
name = "sevenn"
version = "0.9.3"
authors = [
{ name="Yutack Park", email="parkyutack@snu.ac.kr" },
{ name="Jaesun Kim" },
]
description = "Scalable EquiVariance Enabled Neural Network"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: POSIX :: Linux",
]
dependencies = [
"ase",
"braceexpand",
"pyyaml",
"e3nn",
"tqdm",
"scikit-learn",
"torch_geometric",
"numpy<2.0",
]


[project.scripts]
sevenn = "sevenn.main.sevenn:main"
sevenn_get_model = "sevenn.main.sevenn_get_model:main"
sevenn_graph_build = "sevenn.main.sevenn_graph_build:main"
sevenn_inference = "sevenn.main.sevenn_inference:main"
sevenn_patch_lammps = "sevenn.main.sevenn_patch_lammps:main"
sevenn_preset = "sevenn.main.sevenn_preset:main"

[project.urls]
Homepage = "https://github.com/MDIL-SNU/SevenNet"
Issues = "https://github.com/MDIL-SNU/SevenNet/issues"

[build-system]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=61.0"]

[tool.setuptools.package-data]
sevenn = [
"logo_ascii",
"pair_e3gnn/*.cpp",
"pair_e3gnn/*.h",
"pair_e3gnn/patch_lammps.sh",
"presets/*.yaml",
"pretrained_potentials/SevenNet_0__11July2024/checkpoint_sevennet_0.pth",
"pretrained_potentials/SevenNet_0__22May2024/checkpoint_sevennet_0.pth"
]

[tool.setuptools.packages.find]
include = ["sevenn*"]
exclude = ["tests*", "example_inputs*", ]
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=80
known_third_party=ase,braceexpand,e3nn,numpy,setuptools,sklearn,torch,torch_geometric,torch_scatter,tqdm,yaml
known_third_party=ase,braceexpand,e3nn,numpy,setuptools,sklearn,torch,torch_geometric,tqdm,yaml
known_first_party=
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from setuptools import find_packages, setup
# from setuptools import find_packages, setup
import setuptools

setuptools.setup()

"""
setup(
name='sevenn',
version='0.9.2',
Expand All @@ -25,3 +29,4 @@
]
},
)
"""
38 changes: 26 additions & 12 deletions sevenn/_const.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from enum import Enum
from typing import Dict

Expand All @@ -6,7 +7,7 @@
import sevenn._keys as KEY
from sevenn.nn.activation import ShiftedSoftPlus

SEVENN_VERSION = '0.9.2'
SEVENN_VERSION = '0.9.3'
IMPLEMENTED_RADIAL_BASIS = ['bessel']
IMPLEMENTED_CUTOFF_FUNCTION = ['poly_cut', 'XPLOR']
# TODO: support None. This became difficult because of paralell model
Expand Down Expand Up @@ -44,6 +45,14 @@
ACTIVATION_FOR_ODD = {'tanh': torch.tanh, 'abs': torch.abs}
ACTIVATION_DICT = {'e': ACTIVATION_FOR_EVEN, 'o': ACTIVATION_FOR_ODD}

_prefix = os.path.abspath(f'{os.path.dirname(__file__)}/pretrained_potentials')
SEVENNET_0_11July2024 = (
f'{_prefix}/SevenNet_0__11July2024/checkpoint_sevennet_0.pth'
)
SEVENNET_0_22May2024 = (
f'{_prefix}/SevenNet_0__22May2024/checkpoint_sevennet_0.pth'
)


# to avoid torch script to compile torch_geometry.data
AtomGraphDataType = Dict[str, torch.Tensor]
Expand Down Expand Up @@ -90,12 +99,12 @@ def error_record_condition(x):
KEY.NUM_CONVOLUTION: 3,
KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'},
KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'},
#KEY.AVG_NUM_NEIGH: True, # deprecated
#KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
# KEY.AVG_NUM_NEIGH: True, # deprecated
# KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
KEY.CONV_DENOMINATOR: 'avg_num_neigh',
KEY.TRAIN_DENOMINTAOR: False,
KEY.TRAIN_SHIFT_SCALE: False,
#KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY.TRAIN_SHIFT_SCALE: False,
# KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY.USE_BIAS_IN_LINEAR: False,
KEY.READOUT_AS_FCN: False,
# Applied af readout as fcn is True
Expand All @@ -121,7 +130,10 @@ def error_record_condition(x):
},
KEY.CUTOFF: float,
KEY.NUM_CONVOLUTION: int,
KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) or x in ["avg_num_neigh", "sqrt_avg_num_neigh"],
KEY.CONV_DENOMINATOR: lambda x: isinstance(x, float) or x in [
'avg_num_neigh',
'sqrt_avg_num_neigh',
],
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: list,
KEY.TRAIN_SHIFT_SCALE: bool,
KEY.TRAIN_DENOMINTAOR: bool,
Expand All @@ -147,7 +159,6 @@ def model_defaults(config):
return defaults



DEFAULT_DATA_CONFIG = {
KEY.DTYPE: 'single',
KEY.DATA_FORMAT: 'ase',
Expand All @@ -158,9 +169,9 @@ def model_defaults(config):
KEY.RATIO: 0.1,
KEY.BATCH_SIZE: 6,
KEY.PREPROCESS_NUM_CORES: 1,
#KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY.SHIFT: "per_atom_energy_mean",
KEY.SCALE: "force_rms",
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY.SHIFT: 'per_atom_energy_mean',
KEY.SCALE: 'force_rms',
KEY.DATA_SHUFFLE: True,
}

Expand All @@ -174,7 +185,7 @@ def model_defaults(config):
KEY.RATIO: float,
KEY.BATCH_SIZE: int,
KEY.PREPROCESS_NUM_CORES: int,
#KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT,
KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE,
KEY.DATA_SHUFFLE: bool,
Expand All @@ -193,8 +204,11 @@ def data_defaults(config):
KEY.RANDOM_SEED: 1,
KEY.EPOCH: 300,
KEY.LOSS: 'mse',
KEY.LOSS_PARAM: {},
KEY.OPTIMIZER: 'adam',
KEY.OPTIM_PARAM: {},
KEY.SCHEDULER: 'exponentiallr',
KEY.SCHEDULER_PARAM: {},
KEY.FORCE_WEIGHT: 0.1,
KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default
KEY.PER_EPOCH: 5,
Expand Down Expand Up @@ -228,7 +242,7 @@ def data_defaults(config):
KEY.STRESS_WEIGHT: float,
KEY.USE_TESTSET: None, # Not used
KEY.NUM_WORKERS: None, # Not used
KEY.PER_EPOCH: int,
KEY.PER_EPOCH: int,
KEY.CONTINUE: {
KEY.CHECKPOINT: str,
KEY.RESET_OPTIMIZER: bool,
Expand Down
17 changes: 12 additions & 5 deletions sevenn/main/sevenn_get_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse

import torch
import os

import sevenn._const as _const
import sevenn.util
Expand All @@ -10,22 +9,30 @@
f'sevenn version={_const.SEVENN_VERSION}, sevenn_get_model.'
+ ' Deploy model for LAMMPS from the checkpoint'
)
checkpoint_help = 'checkpoint path'
checkpoint_help = (
'path to the checkpoint | SevenNet-0 | 7net-0 |'
' {SevenNet-0|7net-0}_{11July2024|22May2024}'
)
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'


def main(args=None):
checkpoint, output_prefix, get_parallel = cmd_parse_get_model(args)
get_serial = not get_parallel
cp_file = torch.load(checkpoint, map_location=torch.device('cpu'))

if output_prefix is None:
output_prefix = (
'deployed_parallel' if not get_serial else 'deployed_serial'
)

model, config = sevenn.util.model_from_checkpoint(checkpoint)
checkpoint_path = None
if os.path.isfile(checkpoint):
checkpoint_path = checkpoint
else:
checkpoint_path = sevenn.util.pretrained_name_to_path(checkpoint)

model, config = sevenn.util.model_from_checkpoint(checkpoint_path)
stct_dct = model.state_dict()

if get_serial:
Expand Down
40 changes: 40 additions & 0 deletions sevenn/main/sevenn_patch_lammps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import os
import subprocess

from torch import __version__

from sevenn._const import SEVENN_VERSION

# python wrapper of patch_lammps.sh script
# importlib.resources is correct way to do these things
# but it changes so frequently to use
pair_e3gnn_dir = os.path.abspath(f'{os.path.dirname(__file__)}/../pair_e3gnn')

description = (
f'sevenn version={SEVENN_VERSION}, patch LAMMPS for pair_e3gnn styles'
)


def main(args=None):
lammps_dir = cmd_parse_main(args)
cxx_standard = '17' if __version__.startswith('2') else '14'
if cxx_standard == '17':
print('Torch version >= 2.0 detacted, use CXX STANDARD 17')
else:
print('Torch version < 2.0 detacted, use CXX STANDARD 14')
script = f'{pair_e3gnn_dir}/patch_lammps.sh'
cmd = f'{script} {lammps_dir} {cxx_standard}'
res = subprocess.run(cmd.split())
return res.returncode # is it meaningless?


def cmd_parse_main(args=None):
ag = argparse.ArgumentParser(description=description)
ag.add_argument('lammps_dir', help='Path to LAMMPS source', type=str)
args = ag.parse_args()
return args.lammps_dir


if __name__ == '__main__':
main()
30 changes: 30 additions & 0 deletions sevenn/main/sevenn_preset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import argparse

import sevenn._const as _const

description_preset = (
f'sevenn version={_const.SEVENN_VERSION}, sevenn_preset.'
+ ' copy paste preset training yaml file to current directory'
+ ' ex) sevennet_preset fine_tune > my_input.yaml'
)

preset_help = "Name of preset"


def main(args=None):
preset = cmd_parse_preset(args)
prefix = os.path.abspath(f'{os.path.dirname(__file__)}/../presets')

with open(f"{prefix}/{preset}.yaml", "r") as f:
print(f.read())


def cmd_parse_preset(args=None):
ag = argparse.ArgumentParser(description=description_preset)
ag.add_argument(
'preset', choices=['fine_tune', 'sevennet-0', 'base'],
help = preset_help
)
args = ag.parse_args()
return args.preset
Loading

0 comments on commit 6d00fc7

Please sign in to comment.