Skip to content

Commit

Permalink
change order of itm loss init
Browse files Browse the repository at this point in the history
ghstack-source-id: 9b2a69707e2e132df0eae5745e0df5a93a1581a3
Pull Request resolved: #131
  • Loading branch information
ankitade committed Jul 4, 2022
1 parent 6fac11d commit 6f78bfd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
6 changes: 3 additions & 3 deletions test/models/flava/test_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
21.4791,
21.3670,
places=4,
)

Expand All @@ -107,7 +107,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
8.9674,
8.6285,
places=4,
)

Expand All @@ -132,7 +132,7 @@ def test_forward_pretraining(self):
sum(
value if value is not None else 0 for value in output.losses.values()
).item(),
10.0305,
11.0002,
places=4,
)

Expand Down
9 changes: 4 additions & 5 deletions torchmultimodal/modules/losses/flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def __init__(
**kwargs: Any,
):
super().__init__()

self.itm_loss = ITMLoss(
hidden_size=hidden_size,
ignore_index=ignore_index,
)
self.contrastive_loss = FLAVAGlobalContrastiveLoss(
logit_scale=logit_scale,
image_embedding_size=hidden_size,
Expand Down Expand Up @@ -344,10 +347,6 @@ def __init__(
),
}
)
self.itm_loss = ITMLoss(
hidden_size=hidden_size,
ignore_index=ignore_index,
)

self.mim_weight = mim_weight
self.mlm_weight = mlm_weight
Expand Down

0 comments on commit 6f78bfd

Please sign in to comment.