Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating floating point precision for OC22/total_energy predictions and enabling OC22 challenge submission file generation #421

Merged
merged 19 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
c6ecccb
write tot_e predicts in float32
wood-b Sep 30, 2022
dc520a5
adding total_energy=True to base OC22 configs
wood-b Oct 6, 2022
4a4cdd2
assert oc22 predictions are fp32
wood-b Oct 6, 2022
6fbfc90
update to method for writing predictions to keep track of precision
wood-b Oct 6, 2022
414693d
submission file to support oc22
mshuaibii Oct 6, 2022
64341fc
Merge branch 'predict_fp' of https://github.com/Open-Catalyst-Project…
wood-b Oct 7, 2022
f8999c6
move energy values to cpu before writing predicts and updated make_su…
wood-b Oct 7, 2022
266c903
Merge branch 'main' of https://github.com/Open-Catalyst-Project/ocp i…
wood-b Oct 7, 2022
437ee2c
minor fix
mshuaibii Oct 7, 2022
07b68a3
minor fix
wood-b Oct 7, 2022
4a4521e
Merge branch 'predict_fp' of https://github.com/Open-Catalyst-Project…
wood-b Oct 16, 2022
2d21fd9
update to include prediction_dtype flag and remove check in make_subm…
wood-b Oct 16, 2022
a4826c2
Merge branch 'main' into predict_fp
wood-b Dec 12, 2022
649c4df
added documentation for the prediction_type flag and oc22 evalai
wood-b Dec 13, 2022
db14641
Merge branch 'main' of https://github.com/Open-Catalyst-Project/ocp i…
wood-b Jan 19, 2023
df67f10
Merge branch 'main' of https://github.com/Open-Catalyst-Project/ocp i…
wood-b Jan 21, 2023
0460264
updated oc22 docs in TRAIN.md and minor changes to make_submission_fi…
wood-b Jan 21, 2023
7175bbd
add joint training documentation
mshuaibii Jan 24, 2023
dcfa6a0
Merge branch 'main' into predict_fp
abhshkdz Jan 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/oc22/is2re/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ dataset:
train:
src: data/oc22/is2re/train
normalize_labels: False
total_energy: True
wood-b marked this conversation as resolved.
Show resolved Hide resolved
val:
src: data/oc22/is2re/val_id
total_energy: True
wood-b marked this conversation as resolved.
Show resolved Hide resolved

logger: wandb

Expand Down
2 changes: 2 additions & 0 deletions configs/oc22/s2ef/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ dataset:
train:
src: data/oc22/s2ef/train
normalize_labels: False
total_energy: True
val:
src: data/oc22/s2ef/val_id
total_energy: True
wood-b marked this conversation as resolved.
Show resolved Hide resolved

logger: wandb

Expand Down
34 changes: 27 additions & 7 deletions ocpmodels/trainers/forces_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from ocpmodels.common.utils import check_traj_files
from ocpmodels.modules.evaluator import Evaluator
from ocpmodels.modules.normalizer import Normalizer
from ocpmodels.trainers.base_trainer import BaseTrainer
from ocpmodels.modules.scaling.util import ensure_fitted
from ocpmodels.trainers.base_trainer import BaseTrainer


