Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse all hydras #814

Merged
merged 16 commits into from
Aug 20, 2024
Merged

Fuse all hydras #814

merged 16 commits into from
Aug 20, 2024

Conversation

rayg1234
Copy link
Collaborator

@rayg1234 rayg1234 commented Aug 15, 2024

In order to make FineTuneHydra work in FM, we decided to consolidate Main:hydra, FMHydra and FinetuneHydra so no more hydra hydras yay!

Main Changes

  • Fuse finetune-hydra with main-hydra
  • Modify the output dict format for main-hydra so that it can be used with the FM trainer
  • Modify the logic to assign output to targets to accept nested dicts as outputs (consistent hydra format), this will allow tasks to work with FM trainer

Config Changes

The updated hydra works with both main and FM with the following changes

  • main:hydra models need to have a "property" field in the output so that the trainer knows which output property of model to assign to the target (previously this was implicit and would break if multiple heads had the same name)
  • Got rid of all the modes in finetuneHydra and just simplified to only load from a checkpoint, it now has the following logic:
    • if backbone is specified, then build the backbone, else attempt to get it from the starting_checkpoint
    • if heads is specified, then build the head, else attempt to get it from the starting_checkpoint

Config for finetuning with data-only

model:
  name: hydra
  finetune_config:
    starting_checkpoint: "./checkpoints/2024-08-07-20-20-16-test/checkpoint.pt"

Config for finetuning with backbone-only

model:
  name: hydra
  finetune_config:
    starting_checkpoint: "./checkpoints/2024-08-07-20-20-16-test/checkpoint.pt"
  heads:
    energy:
      module: equiformer_v2_energy_head
    forces:
      module: equiformer_v2_force_head

Followup: need to make small updates to FM so that all the configs have heads in them (in all configs, the model will instantiate fully the backbone and heads, not tasks). This will allow us to completely remove FMHydra and FinetuneHydra

For now removed logic to rewrite finetune config - decide after getting this to work with FM first

Testing

  • All previous e2e and finetune tests
  • Train a oc20 model again and finetune with oc22

@rayg1234 rayg1234 changed the title make hydra compat with multitask Fuse all hydras Aug 16, 2024
@rayg1234 rayg1234 marked this pull request as ready for review August 16, 2024 01:06
@rayg1234 rayg1234 added enhancement New feature or request minor Minor version release labels Aug 16, 2024
Copy link

codecov bot commented Aug 16, 2024

Codecov Report

Attention: Patch coverage is 93.22034% with 4 lines in your changes missing coverage. Please review.

Files Patch % Lines
src/fairchem/core/models/base.py 92.50% 3 Missing ⚠️
src/fairchem/core/common/utils.py 92.30% 1 Missing ⚠️
Files Coverage Δ
src/fairchem/core/trainers/base_trainer.py 89.32% <ø> (-0.08%) ⬇️
src/fairchem/core/trainers/ocp_trainer.py 69.49% <100.00%> (+0.52%) ⬆️
src/fairchem/core/common/utils.py 67.80% <92.30%> (+1.69%) ⬆️
src/fairchem/core/models/base.py 88.48% <92.50%> (+1.60%) ⬆️

... and 7 files with indirect coverage changes

misko
misko previously approved these changes Aug 16, 2024
Copy link
Collaborator

@misko misko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! LGTM! :hydra:

errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use map_location here? in case we are loading on CPU or GPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

checkpoint = torch.load(checkpoint_path)
config = checkpoint["config"]["model"]
name = config.pop("name")
model = registry.get_model_class(name)(**config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should model.to() be called to the target device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we shouldnt try to convert the model to device everywhere and only have that happen in the trainer once

wood-b
wood-b previously approved these changes Aug 17, 2024
Copy link
Collaborator

@wood-b wood-b left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! A couple minor comments:

  1. See pass_through_comment
  2. It would great to update ocp_hydra_example.yml in this PR, which I think is just adding the property key in outputs. Also, if you make change to ocp_hydra_example.yml this https://github.com/FAIR-Chem/fairchem/blob/main/configs/ocp_hydra_example.yml#L1 should forces or ocp

otf_graph: bool = True,
pass_through_head_outputs: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should pass_through_head_outputs be head specific instead of on/off for all heads? Instead of a single bool there would be one per head. Maybe there is an edge case where multiple heads are specified in the the config one which returns multiple values and another that does not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for now if pass_through_head_outputs: True we check that there is only 1 head specified? Or just ignore this edge case for the moment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to keep things simple, im trying to constrain the output to either dict(str, tensor) or dict(str, dict(str, tensor)) (so either all tensors or all dicts of tensors), otherwise its too hard to manage

@rayg1234 rayg1234 dismissed stale reviews from wood-b and misko via 9298191 August 19, 2024 18:20

def forward(self, data: Batch):
# get device from input, at least one input must be a tensor to figure out it's device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

nit; Is it possible to do just device=data.pos.device ? That might be a bit safer, since it wont have to iterate over all values on each batch.

My only worry here is performance , but i dont think it should be too bad if the length of data.values() is reasonable ... and the runtime of a single batch is >500ms ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this line should be negligible, i'd rather not assume anything about whats in the data batch, but i can just save the device and only do this once as well

@misko misko self-requested a review August 19, 2024 22:12
Copy link
Collaborator

@misko misko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🎉

@rayg1234 rayg1234 added this pull request to the merge queue Aug 19, 2024
Merged via the queue into main with commit 1f0f631 Aug 20, 2024
8 checks passed
@rayg1234 rayg1234 deleted the rgao_fuse_hydras_0 branch August 20, 2024 00:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request minor Minor version release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants