diff --git a/shark_turbine/dynamo/type_conversion.py b/shark_turbine/dynamo/type_conversion.py index 8206e10f..e829bafc 100644 --- a/shark_turbine/dynamo/type_conversion.py +++ b/shark_turbine/dynamo/type_conversion.py @@ -32,7 +32,7 @@ # 1. Local name (int, float, vtensor) # 2. Parameter block ("<...>"), including the delimitters # 3. Inner parameter block (no delimitters) -DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$") +DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch\.([^<]+)(<(.*)>)?$") # Decomposes a vtensor parameter block into a dimension list and dtype. Groups: # 1. Dimension list diff --git a/tests/dynamo/type_conversion_test.py b/tests/dynamo/type_conversion_test.py index dfc3de25..617c5d05 100644 --- a/tests/dynamo/type_conversion_test.py +++ b/tests/dynamo/type_conversion_test.py @@ -32,6 +32,7 @@ def testValueTensors(self): self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>") self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor") self._compareNative("!torch.vtensor<[],f32>", "tensor") + self._compareNative("!torch.vtensor<[],complex>", "tensor>") def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True): with self.conv._context: