Skip to content

Commit

Permalink
fix dnl_head export onnx inference difference type Cast error (#1161)
Browse files Browse the repository at this point in the history
* fix export onnx inference difference type Cast error

* fix export onnx inference difference type Cast error.

* use yapf format

* use same device type with pairwise_weight
  • Loading branch information
sshuair committed Mar 1, 2022
1 parent 69d5cc5 commit 9947a39
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions mmseg/models/decode_heads/dnl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ def embedded_gaussian(self, theta_x, phi_x):
pairwise_weight = torch.matmul(theta_x, phi_x)
if self.use_scale:
# theta_x.shape[-1] is `self.inter_channels`
pairwise_weight /= theta_x.shape[-1]**0.5
pairwise_weight /= self.temperature
pairwise_weight /= torch.tensor(
theta_x.shape[-1],
dtype=torch.float,
device=pairwise_weight.device)**torch.tensor(
0.5, device=pairwise_weight.device)
pairwise_weight /= torch.tensor(
self.temperature, device=pairwise_weight.device)
pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight

Expand Down

0 comments on commit 9947a39

Please sign in to comment.