Skip to content

Commit

Permalink
[bugfix] [Relay] fix broadcast in PyTorch frontend (apache#14885)
Browse files Browse the repository at this point in the history
* fix broadcast_tensors

* Update pytorch.py

* Update test_forward.py

* Update test_forward.py
  • Loading branch information
jikechao authored and mei-ye committed Jun 1, 2023
1 parent e0bdd13 commit 46fc528
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,7 +2260,12 @@ def broadcast_tensors(self, inputs, input_types):
tensor_list = inputs[0]
import torch

res_shape = list(torch.broadcast_shapes(*[self.infer_shape(t) for t in tensor_list]))
infer_shape_value = [self.infer_shape(t) for t in tensor_list]
# "torch.broadcast_shapes" is available after PyTorch 1.8.0
if hasattr(torch, "broadcast_shapes"):
res_shape = list(torch.broadcast_shapes(*infer_shape_value))
else:
res_shape = list(torch.broadcast_tensors(*map(torch.empty, infer_shape_value))[0].shape)
return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list]

def Bool(self, inputs, input_types):
Expand Down

0 comments on commit 46fc528

Please sign in to comment.