Skip to content

Commit

Permalink
Handle complex element type in torch.vtensor conversion (iree-org#175)
Browse files Browse the repository at this point in the history
Signed-off-by: Boian Petkantchin <boian.petkantchin@amd.com>
Signed-off-by: Ian <ian.nordeng@amd.com>
  • Loading branch information
sogartar authored and IanNod committed Sep 30, 2024
1 parent f64be82 commit bdc9c4d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion shark_turbine/dynamo/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# 1. Local name (int, float, vtensor)
# 2. Parameter block ("<...>"), including the delimitters
# 3. Inner parameter block (no delimitters)
DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch.([^<]+)(<([^>]*)>)?$")
DECOMPOSE_TORCH_TYPE_PATTERN = re.compile(r"^!torch\.([^<]+)(<(.*)>)?$")

# Decomposes a vtensor parameter block into a dimension list and dtype. Groups:
# 1. Dimension list
Expand Down
1 change: 1 addition & 0 deletions tests/dynamo/type_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def testValueTensors(self):
self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>")
self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor<?x?xf32>")
self._compareNative("!torch.vtensor<[],f32>", "tensor<f32>")
self._compareNative("!torch.vtensor<[],complex<f32>>", "tensor<complex<f32>>")

def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True):
with self.conv._context:
Expand Down

0 comments on commit bdc9c4d

Please sign in to comment.