Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colabs for how to interact with the model #28

Open
aaronatp opened this issue Aug 9, 2024 · 3 comments
Open

Colabs for how to interact with the model #28

aaronatp opened this issue Aug 9, 2024 · 3 comments

Comments

@aaronatp
Copy link

aaronatp commented Aug 9, 2024

Hi @amilmerchant I have been following this GitHub repo for a few months and have been looking forward to using the Colabs. When will they be uploaded? I would be happy to create the colabs if you could provide some direction :)

@WeileiZeng
Copy link

As a beginner who is interested in the work, I would appreciate some code that could get the model running.

@aaronatp
Copy link
Author

Hi @amilmerchant any updates by chance?

@WeileiZeng
Copy link

WeileiZeng commented Aug 27, 2024

In case this helps, the following code can construct the model.

#./model/run.py
from nequip import model_from_config, default_config
cfg = default_config()
cfg.scale=1.0
cfg.shift=0.0
model=model_from_config(cfg)
print(model)  #sucessfully constructed the model

I am not sure where to get training data, which has the following format.

# model name: NequiPEneryModel in  ./model/nequip.py
# model input
graph = jraph.GraphsTuple(
        nodes=nodes,
	edges=edges,
        receivers=receivers,
        senders=senders,
	globals=globals_,
        n_node=n_node,
        n_edge=n_edge,
    )

# model output
partial = functools.partial
tree_map = partial(
    jax.tree_map, is_leaf=lambda x: isinstance(x, e3nn.IrrepsArray)
)
global_output = tree_map(
        lambda n: jraph.segment_sum(n, node_gr_idx, n_graph), atomic_output
    )
# global_output is the output

# in one line, the output is
global_output = jax.tree_map(
              is_leaf=lambda x: isinstance(x, e3nn.IrrepsArray),
              lambda n: jraph.segment_sum(n, node_gr_idx, n_graph),
	      atomic_output
    )
# where atomic_output is the output of a neural network

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants