Skip to content

Runtime error when using ArcFace without a miner #359

@gkouros

Description

@gkouros

I tried to use ArcFace loss without a miner (empty dictionary) in the TwoStreamMetricLoss.ipynb from examples on collab, but it fails with the following runtime error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-25-565c85f60968> in <module>()
      1 # In the embeddings plots, the small dots represent the 1st stream, and the larger dots represent the 2nd stream
----> 2 trainer.train(num_epochs=num_epochs)

7 frames
/usr/local/lib/python3.7/dist-packages/pytorch_metric_learning/trainers/base_trainer.py in train(self, start_epoch, num_epochs)
     85             pbar = tqdm.tqdm(range(self.iterations_per_epoch))
     86             for self.iteration in pbar:
---> 87                 self.forward_and_backward()
     88                 self.end_of_iteration_hook(self)
     89                 pbar.set_description("total_loss=%.5f" % self.losses["total_loss"])

/usr/local/lib/python3.7/dist-packages/pytorch_metric_learning/trainers/base_trainer.py in forward_and_backward(self)
    113         self.zero_grad()
    114         self.update_loss_weights()
--> 115         self.calculate_loss(self.get_batch())
    116         self.loss_tracker.update(self.loss_weights)
    117         self.backward()

/usr/local/lib/python3.7/dist-packages/pytorch_metric_learning/trainers/twostream_metric_loss.py in calculate_loss(self, curr_batch)
     16         indices_tuple = self.maybe_mine_embeddings(embeddings, labels)
     17         self.losses["metric_loss"] = self.maybe_get_metric_loss(
---> 18             embeddings, labels, indices_tuple
     19         )
     20 

/usr/local/lib/python3.7/dist-packages/pytorch_metric_learning/trainers/twostream_metric_loss.py in maybe_get_metric_loss(self, embeddings, labels, indices_tuple)
     37             all_embeddings = torch.cat(embeddings, dim=0)
     38             return self.loss_funcs["metric_loss"](
---> 39                 all_embeddings, all_labels, indices_tuple
     40             )
     41         return 0

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.7/dist-packages/pytorch_metric_learning/losses/base_metric_loss_function.py in forward(self, embeddings, labels, indices_tuple)
     32         c_f.check_shapes(embeddings, labels)
     33         labels = c_f.to_device(labels, embeddings)
---> 34         loss_dict = self.compute_loss(embeddings, labels, indices_tuple)
     35         self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
     36         return self.reducer(loss_dict, embeddings, labels)

/usr/local/lib/python3.7/dist-packages/pytorch_metric_learning/losses/large_margin_softmax_loss.py in compute_loss(self, embeddings, labels, indices_tuple)
    102         dtype, device = embeddings.dtype, embeddings.device
    103         self.cast_types(dtype, device)
--> 104         miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=dtype)
    105         mask = self.get_target_mask(embeddings, labels)
    106         cosine = self.get_cosine(embeddings)

/usr/local/lib/python3.7/dist-packages/pytorch_metric_learning/utils/loss_and_miner_utils.py in convert_to_weights(indices_tuple, labels, dtype)
    208     indices, counts = torch.unique(torch.cat(indices_tuple, dim=0), return_counts=True)
    209     counts = c_f.to_dtype(counts, dtype=dtype) / torch.sum(counts)
--> 210     weights[indices] = counts / torch.max(counts)
    211     return weights

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

It seems the labels are on the GPU, while indices_tuple is on the CPU. I'm not sure if it's a bug or I missed something.

Any help is appreciated. :)

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions