Skip to content

Commit

Permalink
Handle complex element type in torch.vtensor conversion (#175)
Browse files Browse the repository at this point in the history
Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
  • Loading branch information
sogartar authored Sep 27, 2024
1 parent a1da408 commit 3410f23
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion shark_turbine/dynamo/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# 1. Local name (int, float, vtensor)
# 2. Parameter block ("<...>"), including the delimitters
# 3. Inner parameter block (no delimitters)
DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$")
DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch\.([^<]+)(<(.*)>)?$")

# Decomposes a vtensor parameter block into a dimension list and dtype. Groups:
# 1. Dimension list
Expand Down
1 change: 1 addition & 0 deletions tests/dynamo/type_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def testValueTensors(self):
self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>")
self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor<?x?xf32>")
self._compareNative("!torch.vtensor<[],f32>", "tensor<f32>")
self._compareNative("!torch.vtensor<[],complex<f32>>", "tensor<complex<f32>>")

def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True):
with self.conv._context:
Expand Down

0 comments on commit 3410f23

Please sign in to comment.