-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Molecule generation model (GeoDiff) #54
Conversation
tests/test_modeling_utils.py
Outdated
@@ -15,6 +15,7 @@ | |||
|
|||
import inspect | |||
import math | |||
import pdb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should be careful to always remove those before merging :-) Totally fine to keep them for testing though!
Hey @natolambert, This PR looks super cool. I think we can merge it quite quickly! Leaving some feedback directly in the code. Regarding the notebook, the API: In general, the notebook looks very nice IMO. If we have to install certain dependencies so be it! Regarding the model API I would suggest to change it a bit. E.g.: # generate geometry with model, then filter it
model_outputs = model.forward(batch, t)
# this model uses additional conditioning of the outputs depending on the current timestep
epsilon = model.get_residual(pos, sigmas[t], model_outputs)["sample"] Do you think we can merge this into a single forward pass e.g. something like: epsilon = model(batch, t, sigma)["sample"] ? |
@@ -68,6 +68,13 @@ def __init__( | |||
elif beta_schedule == "squaredcos_cap_v2": | |||
# Glide cosine schedule | |||
self.betas = betas_for_alpha_bar(num_train_timesteps) | |||
elif beta_schedule == "sigmoid": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Nice to see that "vanilla" DDPM can be used that easily
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According the author the model performs better quantitatively with a different scheduler, but visually I noticed no difference.
from torch.nn import Embedding, Linear, Module, ModuleList, Sequential | ||
|
||
from rdkit.Chem.rdchem import BondType as BT | ||
from torch_geometric.nn import MessagePassing, radius, radius_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to make sure that this model is only imported if torch_geometric
or torch_scatter
is present. We should therefore adapt the same logic we did for the optional transformers
import.
Could you do the following:
- Add a
is_torch_scatter_available
and ais_torch_geometric_available
to https://github.com/huggingface/diffusers/blob/main/src/diffusers/utils/__init__.py (just like it has been done for other dependencies) - Import
DualEncoderEpsNetwork
only if both dependencies are available here:from .vae import AutoencoderKL, VQModel if is_transformers_available(): - Ony import the model if the depedency is available in the main init as well:
diffusers/src/diffusers/__init__.py
Line 26 in 89f2011
if is_transformers_available(): - Having written 3.) run
make fix-styles
to automatically create the dummy class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the painpoint here is that is needs specific versions too, because this code was made before breaking changes in torch_geometric. Can I do similar functions for that?
|
||
|
||
class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): | ||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to not forget the @register_to_config
decorator here:
diffusers/src/diffusers/models/unet_2d.py
Line 13 in 89f2011
@register_to_config |
return score_pos | ||
|
||
|
||
class DualEncoderEpsNetwork(ModelMixin, ConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the first graph network NN in this library, we should be extra careful with the naming.
Is this a universally understandable name? Do you think other graph networks would also use this architecuter? Should we make the name more generic in this case? Can we link to a paper here that defined that model architecture?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Link to original code is above now too, I followed up with those authors asking if their's was original. https://github.com/DeepGraphLearning/ConfGF
I'll work through your comments soon @patrickvonplaten. |
src/diffusers/models/molecule_gnn.py
Outdated
|
||
|
||
class CFConv(MessagePassing): | ||
def __init__(self, in_channels, out_channels, num_filters, nn, cutoff, smooth): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you import nn from torch you can't use it as parameter/variable name, it's sure to create bugs later on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if you are redefining here CFConv from the SchNet paper, I think you should init the linear/shifted softplus/linear layers (what you put in nn) here instead of passing them as arguments, as it makes it hard to follow/find the logic of the paper in your code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(If you want to allow something more general, the default init for this layer should still be the original one from the paper)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah I hadn't seen this nn
usage until now, I really dislike that. It's very confusing (and here is clear how copy-pasted blocks of this code were).
I think this was the result of copying code from multiple files too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RE the second point, I don't feel as strongly about it. It was taken directly from a few codebases back. In this case, the classes CFConv
and InteractionBlock
could effectively become one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, if you open the SchNet paper, there is Figure 2 (I think) which describes the model arch. I think it could be interesting if you went through the figure along with your code: having CFConv and InteractionBlock separate is not a problem, but it would make CFConv easier to match with the paper it it defined the nn
layers in directly. For further readers of this code, it will be clearer, IMO
@@ -0,0 +1,640 @@ | |||
# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think overall it would help if you could add class documentation of your different components and type hinting, at least in the inits and forwards
super(MultiLayerPerceptron, self).__init__() | ||
|
||
self.dims = [input_dim] + hidden_dims | ||
if isinstance(activation, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If someone passes an activation function instead of an activation function name this will fail silently
|
||
class SchNetEncoder(Module): | ||
def __init__( | ||
self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't paper default for num_interactions
3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the SchNet paper? No clue. This was copied from author.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I guess we can close this comment then (it is 3 in the SchNet paper, I looked it up along your code yesterday, but it's not that important, and if the other model relies on 6 as default, might be easier to keep 6)
src/diffusers/models/molecule_gnn.py
Outdated
|
||
|
||
class GINEConv(MessagePassing): | ||
def __init__(self, nn: Callable, eps: float = 0.0, train_eps: bool = False, activation="softplus", **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are importing classes from torchgeometric, why are you redefining GINEConv instead of using the default version? (Could be worth some doc to explain what is different - the activation function choice?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Followed-up with author @MinkaiXu directly, and he got this from the implementation he built on. I can look a little more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the source and it's the same with the optional addition of an activation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth adding it in the class doc IMO
@@ -0,0 +1,640 @@ | |||
# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what the standard of this library is wrt asserts vs raising exceptions, but you might need to check this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anton-l or @patil-suraj any comment on this?
|
||
hiddens = [] | ||
conv_input = node_attr # (num_node, hidden) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is node_attr for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean with respect to the model / application? Or just in this code? Both of them I have only intermediate understanding of.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I just meant that you do
node_attr = self.node_emb(z)
conv_attr = node_attr
why not directly
conv_input = self.node_emb(z)
?
We should be able to add it now that things are calmer 🥳 |
@natolambert @clefourrier @patrickvonplaten any plan to integrate this PR? |
I took a quick look and everything is pretty good (Thank you all for your efforts!) Specifically thanks @natolambert a lot for discussing many details together! |
@georgosgeorgos and @MinkaiXu, sorry for the delay on merging. We got a little distracted by Stable Diffusion. I'll plan on merging the updates on main to this + the notebook then we should be able to close the PR soon. The colab should run in its current form! So if you look at that, happy to take comments! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Fighting the stale bot! We haven't forgotten about this, and actually moved it up the priority list today. Soon! |
I'm glad to help, but not quite familiar with the whole process --- |
@MinkaiXu, @patrickvonplaten moved fast and removed it here. An interesting little difference I am not aware of :) |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@patrickvonplaten @anton-l @patil-suraj: so I realized the tests I implemented are not that useful because in order to test them you need:
These are kind of a pain to install with How should we think about integrating these tests? |
Why do we need to install pytorch from source ? Also is |
@patil-suraj TLDR is re-implemented a research paper's code and it was made on a version before a lot of breaking changes. Has made this pr-for-a-colab a bit unwieldy. I'm actually not sure it any version of |
Do we already have a working google colab for this model. |
Oh yeah I can just put it all in the colab and not port it into diffusers. May be easier. In summary, I would merge in |
sounds good! |
We closed this PR in favor of an colab-only solution unless someone has time to update the source model to the new versions of |
@patrickvonplaten @natolambert where is the working colab with geodiff in the |
Let's leave it open until we have a colab version. We can also make use of community pipelines: https://huggingface.co/docs/diffusers/using-diffusers/custom_pipelines :-) |
Hey! Thanks for adding GeoDiff into the pipeline :D TorsionDiff (https://arxiv.org/abs/2206.01729) might be another cool approach for the library. It requires fewer diffusion steps for conformational generation than GeoDiff! code: https://github.com/gcorso/torsional-diffusion |
Added a new model file to add functionality for this paper that does molecule generation via diffusion.
Pretrained models available are for two tasks,
drugs
andqm9
:Will work on additional examples for this model and update this PR.
Some todo items:
MoleculeGNN
torch_geometric<2
,pytorch 1.8
, andtorch_scatter
(a recommended installation method is in the colab)Some comments: