Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dliddell-aten-index-s…
Browse files Browse the repository at this point in the history
…elect-2
  • Loading branch information
Dave Liddell committed Feb 6, 2024
2 parents 510b71f + faf7d4a commit b5b1133
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 41 deletions.
100 changes: 61 additions & 39 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
Attribute,
Block,
Context,
DenseElementsAttr,
DenseResourceElementsAttr,
FloatAttr,
BF16Type,
Expand Down Expand Up @@ -207,28 +208,28 @@
}


def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string.
def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string.
The method currently just supports 2-dim sparse formats. This should be
generalized to the torch.sparse encodings for prefix dense batch dimensions
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
is currently implememented in torch.sparse, this should not a be problem.
"""
The method currently just supports 2-dim sparse formats. This should be
generalized to the torch.sparse encodings for prefix dense batch dimensions
and suffix dense subtensor dimensions. Since MLIR supports a superset of what
is currently implememented in torch.sparse, this should not a be problem.
"""

# TODO: any rank
if len(shape) != 2:
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")
# TODO: any rank
if len(shape) != 2:
raise RuntimeError(f"Unsupported sparse rank {len(shape)}")

if sparse_layout is torch.sparse_coo:
return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>'
if sparse_layout is torch.sparse_csr:
return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>'
if sparse_layout is torch.sparse_csc:
return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>'
# TODO: block format (derive block size!)
if sparse_layout is torch.sparse_coo:
return "#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>"
if sparse_layout is torch.sparse_csr:
return "#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>"
if sparse_layout is torch.sparse_csc:
return "#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>"
# TODO: block format (derive block size!)

raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")
raise RuntimeError(f"Unsupported sparse layout {sparse_layout}")


def is_symbolic(obj: Any) -> bool:
Expand Down Expand Up @@ -477,15 +478,20 @@ def format_asm_shape(self, shape: torch.Size) -> str:

"""Return IrType for !torch.vtensor with the given shape and dtype"""

def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None):
def get_vtensor_type(
self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None
):
shape_asm = self.format_asm_shape(shape)
mlir_dtype = str(self.dtype_to_type(dtype))
if sparse_layout is not None:
sparsity = sparsity_encoding(shape, sparse_layout)
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c)
sparsity = sparsity_encoding(shape, sparse_layout)
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>",
context=self._c,
)
return IrType.parse(
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c)
f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
)

def node_val_to_type(self, node: torch_fx.Node) -> IrType:
try:
Expand Down Expand Up @@ -521,7 +527,9 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType:
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
)

def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType:
def tensor_metadata_to_type(
self, tm: TensorMetadata, sparse_layout: torch.layout = None
) -> IrType:
tm_shape = tuple(
item.node if is_symbolic(item) else item for item in list(tm.shape)
)
Expand Down Expand Up @@ -686,9 +694,11 @@ def _import_symbolic_torch_op(
# operations on symbolic arguments as regular python expressions rather than as torch ops
if is_builtin_function_or_method(target):
arg_types = [
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
(
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
)
for arg in node.args
]
is_int = [item == int for item in arg_types]
Expand Down Expand Up @@ -1018,7 +1028,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
return tensor_type
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type")
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")


def _make_vtensor_literal_op(
Expand All @@ -1038,15 +1048,28 @@ def _make_vtensor_literal_op(
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
# One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
# support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
# 0d tensors.
if np_tensor.size == 1:
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
elements_attr = DenseElementsAttr.get(
type=element_type, array=np_tensor, shape=np_tensor.shape
)
else:
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
mapping.value = elements_attr
else:
elements_attr = mapping.value
Expand Down Expand Up @@ -1105,8 +1128,7 @@ def lookup(self, t: type) -> Any:

# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...
class EmptyType: ...


Empty = EmptyType()
Expand Down
3 changes: 1 addition & 2 deletions test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(f):
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_1_torch.float32> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]]
# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]]
Expand All @@ -58,7 +58,6 @@ def run(f):
# CHECK: dialect_resources:
# CHECK-DAG: torch_tensor_1_4_torch.float32
# CHECK-DAG: torch_tensor_3_1_torch.float32
# CHECK-DAG: torch_tensor_1_1_torch.float32
def test_import_frozen_exported_program():
# Tests the basic structural premises of import_frozen_exported_program,
# namely that free tensors (buffers) and parameters are treated as
Expand Down

0 comments on commit b5b1133

Please sign in to comment.