From 3410f23dfacf96eda1cb4e7e1f6f6139cca29ada Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 27 Sep 2024 13:42:13 -0400 Subject: [PATCH] Handle complex element type in torch.vtensor conversion (#175) Signed-off-by: Boian Petkantchin --- shark_turbine/dynamo/type_conversion.py | 2 +- tests/dynamo/type_conversion_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/shark_turbine/dynamo/type_conversion.py b/shark_turbine/dynamo/type_conversion.py index 8206e10f..e829bafc 100644 --- a/shark_turbine/dynamo/type_conversion.py +++ b/shark_turbine/dynamo/type_conversion.py @@ -32,7 +32,7 @@ # 1. Local name (int, float, vtensor) # 2. Parameter block ("<...>"), including the delimitters # 3. Inner parameter block (no delimitters) -DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$") +DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch\.([^<]+)(<(.*)>)?$") # Decomposes a vtensor parameter block into a dimension list and dtype. Groups: # 1. Dimension list diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index dfc3de25..617c5d05 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -32,6 +32,7 @@ def testValueTensors(self): self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>") self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor") self._compareNative("!torch.vtensor<[],f32>", "tensor") + self._compareNative("!torch.vtensor<[],complex>", "tensor>") def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True): with self.conv._context: