-
Notifications
You must be signed in to change notification settings - Fork 262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuse all hydras #814
Fuse all hydras #814
Conversation
Codecov ReportAttention: Patch coverage is
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! LGTM! :hydra:
src/fairchem/core/common/utils.py
Outdated
errno.ENOENT, "Checkpoint file not found", checkpoint_path | ||
) | ||
logging.info(f"Loading checkpoint from: {checkpoint_path}") | ||
checkpoint = torch.load(checkpoint_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to use map_location here? in case we are loading on CPU or GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
checkpoint = torch.load(checkpoint_path) | ||
config = checkpoint["config"]["model"] | ||
name = config.pop("name") | ||
model = registry.get_model_class(name)(**config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should model.to() be called to the target device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we shouldnt try to convert the model to device everywhere and only have that happen in the trainer once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! A couple minor comments:
- See pass_through_comment
- It would great to update ocp_hydra_example.yml in this PR, which I think is just adding the property key in outputs. Also, if you make change to ocp_hydra_example.yml this https://github.com/FAIR-Chem/fairchem/blob/main/configs/ocp_hydra_example.yml#L1 should forces or ocp
otf_graph: bool = True, | ||
pass_through_head_outputs: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should pass_through_head_outputs be head specific instead of on/off for all heads? Instead of a single bool there would be one per head. Maybe there is an edge case where multiple heads are specified in the the config one which returns multiple values and another that does not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe for now if pass_through_head_outputs: True
we check that there is only 1 head specified? Or just ignore this edge case for the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to keep things simple, im trying to constrain the output to either dict(str, tensor) or dict(str, dict(str, tensor)) (so either all tensors or all dicts of tensors), otherwise its too hard to manage
src/fairchem/core/models/base.py
Outdated
|
||
def forward(self, data: Batch): | ||
# get device from input, at least one input must be a tensor to figure out it's device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me.
nit; Is it possible to do just device=data.pos.device ? That might be a bit safer, since it wont have to iterate over all values on each batch.
My only worry here is performance , but i dont think it should be too bad if the length of data.values() is reasonable ... and the runtime of a single batch is >500ms ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this line should be negligible, i'd rather not assume anything about whats in the data batch, but i can just save the device and only do this once as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🎉
In order to make FineTuneHydra work in FM, we decided to consolidate Main:hydra, FMHydra and FinetuneHydra so no more hydra hydras yay!
Main Changes
tasks
to work with FM trainerConfig Changes
The updated hydra works with both main and FM with the following changes
output
so that the trainer knows which output property of model to assign to the target (previously this was implicit and would break if multiple heads had the same name)Config for finetuning with data-only
Config for finetuning with backbone-only
Followup: need to make small updates to FM so that all the configs have heads in them (in all configs, the model will instantiate fully the backbone and heads, not tasks). This will allow us to completely remove FMHydra and FinetuneHydra
For now removed logic to rewrite finetune config - decide after getting this to work with FM first
Testing