diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e479dd097d09..08bf5d517c8b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2260,7 +2260,12 @@ def broadcast_tensors(self, inputs, input_types): tensor_list = inputs[0] import torch - res_shape = list(torch.broadcast_shapes(*[self.infer_shape(t) for t in tensor_list])) + infer_shape_value = [self.infer_shape(t) for t in tensor_list] + # "torch.broadcast_shapes" is available after PyTorch 1.8.0 + if hasattr(torch, "broadcast_shapes"): + res_shape = list(torch.broadcast_shapes(*infer_shape_value)) + else: + res_shape = list(torch.broadcast_tensors(*map(torch.empty, infer_shape_value))[0].shape) return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list] def Bool(self, inputs, input_types):