Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CHGNet scorer implementation #8

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ algorithm.
bond_length_acceptability_cutoff: float = 1.0
reward_k: float = 2.0 # the reward constant
mcts_out_dir: str = "mcts" # path to the directory where generated CIF files will be stored
scorer: str = "zmq" # supported values: 'zmq', 'random'
scorer: str = "zmq" # supported values: 'zmq', 'random', `CHGNet`
scorer_host: str = "localhost" # required if `scorer` is 'zmq'
scorer_port: int = 5555 # required if `scorer` is 'zmq'
use_context_sensitive_tree_builder: bool = True
Expand All @@ -522,6 +522,8 @@ algorithm.
n_space_groups: int = 0
bypass_only_child: bool = False
n_rollouts: int = 1 # the number of rollouts to perform per simulation
scorer_device: str = "cpu" # required, if using `CHGNet`, valida values are: `cpu` and `cuda`
chgnet_model_name: str = "0.3.0" # reqruied if `scorer` is `CHGNet`
```

</details>
Expand Down Expand Up @@ -557,6 +559,10 @@ obtaining the score from another process, via the ZMQ library, and in such a cas
value of `zmq`. See [this script](resources/alignn_zmq_example.py) for an example of setting up
[ALIGNN](https://github.com/usnistgov/alignn) to listen for and respond to prediction requests using ZMQ.

If you want to use scorer without ZMQ, `CHGNetScorer` is provided. To avoid memory overflow on a single GPU, CHGNet scorer option
is built to do inference on a separate device (`cpu` or `cuda:1`). However, you need to install `CHGNet` to use this feature.
For installation, you can `pip install chgnet`. More details, please refer to [CHGNet official website](https://chgnet.lbl.gov/).

### Using a Pre-trained Model

To use a pre-trained model, first download it:
Expand Down
8 changes: 7 additions & 1 deletion bin/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
RandomScorer,
UCTSelector,
ZMQScorer,
CHGNetScorer,
)


Expand All @@ -37,7 +38,7 @@ class MCTSDefaults:
bond_length_acceptability_cutoff: float = 1.0
reward_k: float = 2.0 # the reward constant
mcts_out_dir: str = "mcts" # path to the directory where generated CIF files will be stored
scorer: str = "zmq" # supported values: 'zmq', 'random'
scorer: str = "zmq" # supported values: 'zmq', 'random', `CHGnet`
scorer_host: str = "localhost" # required if `scorer` is 'zmq'
scorer_port: int = 5555 # required if `scorer` is 'zmq'
use_context_sensitive_tree_builder: bool = True
Expand All @@ -46,6 +47,8 @@ class MCTSDefaults:
n_space_groups: int = 0
bypass_only_child: bool = False
n_rollouts: int = 1 # the number of rollouts to perform per simulation
scorer_device: str = "cpu" #required, if using `CHGnet`, valida values are: `cpu` and `cuda`
chgnet_model_name: str = "0.3.0" #reqruied if `scorer` is `CHGnet`


if __name__ == "__main__":
Expand Down Expand Up @@ -93,6 +96,9 @@ class MCTSDefaults:
cif_scorer = ZMQScorer(host=C.scorer_host, port=C.scorer_port)
elif C.scorer == "random":
cif_scorer = RandomScorer()
elif C.scorer == "CHGNet":
scorer_device = "cuda:1" if "cuda" in C.scorer_device else "cpu"
cif_scorer = CHGNetScorer(host_device=C.device, scorer_device=scorer_device, model_name=C.chgnet_model_name)
else:
raise Exception(f"unsupported scorer: {C.scorer}")

Expand Down
1 change: 1 addition & 0 deletions crystallm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
CIFScorer,
RandomScorer,
ZMQScorer,
CHGNetScorer,
)

from ._configuration import parse_config
Expand Down
1 change: 0 additions & 1 deletion crystallm/_mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __call__(self, token_sequence, iter_num):
print(f"exception while scoring: {e}")
print(traceback.format_exc())
return -1.0

if math.isnan(score):
print(f"reward cannot be computed as score is nan")
return -1.0
Expand Down
6 changes: 3 additions & 3 deletions crystallm/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def bond_length_reasonableness_score(cif_str, tolerance=0.32, h_factor=2.5):

def is_space_group_consistent(cif_str):
structure = Structure.from_str(cif_str, fmt="cif")
parser = CifParser.from_string(cif_str)
parser = CifParser.from_str(cif_str)
cif_data = parser.as_dict()

# Extract the stated space group from the CIF file
Expand All @@ -88,7 +88,7 @@ def is_space_group_consistent(cif_str):


def is_formula_consistent(cif_str):
parser = CifParser.from_string(cif_str)
parser = CifParser.from_str(cif_str)
cif_data = parser.as_dict()

formula_data = Composition(extract_data_formula(cif_str))
Expand All @@ -100,7 +100,7 @@ def is_formula_consistent(cif_str):

def is_atom_site_multiplicity_consistent(cif_str):
# Parse the CIF string
parser = CifParser.from_string(cif_str)
parser = CifParser.from_str(cif_str)
cif_data = parser.as_dict()

# Extract the chemical formula sum from the CIF data
Expand Down
38 changes: 38 additions & 0 deletions crystallm/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,41 @@ def score(self, cif: str) -> float:

except zmq.Again as e:
raise TimeoutError("ZeroMQ request timed out") from e

class CHGNetScorer(CIFScorer):

def __init__(self,host_device, scorer_device, model_name):
"""
A CIF scorer which returns a score obtained from another cuda process.

:param host_device: the crystaLLM device name
:param scorer_device: the scorer device name

TO-DO: upload a script of energy evaluator and execute it at run-time
"""
from chgnet.model.model import CHGNet
self._host_device = host_device
self._scorer_device = scorer_device
print(f"CrystaLLM using: {host_device}")
print(f"Pytorch Scorer using: {scorer_device}")
print(f"CHGNET model name: {model_name}")
self._chgnet = CHGNet.load(model_name,use_device=scorer_device)

def score(self, cif: str) -> float:
from pymatgen.io.cif import CifParser
message = cif
try:
try:
cif_parser = CifParser.from_str(cif_string = message)
structure = cif_parser.parse_structures(primitive = True)
except Exception as e:
cif_parser = CifParser.from_str(cif_string = message)
structure = cif_parser.parse_structures(primitive = False)
prediction = self._chgnet.predict_structure(structure)
reply = f"{prediction['e']}"

except Exception as ex:
print(f"exception making prediction: {ex}")
reply = "nan"
print(f"sending reply: {reply}")
return float(reply)
2 changes: 1 addition & 1 deletion crystallm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def replace_symmetry_operators(cif_str, space_group_symbol):
v = op.translation_vector
symmops.append(SymmOp.from_rotation_and_translation(op.rotation_matrix, v))

ops = [op.as_xyz_string() for op in symmops]
ops = [op.as_xyz_str() for op in symmops]
data["_symmetry_equiv_pos_site_id"] = [f"{i}" for i in range(1, len(ops) + 1)]
data["_symmetry_equiv_pos_as_xyz"] = ops

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pandas==1.5.3
pymatgen==2023.3.23
pymatgen>=2023.3.23
numpy==1.24.2
scikit-learn==1.2.2
tqdm==4.65.0
Expand Down