Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change default logit scale in contrastive loss with temperature from …
…parameter to float (#510) Summary: GitHub actions for #506 show failures in `test_contrastive_loss_with_temperature.py` even though the changes in that PR do not touch any contrastive loss components. Running e.g. `python -m pytest -v tests/modules/losses/test_contrastive_loss_with_temperature.py`, even on that PR, there are no failures. But when we run the full test suite, two of the test cases in `test_contrastive_loss_with_temperature.py` fail. This is because of how we define the default value of `logit_scale` in `ContrastiveLossWithTemperature`. We set the default to an `nn.Parameter`, which is initialized the first time the class gets imported. But then this parameter is already defined outside of the test class and so we lose isolation of our test cases. The fix is to use a float as the default instead. Since this gets cast to an `nn.Parameter` on init anyways, there will be no difference from the user's perspective. But this way we isolate the parameter to an instance of the class instead of creating a global parameter on import. Tested on top of #506. Before the change: ``` python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py ... FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_local_loss - AssertionError: actual: 2.032681941986084, expected: 9.8753 FAILED tests/modules/losses/test_contrastive_loss_with_temperature.py::TestContrastiveLossWithTemperature::test_loss_with_ce_kwargs - AssertionError: actual: 2.1044366359710693, expected: 10.2524 ================================================================== 2 failed, 6 passed, 2 skipped in 3.00s =================================================================== ``` After the change: ``` python -m pytest -v tests/models/coca/test_coca_model.py tests/modules/losses/test_contrastive_loss_with_temperature.py ... ======================================================================= 8 passed, 2 skipped in 2.87s ======================================================================== ``` Pull Request resolved: #510 Reviewed By: kartikayk Differential Revision: D50974788 Pulled By: ebsmothers fbshipit-source-id: 6b1c2ed98583a0efd4a41894ef7c151189d51f31
- Loading branch information