@registry.register_trainer("forces")
Expand Down Expand Up @@ -208,14 +208,22 @@ def predict(
)
]
predictions["id"].extend(systemids)
predictions["energy"].extend(
out["energy"].to(torch.float16).tolist()
)
batch_natoms = torch.cat(
[batch.natoms for batch in batch_list]
)
batch_fixed = torch.cat([batch.fixed for batch in batch_list])
forces = out["forces"].cpu().detach().to(torch.float16)
# total energy requires predictions are saved in float32
# default is ads energy not total energy
if self.config["dataset"].get("total_energy", False):
wood-b marked this conversation as resolved.
Show resolved Hide resolved
predictions["energy"].extend(
out["energy"].to(torch.float32).tolist()
)
forces = out["forces"].cpu().detach().to(torch.float32)
else:
predictions["energy"].extend(
out["energy"].to(torch.float16).tolist()
)
forces = out["forces"].cpu().detach().to(torch.float16)
per_image_forces = torch.split(forces, batch_natoms.tolist())
per_image_forces = [
force.numpy() for force in per_image_forces
Expand Down Expand Up @@ -247,9 +255,21 @@ def predict(
self.ema.restore()
return predictions

predictions["forces"] = np.array(predictions["forces"])
if self.config["dataset"].get("total_energy", False):
predictions["forces"] = np.array(
predictions["forces"], dtype="float32"
)
predictions["energy"] = np.array(
predictions["energy"], dtype="float32"
)
else:
predictions["forces"] = np.array(
predictions["forces"], dtype="float16"
)
predictions["energy"] = np.array(
predictions["energy"], dtype="float16"
)
predictions["chunk_idx"] = np.array(predictions["chunk_idx"])
predictions["energy"] = np.array(predictions["energy"])
predictions["id"] = np.array(predictions["id"])
self.save_results(
predictions, results_file, keys=["energy", "forces", "chunk_idx"]
Expand Down
105 changes: 71 additions & 34 deletions scripts/make_submission_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,36 @@

import numpy as np

SPLITS = {
"OC20": ["id", "ood_ads", "ood_cat", "ood_both"],
"OC22": ["id", "ood"],
}

def write_is2re_relaxations(paths, filename, hybrid):

def write_is2re_relaxations(args, dataset):
import ase.io
from tqdm import tqdm

submission_file = {}

if not hybrid:
for idx, split in enumerate(["id", "ood_ads", "ood_cat", "ood_both"]):
if not args.hybrid:
for split in SPLITS[dataset]:
ids = []
energies = []
systems = glob.glob(os.path.join(paths[idx], "*.traj"))
systems = glob.glob(os.path.join(vars(args)[split], "*.traj"))
for system in tqdm(systems):
sid, _ = os.path.splitext(os.path.basename(system))
ids.append(str(sid))
# Read the last frame in the ML trajectory. Modify "-1" if you wish to modify which frame to use.
traj = ase.io.read(system, "-1")
energies.append(traj.get_potential_energy())

submission_file[f"{split}_ids"] = np.array(ids)
submission_file[f"{split}_energy"] = np.array(energies)

else:
for idx, split in enumerate(["id", "ood_ads", "ood_cat", "ood_both"]):
preds = np.load(paths[idx])
for split in SPLITS[dataset]:
preds = np.load(vars(args)[split])
ids = []
energies = []
for sid, energy in zip(preds["ids"], preds["energy"]):
Expand All @@ -45,54 +51,71 @@ def write_is2re_relaxations(paths, filename, hybrid):
submission_file[f"{split}_ids"] = np.array(ids)
submission_file[f"{split}_energy"] = np.array(energies)

np.savez_compressed(filename, **submission_file)
np.savez_compressed(args.out_path, **submission_file)


def write_predictions(paths, filename):
submission_file = {}
def write_predictions(args, dataset):
if args.is2re_relaxations:
write_is2re_relaxations(args, dataset=dataset)
else:
submission_file = {}

for split in SPLITS[dataset]:
res = np.load(vars(args)[split], allow_pickle=True)
verify_dtype(res, dataset)
contents = res.files
for i in contents:
key = "_".join([split, i])
submission_file[key] = res[i]

np.savez_compressed(args.out_path, **submission_file)

for idx, split in enumerate(["id", "ood_ads", "ood_cat", "ood_both"]):
res = np.load(paths[idx], allow_pickle=True)
contents = res.files
for i in contents:
key = "_".join([split, i])
submission_file[key] = res[i]

np.savez_compressed(filename, **submission_file)
def verify_dtype(preds, dataset):
if dataset == "OC22":
if "energy" in preds:
assert preds["energy"].dtype in [
np.float32,
np.float64,
], "Predictions written in the wrong precision. Ensure `total_energy` flag is True in the config."
if "forces" in preds:
assert preds["forces"].dtype in [
np.float32,
np.float64,
], "Predictions written in the wrong precision. Ensure `total_energy` flag is True in the config."
wood-b marked this conversation as resolved.
Show resolved Hide resolved


def main(args):
id_path = args.id
ood_ads_path = args.ood_ads
ood_cat_path = args.ood_cat
ood_both_path = args.ood_both
if args.oc22:
for split in SPLITS["OC22"]:
assert vars(args).get(split), f"Missing {split} split for OC22"
dataset = "OC22"
else:
for split in SPLITS["OC20"]:
assert vars(args).get(split), f"Missing {split} split for OC20"
dataset = "OC20"

paths = [id_path, ood_ads_path, ood_cat_path, ood_both_path]
if not args.out_path.endswith(".npz"):
args.out_path = args.out_path + ".npz"

if not args.is2re_relaxations:
write_predictions(paths, filename=args.out_path)
else:
write_is2re_relaxations(
paths, filename=args.out_path, hybrid=args.hybrid
)
write_predictions(args, dataset=dataset)
print(f"Results saved to {args.out_path} successfully.")


if __name__ == "__main__":
"""
Create a submission file for evalAI. Ensure that for the task you are
submitting for you have generated results files on each of the 4 splits -
id, ood_ads, ood_cat, ood_both.
submitting for you have generated results files on each of the splits:
OC20: id, ood_ads, ood_cat, ood_both
OC22: id, ood

Results file can be obtained as follows for the various tasks:

S2EF: config["mode"] = "predict"
IS2RE: config["mode"] = "predict"
IS2RS: config["mode"] = "run-relaxations" and config["task"]["write_pos"] = True

Use this script to join the 4 results files in the format evalAI expects
Use this script to join the results files (4 for OC20, 2 for OC22) in the format evalAI expects
submissions.

If writing IS2RE predictions from relaxations, paths must be directories
Expand All @@ -106,10 +129,21 @@ def main(args):
"""

parser = argparse.ArgumentParser()
parser.add_argument("--id", help="Path to ID results")
parser.add_argument("--ood-ads", help="Path to OOD-Ads results")
parser.add_argument("--ood-cat", help="Path to OOD-Cat results")
parser.add_argument("--ood-both", help="Path to OOD-Both results")
parser.add_argument(
"--id", help="Path to ID results. Required for OC20 and OC22."
)
parser.add_argument(
"--ood-ads", help="Path to OOD-Ads results. Required only for OC20."
)
parser.add_argument(
"--ood-cat", help="Path to OOD-Cat results. Required only for OC20."
)
parser.add_argument(
"--ood-both", help="Path to OOD-Both results. Required only for OC20."
)
parser.add_argument(
"--ood", help="Path to OOD OC22 results. Required only for OC22."
)
parser.add_argument("--out-path", help="Path to write predictions to.")
parser.add_argument(
"--is2re-relaxations",
Expand All @@ -121,6 +155,9 @@ def main(args):
action="store_true",
help="Write IS2RE results from S2EF prediction files. Paths specified correspond to S2EF NPZ files.",
)
parser.add_argument(
"--oc22", action="store_true", help="Write OC22 prediction files."
)
wood-b marked this conversation as resolved.
Show resolved Hide resolved

args = parser.parse_args()
main(args)