-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1438] Adding SDML loss function #17298
Conversation
@anjishnu Please address the sanity errors: http://jenkins.mxnet-ci.amazon-ml.com/job/mxnet-validation/job/sanity/job/PR-17298/1/display/redirect . Also can you randomize your unit test to ensure that we're covering more numerically different cases? |
@haojin2 Sure will address the sanity cases. Can you give an example of a unit test that is appropriately randomized so I can base it on that? |
It looks a little tricky to port this into the 'fit' and 'score' paradigm since this is a retrieval specific loss function which uses the other elements in a batch as implicit negative samples - and I'm not sure how cleanly it fits into the Module API for this kind of test. Specially since the loss computation needs to know the shape of the minibatch which doesn't seem to be possible in the symbol API. The loss only guarantees that associated pairs will be closer in the chosen metric space after learning as compared to the non-associated pairs. Maybe I can write something equivalent using the gluon API, to train a small network and ensure it learns the right associations. I'll come up with a proposal shortly. |
@anjishnu I was actually only asking for getting the input data randomized with the original version of test code untouched. |
@haojin2 if I randomized the input data in the original test code the losses would would have different values during each run (SDML loss imposes a distribution over the relative distances of data points in a minibatch) - so I would not be able to compare the output against precomputed loss values any more - thus the original unit test procedure cannot be reused. That's why I added a test that fits a toy model to some toy data instead. The current test was running in ~50 ms on my machine on CPU. Would love to hear your thoughts on how to improve on this. |
@haojin2 Are there any serious concerns with the new unit test? |
@anjishnu I'm only concerned about if there's any flakiness in such a test, to verify, please try |
@haojin2 Changing the hyperparameters (e.g. increasing dimensionality, lowering N) does make it more robust (tried up to MXNET_TEST_COUNT=2000 runs) but there's always a chance of failure Other options:
Which of these would be your preferred approach - I have tested them all to MXNET_TEST_COUNT=2000 |
@haojin2 could you take a look at the latest version - the new unit test (which looks at loss reduction) is passing MXNET_TEST_COUNT=2000 |
@anjishnu it seems ok. I've re-triggered your failed website build, I think it's good for merge after that passes. |
@haojin2 ok, thanks! sounds good! |
@anjishnu Merged, thx for the contribution. |
* Added loss function * cleaning out loss function * added loss fn * added loss function and unit test * fixed linter issues and updated unit test to train net and eval * fixed documentation * made unit tests more robust
* Added loss function * cleaning out loss function * added loss fn * added loss function and unit test * fixed linter issues and updated unit test to train net and eval * fixed documentation * made unit tests more robust
* Added loss function * cleaning out loss function * added loss fn * added loss function and unit test * fixed linter issues and updated unit test to train net and eval * fixed documentation * made unit tests more robust
distances = self._compute_distances(x1, x2) | ||
log_probabilities = F.log_softmax(-distances, axis=1) | ||
# multiply for the number of labels to obtain the correct loss (gluon kl_loss averages instead of sum) | ||
return self.kl_loss(log_probabilities, labels.as_in_context(distances.context)) * batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't work in sym. 1.x will need to be fixed while 2.0 will be switched to deferred compute mode
Description
(Brief description on what this PR is about)
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments