Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flava] Add ckpt loading and accuracy metric to finetuning #119

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/flava/configs/finetuning/qnli.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ training:
_target_: flava.definitions.TrainingArguments
lightning:
max_steps: 33112
gpus: -1
gpus: 1
progress_bar_refresh_rate: 50
val_check_interval: 1000
num_sanity_val_steps: 0
Expand All @@ -16,6 +16,8 @@ training:
every_n_train_steps: 1000
save_on_train_epoch_end: true
verbose: true
monitor: validation/accuracy/classification
mode: max
lightning_load_from_checkpoint: null
seed: -1
batch_size: 32
Expand Down Expand Up @@ -45,3 +47,6 @@ datasets:
- ["sentence", "sentence2"]
datamodule_extra_kwargs:
text_columns: ["sentence1", "sentence2"]

model:
pretrained_model_key: flava_full
25 changes: 17 additions & 8 deletions examples/flava/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from flava.data.datamodules import VLDataModule
from flava.definitions import FLAVAArguments
from flava.model import FLAVAClassificationLightningModule
from flava.utils import build_config, build_datamodule_kwargs
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from utils import build_config, build_datamodule_kwargs
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

AVAIL_GPUS = 1
SEED = -1
Expand Down Expand Up @@ -55,14 +55,23 @@ def main():
**config.model,
)

callbacks = [
LearningRateMonitor(logging_interval="step"),
]

if config.training.lightning_checkpoint is not None:
callbacks.append(
ModelCheckpoint(
**OmegaConf.to_container(config.training.lightning_checkpoint)
)
)

trainer = Trainer(
**OmegaConf.to_container(config.training.lightning),
callbacks=[
LearningRateMonitor(logging_interval="step"),
],
**OmegaConf.to_container(config.training.lightning), callbacks=callbacks
)
trainer.fit(model, datamodule=datamodule)
trainer.validate(model, datamodule=datamodule)
ckpt_path = config.training.lightning_load_from_checkpoint
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
trainer.validate(datamodule=datamodule)


if __name__ == "__main__":
Expand Down
28 changes: 23 additions & 5 deletions examples/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
from pytorch_lightning import LightningModule
from torchmetrics import Accuracy
from torchmultimodal.models.flava.flava_model import (
flava_model_for_classification,
flava_model_for_pretraining,
Expand Down Expand Up @@ -139,18 +140,33 @@ def __init__(
self.warmup_steps = warmup_steps
self.max_steps = max_steps
self.adam_betas = adam_betas
self.metrics = Accuracy()

def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
output, accuracy = self._step(batch, batch_idx)
self.log("train/losses/classification", output.loss, prog_bar=True, logger=True)
self.log(
"train/accuracy/classification",
accuracy,
prog_bar=True,
logger=True,
sync_dist=True,
)

return output.loss

def validation_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
output, accuracy = self._step(batch, batch_idx)
self.log(
"validation/losses/classification", output.loss, prog_bar=True, logger=True
)
self.log(
"validation/accuracy/classification",
accuracy,
prog_bar=True,
logger=True,
sync_dist=True,
)

return output.loss

Expand All @@ -164,15 +180,17 @@ def _step(self, batch, batch_idx):
else:
raise RuntimeError("Batch needs to have either or both 'image' and 'text'.")

labels = batch["labels"]
output = self.model(
image=batch.get("image", None),
text=batch.get("text", None),
required_embedding=required_embedding,
labels=batch.get("labels", None),
labels=labels,
)

# TODO: Add accuracy metric to this later.
return output
accuracy = self.metrics(output.logits, labels)

return output, accuracy

def configure_optimizers(self):
return get_optimizers_for_lightning(
Expand Down
2 changes: 1 addition & 1 deletion test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):

@torch.no_grad()
def test_forward_classification(self):
flava = flava_model_for_classification(NUM_CLASSES)
flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None)
text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
image = torch.rand((2, 3, 224, 224))

Expand Down
10 changes: 9 additions & 1 deletion torchmultimodal/models/flava/flava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def flava_model_for_classification(
classifier_activation: Callable[..., nn.Module] = nn.ReLU,
classifier_normalization: Optional[Callable[..., nn.Module]] = None,
loss_fn: Optional[Callable[..., Tensor]] = None,
pretrained_model_key: Optional[str] = "flava_full",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering, why do we want to add this as a param to flava_model_for_classification? Feels to me like this class should not care about the pretrained weights. I like TorchVision's approach for handling this, maybe we can do something similar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just doing this for "short term uniformity" with the flava_for_pretraining. But yeah i agree, tv's approach is nicer and something we can potentially follow

**flava_model_kwargs: Any,
):
model = flava_model(**flava_model_kwargs)
Expand All @@ -224,7 +225,14 @@ def flava_model_for_classification(
if loss_fn is None:
loss_fn = nn.CrossEntropyLoss()

return FLAVAForClassification(model=model, classifier=classifier, loss=loss_fn)
classification_model = FLAVAForClassification(
model=model, classifier=classifier, loss=loss_fn
)
if pretrained_model_key is not None:
classification_model.load_model(
FLAVA_FOR_PRETRAINED_MAPPING[pretrained_model_key], strict=False
)
return classification_model


def to_2tuple(x):
Expand Down
3 changes: 2 additions & 1 deletion torchmultimodal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def load_model(
pretrained_url: Optional[str],
load_state_dict: bool = True,
state_dict_key: Optional[str] = None,
strict: bool = True,
):
assert isinstance(
self, torch.nn.Module
Expand All @@ -160,7 +161,7 @@ def load_model(
state_dict = state_dict[state_dict_key]

if load_state_dict:
self.load_state_dict(state_dict)
self.load_state_dict(state_dict, strict=strict)
return state_dict


Expand Down