diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..11547a2312 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -121,7 +121,24 @@ void setup_input_tensors( // Shape tensor inputs are casted to int64 explicitly. // Refer to // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 - auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64); + at::Tensor cloned_input; + + // Check if it's a scalar tensor (0-dimensional) + if (inputs[i].dim() == 0 && inputs[i].numel() == 1) { + // It's a scalar tensor, create a proper tensor from the scalar value + int64_t scalar_value = inputs[i].item(); + LOG_DEBUG("Input " << i << " is a scalar tensor with value: " << scalar_value); + cloned_input = torch::tensor({scalar_value}, torch::kInt64); + LOG_DEBUG("cloned_input dim: " << cloned_input.dim() << " ; numel: " << cloned_input.numel()); + } else { + // It's a regular tensor + LOG_DEBUG( + "Input " << i << " is a regular tensor" + << " inputs[i]: " << inputs[i]); + cloned_input = inputs[i].clone(); + } + auto input_cpu = cloned_input.contiguous().cpu().to(torch::kInt64); + std::vector inputs_cpu_vec( input_cpu.data_ptr(), input_cpu.data_ptr() + input_cpu.numel()); inputShapeTensorValues.emplace_back(inputs_cpu_vec); diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index fc76b20141..dda5929d4b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -430,7 +430,7 @@ def create_output_allocator(self) -> None: def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: - shape_changed = self.validate_input_shapes(inputs) + shape_changed = self.validate_input_shapes(contiguous_inputs) ( need_cudagraphs_record, can_use_pre_allocated_outputs,