diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 8cffcb1ea935..5328e8730cc3 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -51,6 +51,7 @@ Attribute, Block, Context, + DenseElementsAttr, DenseResourceElementsAttr, FloatAttr, BF16Type, @@ -207,28 +208,28 @@ } -def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str: - """Returns sparse tensor encoding for the given sparse layout as string. +def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str: + """Returns sparse tensor encoding for the given sparse layout as string. - The method currently just supports 2-dim sparse formats. This should be - generalized to the torch.sparse encodings for prefix dense batch dimensions - and suffix dense subtensor dimensions. Since MLIR supports a superset of what - is currently implememented in torch.sparse, this should not a be problem. - """ + The method currently just supports 2-dim sparse formats. This should be + generalized to the torch.sparse encodings for prefix dense batch dimensions + and suffix dense subtensor dimensions. Since MLIR supports a superset of what + is currently implememented in torch.sparse, this should not a be problem. + """ - # TODO: any rank - if len(shape) != 2: - raise RuntimeError(f"Unsupported sparse rank {len(shape)}") + # TODO: any rank + if len(shape) != 2: + raise RuntimeError(f"Unsupported sparse rank {len(shape)}") - if sparse_layout is torch.sparse_coo: - return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>' - if sparse_layout is torch.sparse_csr: - return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>' - if sparse_layout is torch.sparse_csc: - return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>' - # TODO: block format (derive block size!) + if sparse_layout is torch.sparse_coo: + return "#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>" + if sparse_layout is torch.sparse_csr: + return "#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>" + if sparse_layout is torch.sparse_csc: + return "#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>" + # TODO: block format (derive block size!) - raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") + raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") def is_symbolic(obj: Any) -> bool: @@ -477,15 +478,20 @@ def format_asm_shape(self, shape: torch.Size) -> str: """Return IrType for !torch.vtensor with the given shape and dtype""" - def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None): + def get_vtensor_type( + self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None + ): shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) if sparse_layout is not None: - sparsity = sparsity_encoding(shape, sparse_layout) - return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c) + sparsity = sparsity_encoding(shape, sparse_layout) + return IrType.parse( + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", + context=self._c, + ) return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c) + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c + ) def node_val_to_type(self, node: torch_fx.Node) -> IrType: try: @@ -521,7 +527,9 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" ) - def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType: + def tensor_metadata_to_type( + self, tm: TensorMetadata, sparse_layout: torch.layout = None + ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) @@ -686,9 +694,11 @@ def _import_symbolic_torch_op( # operations on symbolic arguments as regular python expressions rather than as torch ops if is_builtin_function_or_method(target): arg_types = [ - arg.meta["val"].node.pytype - if isinstance(arg, torch.fx.Node) - else type(arg) + ( + arg.meta["val"].node.pytype + if isinstance(arg, torch.fx.Node) + else type(arg) + ) for arg in node.args ] is_int = [item == int for item in arg_types] @@ -1018,7 +1028,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) return tensor_type except KeyError: - raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type") + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") def _make_vtensor_literal_op( @@ -1038,15 +1048,28 @@ def _make_vtensor_literal_op( # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) np_tensor = np.array(tensor.tolist()).astype(npy_dtype) - bytes_view = memoryview(np_tensor) - tensor_type = create_mlir_tensor_type(tensor) - shape_desc = "_".join([str(d) for d in tensor.shape]) - blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" - elements_attr = DenseResourceElementsAttr.get_from_buffer( - bytes_view, - blob_name, - tensor_type, - ) + # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not + # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling + # 0d tensors. + if np_tensor.size == 1: + try: + dtype = tensor.dtype + element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() + except KeyError: + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") + elements_attr = DenseElementsAttr.get( + type=element_type, array=np_tensor, shape=np_tensor.shape + ) + else: + bytes_view = memoryview(np_tensor) + tensor_type = create_mlir_tensor_type(tensor) + shape_desc = "_".join([str(d) for d in tensor.shape]) + blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" + elements_attr = DenseResourceElementsAttr.get_from_buffer( + bytes_view, + blob_name, + tensor_type, + ) mapping.value = elements_attr else: elements_attr = mapping.value @@ -1105,8 +1128,7 @@ def lookup(self, t: type) -> Any: # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: - ... +class EmptyType: ... Empty = EmptyType() diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 62d3b1203e03..acd2a559fa52 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -47,7 +47,7 @@ def run(f): # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> # CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> -# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense_resource : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> # CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] # CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] # CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] @@ -58,7 +58,6 @@ def run(f): # CHECK: dialect_resources: # CHECK-DAG: torch_tensor_1_4_torch.float32 # CHECK-DAG: torch_tensor_3_1_torch.float32 -# CHECK-DAG: torch_tensor_1_1_torch.float32 def test_import_frozen_exported_program(): # Tests the basic structural premises of import_frozen_exported_program, # namely that free tensors (buffers) and parameters are treated as