-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flava] Add ckpt loading and accuracy metric to finetuning #119
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 08766d32244fc64c31049cd8003299124c48e66c Pull Request resolved: #119
Codecov Report
@@ Coverage Diff @@
## gh/ankitade/7/base #119 +/- ##
=====================================================
Coverage ? 88.59%
=====================================================
Files ? 40
Lines ? 2087
Branches ? 0
=====================================================
Hits ? 1849
Misses ? 238
Partials ? 0 Continue to review full report at Codecov.
|
@ankitade has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
- Accuracy metric for finetuning - Add checkpoint saving and best ckpt loading based on val accuracy - Load pretrained ckpt by default in classification model - make num gpus 1 in qnli.yaml Test plan python -m flava.finetune config=flava/configs/finetuning/qnli.yaml (val acc : 0.8651) Loaded model weights from checkpoint at /data/home/deankita/torchmultimodal/examples/flava-epoch=03-step=10000.ckpt /data/home/deankita/miniconda/envs/flava/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.validate()`, it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs. rank_zero_warn( Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 171/171 [00:54<00:00, 3.15it/s] ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Validate metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ validation/accuracy/classification │ 0.8651315569877625 │ │ validation/losses/classification │ 0.4168359339237213 │ Differential Revision: [D37444938](https://our.internmc.facebook.com/intern/diff/D37444938) [ghstack-poisoned]
ghstack-source-id: 39c8d11af6ffae071159f6594a5b24c6a12909bb Pull Request resolved: #119
@ankitade has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
1 similar comment
@ankitade has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
- Accuracy metric for finetuning - Add checkpoint saving and best ckpt loading based on val accuracy - Load pretrained ckpt by default in classification model - make num gpus 1 in qnli.yaml Test plan python -m flava.finetune config=flava/configs/finetuning/qnli.yaml (val acc : 0.8651) Loaded model weights from checkpoint at /data/home/deankita/torchmultimodal/examples/flava-epoch=03-step=10000.ckpt /data/home/deankita/miniconda/envs/flava/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.validate()`, it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs. rank_zero_warn( Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 171/171 [00:54<00:00, 3.15it/s] ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Validate metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ validation/accuracy/classification │ 0.8651315569877625 │ │ validation/losses/classification │ 0.4168359339237213 │ Differential Revision: [D37444938](https://our.internmc.facebook.com/intern/diff/D37444938) [ghstack-poisoned]
ghstack-source-id: 2f01a3335be79be13f128455ce898fce8469a185 Pull Request resolved: #119
- Accuracy metric for finetuning - Add checkpoint saving and best ckpt loading based on val accuracy - Load pretrained ckpt by default in classification model - make num gpus 1 in qnli.yaml Test plan python -m flava.finetune config=flava/configs/finetuning/qnli.yaml (val acc : 0.8651) Loaded model weights from checkpoint at /data/home/deankita/torchmultimodal/examples/flava-epoch=03-step=10000.ckpt /data/home/deankita/miniconda/envs/flava/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.validate()`, it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs. rank_zero_warn( Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 171/171 [00:54<00:00, 3.15it/s] ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Validate metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ validation/accuracy/classification │ 0.8651315569877625 │ │ validation/losses/classification │ 0.4168359339237213 │ Differential Revision: [D37444938](https://our.internmc.facebook.com/intern/diff/D37444938) [ghstack-poisoned]
@@ -209,6 +209,7 @@ def flava_model_for_classification( | |||
classifier_activation: Callable[..., nn.Module] = nn.ReLU, | |||
classifier_normalization: Optional[Callable[..., nn.Module]] = None, | |||
loss_fn: Optional[Callable[..., Tensor]] = None, | |||
pretrained_model_key: Optional[str] = "flava_full", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering, why do we want to add this as a param to flava_model_for_classification
? Feels to me like this class should not care about the pretrained weights. I like TorchVision's approach for handling this, maybe we can do something similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just doing this for "short term uniformity" with the flava_for_pretraining. But yeah i agree, tv's approach is nicer and something we can potentially follow
@apsdehal I am going to merge this today. please feel free to leave comments either ways and ping me. I can address them in a separate PR if I end up merging before u review |
Test plan
python -m flava.finetune config=flava/configs/finetuning/qnli.yaml
(val acc : 0.8651)
Loaded model weights from checkpoint at /data/home/deankita/torchmultimodal/examples/flava-epoch=03-step=10000.ckpt
/data/home/deankita/miniconda/envs/flava/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:330: PossibleUserWarning: Using
DistributedSampler
with the dataloaders. Duringtrainer.validate()
, it is recommended to useTrainer(devices=1)
to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings useDistributedSampler
that replicates some samples to make sure all devices have same batch size in case of uneven inputs.rank_zero_warn(
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 171/171 [00:54<00:00, 3.15it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Validate metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ validation/accuracy/classification │ 0.8651315569877625 │
│ validation/losses/classification │ 0.4168359339237213 │
Stack from ghstack (oldest at bottom):
Differential Revision: D37444938