Skip to content

Commit

Permalink
Change default logit scale in contrastive loss with temperature from …
Browse files Browse the repository at this point in the history
…parameter to float (#510)

Summary:
GitHub actions for #506 show failures in `test_contrastive_loss_with_temperature.py` even though the changes in that PR do not touch any contrastive loss components. Running e.g. `python -m pytest -v tests/modules/losses/test_contrastive_loss_with_temperature.py`, even on that PR, there are no failures. But when we run the full test suite, two of the test cases in `test_contrastive_loss_with_temperature.py` fail.

This is because of how we define the default value of `logit_scale` in `ContrastiveLossWithTemperature`. We set the default to an `nn.Parameter`, which is initialized the first time the class gets imported. But then this parameter is already defined outside of the test class and so we lose isolation of our test cases.

The fix is to use a float as the default instead. Since this gets cast to an `nn.Parameter` on init anyways, there will be no difference from the user's perspective. But this way we isolate the parameter to an instance of the class instead of creating a global parameter on import.

Tested on top of #506. Before the change:

```
python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py
...
FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_local_loss - AssertionError: actual: 2.032681941986084, expected: 9.8753
FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_loss_with_ce_kwargs - AssertionError: actual: 2.1044366359710693, expected: 10.2524
================================================================== 2 failed, 6 passed, 2 skipped in 3.00s ===================================================================
```

After the change:

```
python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py
...
======================================================================= 8 passed, 2 skipped in 2.87s ========================================================================
```

Pull Request resolved: #510

Reviewed By: kartikayk

Differential Revision: D50974788

Pulled By: ebsmothers

fbshipit-source-id: 6b1c2ed98583a0efd4a41894ef7c151189d51f31
  • Loading branch information
ebsmothers authored and facebook-github-bot committed Nov 9, 2023
1 parent b933b8e commit e6b92b5
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def contrastive_loss_with_temperature(
)


DEFAULT_LOGIT_SCALE = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
DEFAULT_LOGIT_SCALE = math.log(1 / 0.07)


class ContrastiveLossWithTemperature(nn.Module):
Expand Down

0 comments on commit e6b92b5

Please sign in to comment.