Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to keep the temperature to be positive in .set_temperature() ? #31

Open
eugene-yh opened this issue Apr 25, 2022 · 1 comment
Open

Comments

@eugene-yh
Copy link

eugene-yh commented Apr 25, 2022

In the paper, it is stated that the temperature T has to be a positive number. In the code, however, although the temperature is initialized with a positive number (i.e., self.temperature = nn.Parameter(torch.ones(1) * 1.5)), it seems to me that there is nothing in the code of .set_temperature() to make sure that we do not end up with a negative temperature.

Did I miss something? Or, is it because it can be mathematically proven that the gradient will never push the temperature to the negative side as long as it is initialized to be positive? If not so, should we initialize with something like self.temperature = nn.Parameter(torch.ones(1) * 1.5) ** 2 to ensure that self.temperature is always positive?

@pdejorge
Copy link

@eugene-yh Intuitively, I think it should be strange that the temperature takes negative numbers since it would be inverting the prediction of the network (i.e. the most likely class would now be the least likely)

That being said, it actually happened to me that I found a negative temperature in some case. I found that adding a torch.abs(self.temperature) for the closure function worked well. For instance:

def closure():
      optimizer.zero_grad()
      scaled_logits = logits / torch.abs(self.temperature) # Ensure temperature stays positive in optimization.
      
      if metric == 'ece':
          loss = ece_criterion(scaled_logits, labels)
      elif metric == 'nll':
          loss = nll_criterion(scaled_logits, labels)
      else:
          raise NotImplementedError()
      
      loss.backward()
      return loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants