diff --git a/python/tvm/relax/relax_to_pyfunc_converter.py b/python/tvm/relax/relax_to_pyfunc_converter.py new file mode 100644 index 000000000000..3de27d78c863 --- /dev/null +++ b/python/tvm/relax/relax_to_pyfunc_converter.py @@ -0,0 +1,1104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relax to Python Function Converter. + +This module provides functionality to convert Relax functions to Python functions +that can be executed directly in Python/PyTorch environment. +""" + +from typing import Any, Dict, List, Union + +import torch +import torch.nn.functional as F + +import tvm +from tvm import relax +from tvm.ir import IRModule, Op + + +class RelaxToPyFuncConverter: + """Converter that works with IRModule to convert Relax functions to Python functions. + + This converter transforms Relax functions into Python functions that can be executed + directly in Python/PyTorch environment. The conversion maps Relax operators to + corresponding PyTorch APIs and handles special cases like call_tir and call_dps_packed. + """ + + def __init__(self, ir_module: IRModule): + """Initialize the converter with an IRModule. + + Args: + ir_module: The IRModule containing Relax functions to convert + """ + self.ir_module = ir_module + self.operator_map = self._get_op_map() + # Cache for RelaxExpressionConverter instances to avoid recreating them + self._converter_cache = {} + # Cache for operator mappings to avoid repeated lookups + self._op_cache = {} + + def convert(self, relax_function_names: Union[str, List[str]]) -> IRModule: + """Convert specified Relax functions to Python functions. + + Args: + relax_function_names: Name(s) of Relax functions to convert + + Returns: + Updated IRModule with converted Python functions stored in pyfuncs + + Example: + >>> converter = RelaxToPyFuncConverter(ir_mod) + >>> # Convert a single function + >>> converted_ir_mod = converter.convert("my_relax_func") + >>> # Convert multiple functions + >>> converted_ir_mod = converter.convert(["func1", "func2"]) + """ + if isinstance(relax_function_names, str): + relax_function_names = [relax_function_names] + + # Create a copy of the current IRModule + new_ir_mod = self.ir_module.clone() + + # Initialize pyfuncs if not exists + if not hasattr(new_ir_mod, "pyfuncs"): + new_ir_mod.pyfuncs = {} + + # Get Relax function names from IRModule + relax_func_names = [] + for global_var, func in self.ir_module.functions_items(): + if isinstance(func, relax.Function): + relax_func_names.append(global_var.name_hint) + + # Convert each Relax function + for func_name in relax_function_names: + if func_name not in relax_func_names: + raise ValueError(f"Relax function '{func_name}' not found in IRModule") + + # Get the Relax function + relax_func = None + for global_var, func in self.ir_module.functions_items(): + if global_var.name_hint == func_name and isinstance(func, relax.Function): + relax_func = func + break + + if relax_func is None: + raise ValueError(f"Could not find Relax function '{func_name}'") + + # Convert to Python function + py_func = self._convert_relax_func_to_python(relax_func, func_name) + + # Store in pyfuncs + new_ir_mod.pyfuncs[func_name] = py_func + + return new_ir_mod + + def _convert_relax_func_to_python(self, relax_func: relax.Function, func_name: str) -> callable: + """Convert a single Relax function to a Python function with caching.""" + # Get function parameters + params = relax_func.params + + # Create the Python function + def converted_function(*args, **_kwargs): + """Converted Python function from Relax function.""" + # Handle arguments + if len(args) != len(params): + raise ValueError(f"Expected {len(params)} arguments, got {len(args)}") + + # Use cached converter or create new one + if func_name not in self._converter_cache: + self._converter_cache[func_name] = RelaxExpressionConverter( + self.operator_map, self.ir_module, self._op_cache + ) + + # Execute the converted function body + converter = self._converter_cache[func_name] + converter.current_params = params + return converter.convert_expr(relax_func.body, args) + + # Set function metadata + converted_function.__name__ = func_name + converted_function.__doc__ = f"Converted Python function from Relax function: {func_name}" + + return converted_function + + @staticmethod + def _get_op_map() -> Dict[str, str]: + """Get the mapping from Relax operators to PyTorch operators.""" + return { + # Binary operations + "relax.add": "torch.add", + "relax.subtract": "torch.sub", + "relax.multiply": "torch.mul", + "relax.divide": "torch.div", + "relax.power": "torch.pow", + "relax.maximum": "torch.maximum", + "relax.minimum": "torch.minimum", + "relax.floor_divide": "torch.floor_divide", + "relax.mod": "torch.fmod", + "relax.floor_mod": "torch.remainder", + "relax.log_add_exp": "torch.logaddexp", + # Bitwise operations + "relax.bitwise_and": "torch.bitwise_and", + "relax.bitwise_or": "torch.bitwise_or", + "relax.bitwise_xor": "torch.bitwise_xor", + "relax.left_shift": "torch.left_shift", + "relax.right_shift": "torch.right_shift", + # Unary operations + "relax.abs": "torch.abs", + "relax.negative": "torch.neg", + "relax.exp": "torch.exp", + "relax.log": "torch.log", + "relax.sqrt": "torch.sqrt", + "relax.rsqrt": "torch.rsqrt", + "relax.sin": "torch.sin", + "relax.cos": "torch.cos", + "relax.tanh": "torch.tanh", + "relax.sigmoid": "torch.sigmoid", + "relax.square": "torch.square", + "relax.sign": "torch.sign", + "relax.floor": "torch.floor", + "relax.ceil": "torch.ceil", + "relax.round": "torch.round", + "relax.trunc": "torch.trunc", + "relax.clip": "torch.clamp", + "relax.bitwise_not": "torch.bitwise_not", + # Trigonometric functions + "relax.acos": "torch.acos", + "relax.asin": "torch.asin", + "relax.atan": "torch.atan", + "relax.cosh": "torch.cosh", + "relax.sinh": "torch.sinh", + "relax.tan": "torch.tan", + "relax.acosh": "torch.acosh", + "relax.asinh": "torch.asinh", + "relax.atanh": "torch.atanh", + # Special functions + "relax.erf": "torch.erf", + "relax.isfinite": "torch.isfinite", + "relax.isinf": "torch.isinf", + "relax.isnan": "torch.isnan", + # Neural network operations + "relax.nn.relu": "F.relu", + "relax.nn.relu6": "F.relu6", + "relax.nn.gelu": "F.gelu", + "relax.nn.gelu_tanh": "F.gelu", + "relax.nn.softmax": "F.softmax", + "relax.nn.log_softmax": "F.log_softmax", + "relax.nn.dropout": "F.dropout", + "relax.nn.batch_norm": "F.batch_norm", + "relax.nn.layer_norm": "F.layer_norm", + "relax.nn.group_norm": "F.group_norm", + "relax.nn.instance_norm": "F.instance_norm", + "relax.nn.rms_norm": "F.layer_norm", # Approximate mapping + "relax.nn.linear": "F.linear", + "relax.nn.conv1d": "F.conv1d", + "relax.nn.conv2d": "F.conv2d", + "relax.nn.conv3d": "F.conv3d", + "relax.nn.conv1d_transpose": "F.conv_transpose1d", + "relax.nn.conv2d_transpose": "F.conv_transpose2d", + "relax.nn.conv3d_transpose": "F.conv_transpose3d", + "relax.nn.max_pool1d": "F.max_pool1d", + "relax.nn.max_pool2d": "F.max_pool2d", + "relax.nn.max_pool3d": "F.max_pool3d", + "relax.nn.avg_pool1d": "F.avg_pool1d", + "relax.nn.avg_pool2d": "F.avg_pool2d", + "relax.nn.avg_pool3d": "F.avg_pool3d", + "relax.nn.adaptive_avg_pool1d": "F.adaptive_avg_pool1d", + "relax.nn.adaptive_avg_pool2d": "F.adaptive_avg_pool2d", + "relax.nn.adaptive_avg_pool3d": "F.adaptive_avg_pool3d", + "relax.nn.leakyrelu": "F.leaky_relu", + "relax.nn.prelu": "F.prelu", + "relax.nn.selu": "F.selu", + "relax.nn.silu": "F.silu", + "relax.nn.softplus": "F.softplus", + "relax.nn.attention": "F.scaled_dot_product_attention", # Approximate mapping + "relax.nn.cross_entropy_with_logits": "F.cross_entropy", + "relax.nn.nll_loss": "F.nll_loss", + "relax.nn.pad": "F.pad", + "relax.nn.pixel_shuffle": "F.pixel_shuffle", + # Tensor operations + "relax.matmul": "torch.matmul", + "relax.linear": "F.linear", + "relax.einsum": "torch.einsum", + "relax.outer": "torch.outer", + "relax.reshape": "reshape", # Special handling needed + "relax.permute_dims": "permute_dims", # Special handling needed + "relax.expand_dims": "expand_dims", # Special handling needed + "relax.squeeze": "squeeze", # Special handling needed + "relax.concat": "concat", # Special handling needed + "relax.split": "split", # Special handling needed + "relax.stack": "stack", # Special handling needed + "relax.tile": "tile", # Special handling needed + "relax.repeat": "repeat", # Special handling needed + "relax.broadcast_to": "torch.broadcast_to", + "relax.flatten": "torch.flatten", + "relax.flip": "flip", # Special handling needed + "relax.roll": "torch.roll", + "relax.rot90": "torch.rot90", + "relax.meshgrid": "torch.meshgrid", + "relax.one_hot": "F.one_hot", + "relax.layout_transform": "torch.permute", # Approximate mapping + # Indexing operations + "relax.take": "take", # Special handling needed + "relax.gather_elements": "torch.gather", + "relax.gather_nd": "torch.gather", + "relax.scatter_elements": "torch.scatter", + "relax.scatter_nd": "torch.scatter", + "relax.index_put": "torch.index_put", + "relax.index_tensor": "torch.index_select", + "relax.strided_slice": "torch.slice", + "relax.dynamic_strided_slice": "torch.slice", + "relax.slice_scatter": "torch.scatter", + # Reduction operations + "relax.sum": "sum", # Special handling needed + "relax.mean": "mean", # Special handling needed + "relax.max": "max", # Special handling needed + "relax.min": "min", # Special handling needed + "relax.prod": "torch.prod", + "relax.std": "std", # Special handling needed + "relax.variance": "variance", # Special handling needed + "relax.cumsum": "torch.cumsum", + "relax.cumprod": "torch.cumprod", + "relax.argmax": "torch.argmax", + "relax.argmin": "torch.argmin", + # Comparison operations + "relax.equal": "torch.eq", + "relax.not_equal": "torch.ne", + "relax.greater": "torch.gt", + "relax.greater_equal": "torch.ge", + "relax.less": "torch.lt", + "relax.less_equal": "torch.le", + # Logical operations + "relax.logical_and": "torch.logical_and", + "relax.logical_or": "torch.logical_or", + "relax.logical_not": "torch.logical_not", + "relax.logical_xor": "torch.logical_xor", + # Creation operations + "relax.zeros": "torch.zeros", + "relax.ones": "torch.ones", + "relax.full": "torch.full", + "relax.full_like": "torch.full_like", + "relax.zeros_like": "torch.zeros_like", + "relax.ones_like": "torch.ones_like", + "relax.arange": "torch.arange", + "relax.eye": "torch.eye", + "relax.eye_like": "torch.eye", + "relax.tril": "torch.tril", + "relax.triu": "torch.triu", + "relax.hamming_window": "torch.hamming_window", + # Search operations + "relax.where": "torch.where", + "relax.bucketize": "torch.bucketize", + "relax.nonzero": "torch.nonzero", + "relax.unique": "torch.unique", + # Sorting operations + "relax.sort": "torch.sort", + "relax.argsort": "torch.argsort", + "relax.topk": "torch.topk", + # Sampling operations + "relax.multinomial_from_uniform": "torch.multinomial", + # Ternary operations + "relax.ewise_fma": "torch.fma", # Approximate mapping + # Data type operations + "relax.astype": "torch.to", + "relax.wrap_param": "torch.tensor", + # Mask operations + "relax.masked_fill": "torch.masked_fill", + # Quantization operations + "relax.quantize": "torch.quantize_per_tensor", # Approximate mapping + "relax.dequantize": "torch.dequantize", # Approximate mapping + # Special operations (handled separately) + "relax.call_tir": "call_tir", + "relax.call_tir_inplace": "call_tir_inplace", + "relax.call_dps_packed": "call_dps_packed", + "relax.call_pure_packed": "call_pure_packed", + "relax.call_tir_with_grad": "call_tir_with_grad", + "relax.call_builtin_with_ctx": "call_builtin_with_ctx", + "relax.call_inplace_packed": "call_inplace_packed", + "relax.invoke_closure": "invoke_closure", + "relax.invoke_pure_closure": "invoke_pure_closure", + "relax.make_closure": "make_closure", + "relax.null_value": "null_value", + "relax.print": "print", + "relax.shape_of": "shape_of", + "relax.shape_to_tensor": "shape_to_tensor", + "relax.tensor_to_shape": "tensor_to_shape", + "relax.to_vdevice": "to_vdevice", + "relax.hint_on_device": "hint_on_device", + "relax.assert_op": "assert_op", + } + + +class RelaxExpressionConverter: + """Converter that transforms Relax expressions to Python/PyTorch code.""" + + def __init__( + self, + operator_map: Dict[str, str], + ir_module: IRModule = None, + op_cache: Dict[str, str] = None, + ): + """Initialize the expression converter. + + Args: + operator_map: Mapping from Relax operators to PyTorch operators + ir_module: The IRModule containing TIR functions to compile + op_cache: Shared cache for operator mappings to avoid repeated lookups + """ + self.operator_map = operator_map + self.variable_map: Dict[str, Any] = {} + self.current_params: List[relax.Var] = [] + self.ir_module = ir_module + # Use shared operator cache or create new one + self._op_cache = op_cache if op_cache is not None else {} + + def convert_expr(self, expr: relax.Expr, args: List[Any]) -> Any: + """Convert a Relax expression to Python/PyTorch equivalent.""" + if isinstance(expr, relax.Var): + return self._convert_var(expr, args) + elif isinstance(expr, relax.Call): + return self._convert_call(expr, args) + elif isinstance(expr, relax.Constant): + return self._convert_constant(expr) + elif isinstance(expr, relax.SeqExpr): + return self._convert_seq_expr(expr, args) + elif isinstance(expr, relax.Tuple): + return self._convert_tuple(expr, args) + elif isinstance(expr, relax.TupleGetItem): + return self._convert_tuple_get_item(expr, args) + elif isinstance(expr, relax.If): + return self._convert_if(expr, args) + elif isinstance(expr, relax.ShapeExpr): + return self._convert_shape_expr(expr) + else: + # Fallback for unknown expression types + return f"" + + def _convert_var(self, var: relax.Var, args: List[Any]) -> Any: + """Convert a Relax variable to Python equivalent.""" + if hasattr(var, "name_hint"): + var_name = var.name_hint + + # Check if it's a function parameter + for i, param in enumerate(self.current_params): + if hasattr(param, "name_hint") and param.name_hint == var_name: + return args[i] + + # Check if it's a bound variable + if var_name in self.variable_map: + return self.variable_map[var_name] + + # Return placeholder for unbound variables + return f"" + return f"" + + def _convert_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax call to Python/PyTorch equivalent.""" + op = call.op + + # Handle different types of calls + if isinstance(op, relax.GlobalVar): + # Function call + return self._convert_function_call(call, args) + elif isinstance(op, Op): + # Operator call + return self._convert_operator_call(call, args) + elif isinstance(op, relax.ExternFunc): + # External function call (like call_tir, call_dps_packed) + return self._convert_extern_func_call(call, args) + else: + return f"" + + def _convert_function_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax function call.""" + func_name = call.op.name_hint + call_args = [self.convert_expr(arg, args) for arg in call.args] + + # Handle special cases + if func_name in ["call_tir", "call_tir_inplace"]: + return self._convert_call_tir(call, args) + elif func_name in ["call_dps_packed", "call_pure_packed"]: + return self._convert_call_dps_packed(call, args) + else: + # Regular function call + return f"" + + def _convert_operator_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert a Relax operator call to PyTorch equivalent.""" + op_name = call.op.name + call_args = [self.convert_expr(arg, args) for arg in call.args] + + # Use cached operator mapping or look it up + if op_name not in self._op_cache: + self._op_cache[op_name] = self.operator_map.get(op_name) + pytorch_op = self._op_cache[op_name] + if pytorch_op: + try: + # Handle special operations + if pytorch_op == "call_tir": + return self._convert_call_tir(call, args) + elif pytorch_op == "call_tir_inplace": + return self._convert_call_tir(call, args) + elif pytorch_op == "call_dps_packed": + return self._convert_call_dps_packed(call, args) + elif pytorch_op == "call_pure_packed": + return self._convert_call_dps_packed(call, args) + elif pytorch_op == "expand_dims": + return self._convert_expand_dims(call, args) + elif pytorch_op in ["sum", "mean", "max", "min", "std", "variance"]: + return self._convert_reduction_op(call, args, pytorch_op) + elif pytorch_op == "squeeze": + return self._convert_squeeze(call, args) + elif pytorch_op in ["concat", "split", "stack"]: + return self._convert_tensor_ops(call, args, pytorch_op) + elif pytorch_op == "reshape": + return self._convert_reshape(call, args) + elif pytorch_op == "permute_dims": + return self._convert_permute_dims(call, args) + elif pytorch_op == "take": + return self._convert_take(call, args) + elif pytorch_op == "flip": + return self._convert_flip(call, args) + elif pytorch_op == "tile": + return self._convert_tile(call, args) + elif pytorch_op == "repeat": + return self._convert_repeat(call, args) + # Handle special cases for PyTorch operations + elif pytorch_op.startswith("F."): + return self._handle_functional_operation(pytorch_op, call, call_args) + elif pytorch_op.startswith("torch."): + # Regular PyTorch operation + func_name = pytorch_op[6:] # Remove "torch." prefix + func = getattr(torch, func_name) + return func(*call_args) + else: + # Direct function reference - use getattr for safer access + if pytorch_op.startswith("torch."): + module = torch + func_name = pytorch_op[6:] # Remove "torch." prefix + elif pytorch_op.startswith("F."): + module = F + func_name = pytorch_op[2:] # Remove "F." prefix + else: + return ( + f"" + ) + + func = getattr(module, func_name, None) + if func is None: + return ( + f"" + ) + return func(*call_args) + except (AttributeError, TypeError, ValueError) as error: + # This allows the test framework to catch and handle the errors appropriately + if pytorch_op.startswith("torch.") or pytorch_op.startswith("F."): + raise error + # Fallback to string representation for non-PyTorch operations + return f"" + else: + # Unknown operator + return f"" + + def _handle_functional_operation( + self, pytorch_op: str, call: relax.Call, call_args: List[Any] + ) -> Any: + """Handle PyTorch functional operations with special parameter handling.""" + # Neural network function + func_name = pytorch_op[2:] # Remove "F." prefix + func = getattr(F, func_name) + + # Special handling for functions that need dim parameter + if func_name in ["softmax", "log_softmax"]: + # Extract axis from call.attrs and convert to dim + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + return func(call_args[0], dim=axis) + else: + # Default to last dimension if no axis specified + return func(call_args[0], dim=-1) + else: + return func(*call_args) + + def _convert_extern_func_call(self, call: relax.Call, args: List[Any]) -> Any: + """Convert an external function call.""" + func_name = call.op.global_symbol + call_args = [self.convert_expr(arg, args) for arg in call.args] + + if func_name in ["call_tir", "call_tir_inplace"]: + return self._convert_call_tir(call, args) + elif func_name in ["call_dps_packed", "call_pure_packed"]: + return self._convert_call_dps_packed(call, args) + else: + return f"" + + def _convert_call_tir(self, call: relax.Call, args: List[Any]) -> Any: + """Convert call_tir to Python equivalent with DLPack conversion.""" + # Extract TIR function name and arguments + tir_func = call.args[0] + tir_args = call.args[1] if len(call.args) > 1 else [] + out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + + # Get function name + if isinstance(tir_func, relax.GlobalVar): + func_name = tir_func.name_hint + else: + # Convert the GlobalVar expression + func_name = self.convert_expr(tir_func, args) + if isinstance(func_name, str) and func_name.startswith("<"): + # If it's a placeholder, extract the name + func_name = str(tir_func) + + # Convert arguments to PyTorch tensors + converted_args = [self.convert_expr(arg, args) for arg in tir_args] + + try: + # First, try to get the TIR function from the current IRModule + tir_function = None + if self.ir_module: + # Look for the TIR function in the current IRModule + for global_var, func in self.ir_module.functions.items(): + if global_var.name_hint == func_name and hasattr(func, "body"): + try: + # Compile the TIR function + target = tvm.target.Target("llvm") + with tvm.target.Target(target): + tir_function = tvm.compile(func, target=target) + break + except (RuntimeError, ValueError, TypeError) as compile_e: + print( + f"Warning: Failed to compile TIR function {func_name}: {compile_e}" + ) + continue + + # If not found in current module, try global registry + if tir_function is None: + tir_function = tvm.get_global_func(func_name) + + if tir_function is None: + return ( + f"" + ) + + # Convert PyTorch tensors to TVM NDArrays via DLPack + tvm_args = [] + for arg in converted_args: + if isinstance(arg, torch.Tensor): + # Convert PyTorch tensor to TVM NDArray via DLPack + tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg)) + tvm_args.append(tvm_arg) + else: + tvm_args.append(arg) + + # For call_tir, we need to allocate output tensor + output_shape = None + if out_sinfo and hasattr(out_sinfo, "shape"): + output_shape = out_sinfo.shape + elif converted_args: + # Use the shape of the first input tensor + first_arg = converted_args[0] + if isinstance(first_arg, torch.Tensor): + output_shape = first_arg.shape + + if output_shape is None: + return f"" + + # Allocate output tensor + output_tensor = tvm.nd.array(tvm.nd.empty(output_shape, dtype="float32")) + tvm_args.append(output_tensor) + + # Call the TIR function + tir_function(*tvm_args) + + # The result is in the output_tensor we allocated + # Convert result back to PyTorch tensor via DLPack + return torch.from_dlpack(output_tensor.to_dlpack()) + + except (RuntimeError, ValueError, TypeError) as error: + return f"" + + def _convert_call_dps_packed(self, call: relax.Call, args: List[Any]) -> Any: + """Convert call_dps_packed to Python equivalent with DLPack conversion.""" + # Extract packed function name and arguments + packed_func = call.args[0] + packed_args = call.args[1] if len(call.args) > 1 else [] + _out_sinfo = call.attrs.get("out_sinfo") if call.attrs else None + + # Get function name + if isinstance(packed_func, relax.GlobalVar): + func_name = packed_func.name_hint + elif isinstance(packed_func, relax.ExternFunc): + func_name = packed_func.global_symbol + else: + func_name = str(packed_func) + + # Convert arguments to PyTorch tensors + converted_args = [self.convert_expr(arg, args) for arg in packed_args] + + try: + # Get the packed function from TVM + packed_function = tvm.get_global_func(func_name) + if packed_function is None: + return f"" + + # Convert PyTorch tensors to TVM NDArrays via DLPack + tvm_args = [] + for arg in converted_args: + if isinstance(arg, torch.Tensor): + # Convert PyTorch tensor to TVM NDArray via DLPack + tvm_arg = tvm.nd.from_dlpack(torch.to_dlpack(arg)) + tvm_args.append(tvm_arg) + else: + tvm_args.append(arg) + + # Call the packed function + result = packed_function(*tvm_args) + + # Convert result back to PyTorch tensor via DLPack + if isinstance(result, tvm.nd.NDArray): + return torch.from_dlpack(result.to_dlpack()) + else: + return result + + except (RuntimeError, ValueError, TypeError) as error: + return f"" + + def _convert_constant(self, const: relax.Constant) -> Any: + """Convert a Relax constant to Python equivalent.""" + if hasattr(const, "data"): + data = const.data + # Convert TVM NDArray to Python scalar if it's a scalar + if hasattr(data, "numpy"): + numpy_data = data.numpy() + if numpy_data.size == 1: + return float(numpy_data.item()) + else: + # For multi-element arrays, convert to PyTorch tensor + return torch.from_numpy(numpy_data) + elif hasattr(data, "item"): + # Single element tensor + return data.item() + else: + return data + return f"" + + def _convert_seq_expr(self, seq: relax.SeqExpr, args: List[Any]) -> Any: + """Convert a Relax sequence expression.""" + # Convert blocks + for block in seq.blocks: + if hasattr(block, "bindings"): + for binding in block.bindings: + if isinstance(binding, relax.VarBinding): + var_name = binding.var.name_hint + value = self.convert_expr(binding.value, args) + self.variable_map[var_name] = value + + # Convert body + return self.convert_expr(seq.body, args) + + def _convert_tuple(self, tuple_expr: relax.Tuple, args: List[Any]) -> Any: + """Convert a Relax tuple to Python tuple.""" + elements = [self.convert_expr(elem, args) for elem in tuple_expr.fields] + return tuple(elements) + + def _convert_tuple_get_item(self, get_item: relax.TupleGetItem, args: List[Any]) -> Any: + """Convert a Relax tuple get item to Python equivalent.""" + tuple_expr = self.convert_expr(get_item.tuple_value, args) + index = get_item.index + return f"" + + def _convert_if(self, if_expr: relax.If, args: List[Any]) -> Any: + """Convert a Relax if expression to Python equivalent.""" + condition = self.convert_expr(if_expr.cond, args) + true_branch = self.convert_expr(if_expr.true_branch, args) + false_branch = self.convert_expr(if_expr.false_branch, args) + return f"" + + def _convert_expand_dims(self, call: relax.Call, args: List[Any]) -> Any: + """Convert expand_dims to torch.unsqueeze with proper axis handling.""" + if len(call.args) < 1: + return "" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get the axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, "__iter__") and not isinstance(axis, str): + # It's an array/list, take the first element + axis = list(axis)[0] if len(axis) > 0 else None + + # Handle TVM types + if hasattr(axis, "value"): + # It's a TVM IntImm or similar, get the value + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is None: + return "" + + # Use torch.unsqueeze with the correct axis + return torch.unsqueeze(tensor_arg, dim=axis) + + def _convert_reduction_op(self, call: relax.Call, args: List[Any], op_name: str) -> Any: + """Convert reduction operations with axis and keepdims parameters.""" + if len(call.args) < 1: + return f"<{op_name}_error: insufficient arguments>" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get axis and keepdims from call.attrs + axis = None + keepdims = False + + if call.attrs: + if hasattr(call.attrs, "axis") and call.attrs.axis is not None: + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, "__iter__") and not isinstance(axis, str): + # It's an array/list, convert to list of ints + axis = [ + int(item.value) if hasattr(item, "value") else int(item) for item in axis + ] + elif hasattr(axis, "value"): + # It's a TVM IntImm, get the value + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if hasattr(call.attrs, "keepdims"): + keepdims = bool(call.attrs.keepdims) + + # Get the PyTorch function + func = getattr(torch, op_name) + + # Call with appropriate parameters + if axis is not None: + # For max and min, PyTorch returns (values, indices) tuple when dim is specified + if op_name in ["max", "min"]: + if isinstance(axis, list) and len(axis) == 1: + axis = axis[0] + elif isinstance(axis, list) and len(axis) > 1: + axis = axis[0] + result = func(tensor_arg, axis, keepdim=keepdims) + if isinstance(result, tuple): + return result[0] + else: + return result + else: + return func(tensor_arg, dim=axis, keepdim=keepdims) + else: + return func(tensor_arg) + + def _convert_squeeze(self, call: relax.Call, args: List[Any]) -> Any: + """Convert squeeze to torch.squeeze with proper axis handling.""" + if len(call.args) < 1: + return "" + + # Convert the tensor argument + tensor_arg = self.convert_expr(call.args[0], args) + + # Get axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis") and call.attrs.axis is not None: + axis = call.attrs.axis + # Handle different types of axis + if hasattr(axis, "__iter__") and not isinstance(axis, str): + axis = [int(item.value) if hasattr(item, "value") else int(item) for item in axis] + elif hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + # Call torch.squeeze with appropriate parameters + if axis is not None: + return torch.squeeze(tensor_arg, dim=axis) + else: + return torch.squeeze(tensor_arg) + + def _convert_tensor_ops(self, call: relax.Call, args: List[Any], op_name: str) -> Any: + """Convert tensor operations like concat, split, stack.""" + if len(call.args) < 1: + return f"<{op_name}_error: insufficient arguments>" + + # Convert arguments + converted_args = [self.convert_expr(arg, args) for arg in call.args] + + if op_name == "concat": + # torch.cat(tensors, dim=0) + # In Relax, concat takes a tuple of tensors as first argument + if len(converted_args) == 1 and isinstance(converted_args[0], tuple): + # This is a tuple of tensors + tensors = converted_args[0] + else: + # Direct tensor arguments + tensors = converted_args + axis = 0 + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + return torch.cat(tensors, dim=axis) + + elif op_name == "split": + # torch.split(tensor, split_size_or_sections, dim=0) + tensor = converted_args[0] + split_size = converted_args[1] if len(converted_args) > 1 else 1 + axis = 0 + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + # Handle indices_or_sections parameter + if call.attrs and hasattr(call.attrs, "indices_or_sections"): + indices_or_sections = call.attrs.indices_or_sections + if hasattr(indices_or_sections, "value"): + indices_or_sections = int(indices_or_sections.value) + elif isinstance(indices_or_sections, (int, float)): + indices_or_sections = int(indices_or_sections) + + # If indices_or_sections is an integer, it means split into N equal parts + if isinstance(indices_or_sections, int): + total_size = tensor.shape[axis] + split_size = total_size // indices_or_sections + return torch.split(tensor, split_size, dim=axis) + else: + # If it's a list, use it directly + return torch.split(tensor, indices_or_sections, dim=axis) + else: + return torch.split(tensor, split_size, dim=axis) + + elif op_name == "stack": + # torch.stack(tensors, dim=0) + if len(converted_args) == 1 and isinstance(converted_args[0], tuple): + tensors = converted_args[0] + else: + tensors = converted_args + axis = 0 + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + return torch.stack(tensors, dim=axis) + + else: + return f"<{op_name}_error: unsupported operation>" + + def _convert_reshape(self, call: relax.Call, args: List[Any]) -> Any: + """Convert reshape operation.""" + if len(call.args) < 2: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + shape_arg = call.args[1] + + # Convert shape argument to Python tuple + if isinstance(shape_arg, relax.ShapeExpr): + if hasattr(shape_arg, "values"): + shape = tuple( + int(v.value) if hasattr(v, "value") else int(v) for v in shape_arg.values + ) + else: + shape = (int(shape_arg),) + elif isinstance(shape_arg, relax.Constant): + # Constant tensor case + shape_data = shape_arg.data.numpy() + shape = tuple(int(v) for v in shape_data) + else: + # Try to convert as expression + converted_shape = self.convert_expr(shape_arg, args) + if isinstance(converted_shape, (list, tuple)): + shape = tuple(int(v) for v in converted_shape) + else: + shape = (int(converted_shape),) + + return torch.reshape(tensor_arg, shape) + + def _convert_permute_dims(self, call: relax.Call, args: List[Any]) -> Any: + """Convert permute_dims operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract axes from call.attrs + if call.attrs and hasattr(call.attrs, "axes"): + axes = call.attrs.axes + # Handle TVM Array type + if hasattr(axes, "__iter__") and not isinstance(axes, str): + # Convert TVM Array or Python list/tuple to tuple of ints + axes = tuple(int(v.value) if hasattr(v, "value") else int(v) for v in axes) + elif isinstance(axes, (list, tuple)): + axes = tuple(int(v) for v in axes) + else: + axes = (int(axes),) + else: + return "" + + return torch.permute(tensor_arg, axes) + + def _convert_take(self, call: relax.Call, args: List[Any]) -> Any: + """Convert take operation.""" + if len(call.args) < 2: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + indices_arg = self.convert_expr(call.args[1], args) + + # Extract axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + # Use advanced indexing for specific axis + if axis == 0: + return tensor_arg[indices_arg] + else: + # For other axes, we need to use torch.index_select + return torch.index_select(tensor_arg, dim=axis, index=indices_arg) + else: + # No axis specified, use torch.take (flattens the tensor) + return torch.take(tensor_arg, indices_arg) + + def _convert_flip(self, call: relax.Call, args: List[Any]) -> Any: + """Convert flip operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract axis from call.attrs + axis = None + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + # Convert single axis to list for torch.flip + dims = [axis] + else: + # Default: flip all dimensions + dims = list(range(tensor_arg.dim())) + + return torch.flip(tensor_arg, dims=dims) + + def _convert_tile(self, call: relax.Call, args: List[Any]) -> Any: + """Convert tile operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract repeats from call.attrs + if call.attrs and hasattr(call.attrs, "repeats"): + repeats = call.attrs.repeats + # Handle TVM Array type + if hasattr(repeats, "__iter__") and not isinstance(repeats, str): + repeats = tuple(int(v.value) if hasattr(v, "value") else int(v) for v in repeats) + elif isinstance(repeats, (list, tuple)): + repeats = tuple(int(v) for v in repeats) + else: + repeats = (int(repeats),) + else: + return "" + + return torch.tile(tensor_arg, dims=repeats) + + def _convert_repeat(self, call: relax.Call, args: List[Any]) -> Any: + """Convert repeat operation.""" + if len(call.args) < 1: + return "" + + tensor_arg = self.convert_expr(call.args[0], args) + + # Extract repeats and axis from call.attrs + repeats = 1 + axis = None + + if call.attrs and hasattr(call.attrs, "repeats"): + repeats = call.attrs.repeats + if hasattr(repeats, "value"): + repeats = int(repeats.value) + elif isinstance(repeats, (int, float)): + repeats = int(repeats) + + if call.attrs and hasattr(call.attrs, "axis"): + axis = call.attrs.axis + if hasattr(axis, "value"): + axis = int(axis.value) + elif isinstance(axis, (int, float)): + axis = int(axis) + + if axis is not None: + return torch.repeat_interleave(tensor_arg, repeats=repeats, dim=axis) + else: + return torch.repeat_interleave(tensor_arg, repeats=repeats) + + def _convert_shape_expr(self, shape_expr: relax.ShapeExpr) -> Any: + """Convert a Relax shape expression to Python equivalent.""" + if hasattr(shape_expr, "values"): + return f"" + return f"" + + +def convert_relax_to_pyfunc( + ir_module: IRModule, relax_function_names: Union[str, List[str]] +) -> IRModule: + """Convert Relax functions to Python functions. + + Args: + ir_module: The IRModule containing Relax functions + relax_function_names: Name(s) of Relax functions to convert + + Returns: + IRModule with converted Python functions stored in pyfuncs + + Example: + >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, "my_function") + >>> converted_ir_mod = convert_relax_to_pyfunc(ir_mod, ["func1", "func2"]) + """ + converter = RelaxToPyFuncConverter(ir_module) + return converter.convert(relax_function_names) diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py b/tests/python/relax/test_relax_to_pyfunc_converter.py new file mode 100644 index 000000000000..6dce3093156f --- /dev/null +++ b/tests/python/relax/test_relax_to_pyfunc_converter.py @@ -0,0 +1,866 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Comprehensive test cases for Relax to PyFunc converter. +Tests all major features including basic operations, call_tir, call_dps_packed, and symbolic shapes. +""" + + +import pytest +import torch +import torch.nn.functional as F +import numpy as np + + +import tvm +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script import relax as R +from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter + + +@I.ir_module +class ComprehensiveTestModule: + """Test module covering all converter features.""" + + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + """TIR function for addition.""" + x = T.match_buffer(var_x, (5,), "float32") + y = T.match_buffer(var_y, (5,), "float32") + out = T.match_buffer(var_out, (5,), "float32") + for i in range(5): + out[i] = x[i] + y[i] + + @T.prim_func + def mul_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + """TIR function for multiplication.""" + x = T.match_buffer(var_x, (3, 4), "float32") + y = T.match_buffer(var_y, (3, 4), "float32") + out = T.match_buffer(var_out, (3, 4), "float32") + for i in range(3): + for j in range(4): + out[i, j] = x[i, j] * y[i, j] + + @R.function + def simple_add( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.add(x, y) + + @R.function + def with_relu(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.nn.relu(x) + + @R.function + def with_call_tir( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + cls = ComprehensiveTestModule + return R.call_tir(cls.add_tir, (x, y), out_sinfo=R.Tensor((5,), "float32")) + + @R.function + def with_call_dps_packed(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.call_dps_packed( + "my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), "float32") + ) + + @R.function + def complex_function( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + added = R.add(x, y) + relued = R.nn.relu(added) + cls = ComprehensiveTestModule + tir_result = R.call_tir(cls.add_tir, (relued, y), out_sinfo=R.Tensor((5,), "float32")) + return R.nn.relu(tir_result) + + @R.function + def symbolic_add( + x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32") + ) -> R.Tensor(("n",), "float32"): + return R.add(x, y) + + @R.function + def symbolic_matmul( + x: R.Tensor(("batch", "m", "k"), "float32"), y: R.Tensor(("batch", "k", "n"), "float32") + ) -> R.Tensor(("batch", "m", "n"), "float32"): + return R.matmul(x, y) + + @R.function + def symbolic_expand_dims( + x: R.Tensor(("batch", "seq_len"), "float32") + ) -> R.Tensor(("batch", "seq_len", 1), "float32"): + return R.expand_dims(x, axis=2) + + @R.function + def multi_ops( + x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32") + ) -> R.Tensor((3, 4), "float32"): + added = R.add(x, y) + multiplied = R.multiply(added, y) + powered = R.power(multiplied, R.const(2.0)) + maxed = R.maximum(powered, x) + return maxed + + @R.function + def reduction_ops(x: R.Tensor((5,), "float32")) -> R.Tensor((), "float32"): + sum_val = R.sum(x) + mean_val = R.mean(x) + max_val = R.max(x) + return R.add(R.add(sum_val, mean_val), max_val) + + @R.function + def comparison_ops( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): + eq_val = R.equal(x, y) + gt_val = R.greater(x, y) + return R.logical_and(eq_val, gt_val) + + @R.function + def test_reshape(x: R.Tensor((2, 3), "float32")) -> R.Tensor((6,), "float32"): + return R.reshape(x, (6,)) + + @R.function + def test_permute_dims(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((4, 2, 3), "float32"): + return R.permute_dims(x, axes=[2, 0, 1]) + + @R.function + def test_concat( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((4, 3), "float32"): + return R.concat((x, y), axis=0) + + @R.function + def test_split(x: R.Tensor((4, 3), "float32")) -> R.Tuple: + return R.split(x, indices_or_sections=2, axis=0) + + @R.function + def test_stack( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 2, 3), "float32"): + return R.stack((x, y), axis=1) + + @R.function + def test_take( + x: R.Tensor((3, 4), "float32"), indices: R.Tensor((2,), "int64") + ) -> R.Tensor((2,), "float32"): + return R.take(x, indices, axis=0) + + @R.function + def test_flip(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + return R.flip(x, axis=1) + + @R.function + def test_tile(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 6), "float32"): + return R.tile(x, (2, 2)) + + @R.function + def test_repeat(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 3), "float32"): + return R.repeat(x, repeats=2, axis=0) + + @R.function + def test_expand_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3, 1), "float32"): + return R.expand_dims(x, axis=2) + + @R.function + def test_squeeze(x: R.Tensor((2, 3, 1), "float32")) -> R.Tensor((2, 3), "float32"): + return R.squeeze(x, axis=2) + + @R.function + def test_sum_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.sum(x, axis=0) + + @R.function + def test_max_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.max(x, axis=0) + + +def create_mock_packed_function(): + """Create a mock packed function for testing.""" + + def mock_softmax(x, axis): + """Mock softmax function that just returns the input.""" + return x + + # Register the function globally + tvm.register_func("my_softmax", mock_softmax) + + +class TestRelaxToPyFuncConverter: + """Comprehensive test class for Relax to PyFunc converter.""" + + @classmethod + def setup_class(cls): + """Set up test fixtures.""" + cls.ir_mod = ComprehensiveTestModule + cls.converter = RelaxToPyFuncConverter(cls.ir_mod) + create_mock_packed_function() + + def test_basic_operations(self): + """Test basic arithmetic operations.""" + converted_ir_mod = self.converter.convert(["simple_add", "with_relu"]) + + # Test simple_add + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["simple_add"](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + + # Test with_relu + x_neg = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["with_relu"](x_neg) + expected = torch.nn.functional.relu(x_neg) + assert torch.allclose(result, expected) + + def test_call_tir(self): + """Test call_tir functionality with DLPack conversion.""" + converted_ir_mod = self.converter.convert(["with_call_tir"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["with_call_tir"](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + assert result.shape == expected.shape + + def test_call_dps_packed(self): + """Test call_dps_packed functionality.""" + converted_ir_mod = self.converter.convert(["with_call_dps_packed"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["with_call_dps_packed"](x) + expected = x + assert torch.allclose(result, expected) + + def test_complex_function(self): + """Test complex function with multiple operations.""" + converted_ir_mod = self.converter.convert(["complex_function"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["complex_function"](x, y) + + # Expected: relu(add(relu(add(x, y)), y)) + step1 = torch.add(x, y) + step2 = torch.nn.functional.relu(step1) + step3 = torch.add(step2, y) # TIR call + expected = torch.nn.functional.relu(step3) + + assert torch.allclose(result, expected) + + def test_symbolic_shapes(self): + """Test symbolic shape handling.""" + converted_ir_mod = self.converter.convert( + ["symbolic_add", "symbolic_matmul", "symbolic_expand_dims"] + ) + + # Test symbolic_add + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["symbolic_add"](x, y) + expected = torch.add(x, y) + assert torch.allclose(result, expected) + + # Test symbolic_matmul + x = torch.randn(2, 3, 4, dtype=torch.float32) # (batch=2, m=3, k=4) + y = torch.randn(2, 4, 5, dtype=torch.float32) # (batch=2, k=4, n=5) + result = converted_ir_mod.pyfuncs["symbolic_matmul"](x, y) + expected = torch.matmul(x, y) + assert torch.allclose(result, expected) + assert result.shape == (2, 3, 5) + + # Test symbolic_expand_dims + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["symbolic_expand_dims"](x) + expected = torch.unsqueeze(x, dim=2) + assert torch.allclose(result, expected) + assert result.shape == (2, 2, 1) + + def test_multi_operations(self): + """Test multiple operations in sequence.""" + converted_ir_mod = self.converter.convert(["multi_ops"]) + + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + dtype=torch.float32, + ) + y = torch.tensor( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=torch.float32 + ) + + result = converted_ir_mod.pyfuncs["multi_ops"](x, y) + + # Expected: maximum(power(multiply(add(x, y), y), 2), x) + step1 = torch.add(x, y) + step2 = torch.mul(step1, y) + step3 = torch.pow(step2, 2.0) + expected = torch.maximum(step3, x) + + assert torch.allclose(result, expected) + + def test_reduction_operations(self): + """Test reduction operations.""" + converted_ir_mod = self.converter.convert(["reduction_ops"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["reduction_ops"](x) + + # Expected: sum(x) + mean(x) + max(x) + expected = torch.sum(x) + torch.mean(x) + torch.max(x) + + assert torch.allclose(result, expected) + assert result.shape == () + + def test_comparison_operations(self): + """Test comparison operations.""" + converted_ir_mod = self.converter.convert(["comparison_ops"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([1.0, 2.5, 3.0, 4.5, 5.0], dtype=torch.float32) + + result = converted_ir_mod.pyfuncs["comparison_ops"](x, y) + + # Expected: logical_and(equal(x, y), greater(x, y)) + eq_val = torch.eq(x, y) + gt_val = torch.gt(x, y) + expected = torch.logical_and(eq_val, gt_val) + + assert torch.allclose(result, expected) + assert result.dtype == torch.bool + + def test_operator_mapping_completeness(self): + """Test that operator mapping is comprehensive.""" + operator_map = RelaxToPyFuncConverter._get_op_map() + + # Check that we have a good number of operators + assert len(operator_map) > 100, f"Expected >100 operators, got {len(operator_map)}" + + # Check key operator categories + binary_ops = [ + op + for op in operator_map.keys() + if op.startswith("relax.") and not op.startswith("relax.nn.") + ] + nn_ops = [op for op in operator_map.keys() if op.startswith("relax.nn.")] + + assert len(binary_ops) > 20, f"Expected >20 binary ops, got {len(binary_ops)}" + assert len(nn_ops) > 30, f"Expected >30 nn ops, got {len(nn_ops)}" + + # Check specific important operators + important_ops = [ + "relax.add", + "relax.multiply", + "relax.nn.relu", + "relax.nn.softmax", + "relax.matmul", + "relax.reshape", + "relax.sum", + "relax.mean", + ] + + for op in important_ops: + assert op in operator_map, f"Missing important operator: {op}" + + def test_error_handling(self): + """Test error handling for invalid inputs.""" + converted_ir_mod = self.converter.convert(["simple_add"]) + + # Test with wrong number of arguments + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + + with pytest.raises(ValueError, match="Expected 2 arguments"): + converted_ir_mod.pyfuncs["simple_add"](x) # Missing second argument + + # Test with incompatible shapes - this should raise a RuntimeError + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + y = torch.tensor([1.0, 2.0], dtype=torch.float32) # Different shape + + # This should raise a RuntimeError because shapes don't match + with pytest.raises(RuntimeError, match="The size of tensor a"): + converted_ir_mod.pyfuncs["simple_add"](x, y) + + def test_conversion_metadata(self): + """Test that conversion preserves metadata correctly.""" + converted_ir_mod = self.converter.convert(["simple_add"]) + + # Check that pyfuncs attribute exists + assert hasattr(converted_ir_mod, "pyfuncs") + assert "simple_add" in converted_ir_mod.pyfuncs + + # Check function metadata + pyfunc = converted_ir_mod.pyfuncs["simple_add"] + assert hasattr(pyfunc, "__name__") + assert hasattr(pyfunc, "__doc__") + assert pyfunc.__name__ == "simple_add" + + def test_tensor_operations(self): + """Test tensor manipulation operations.""" + converted_ir_mod = self.converter.convert( + [ + "test_reshape", + "test_permute_dims", + "test_concat", + "test_split", + "test_stack", + "test_take", + "test_flip", + "test_tile", + "test_repeat", + "test_expand_dims", + "test_squeeze", + "test_sum_with_axis", + "test_max_with_axis", + ] + ) + + # Test reshape + x1 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result1 = converted_ir_mod.pyfuncs["test_reshape"](x1) + expected1 = torch.reshape(x1, (6,)) + assert torch.allclose(result1, expected1), "Reshape operation failed" + + # Test permute_dims + x2 = torch.randn(2, 3, 4) + result2 = converted_ir_mod.pyfuncs["test_permute_dims"](x2) + expected2 = torch.permute(x2, (2, 0, 1)) + assert torch.allclose(result2, expected2), "Permute_dims operation failed" + + # Test concat + x3 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + y3 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + result3 = converted_ir_mod.pyfuncs["test_concat"](x3, y3) + expected3 = torch.cat([x3, y3], dim=0) + assert torch.allclose(result3, expected3), "Concat operation failed" + + # Test split + x4 = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], + dtype=torch.float32, + ) + result4 = converted_ir_mod.pyfuncs["test_split"](x4) + expected4 = torch.split(x4, 2, dim=0) + assert len(result4) == len(expected4), "Split operation failed - wrong number of tensors" + for r, e in zip(result4, expected4): + assert torch.allclose(r, e), "Split operation failed - tensor mismatch" + + # Test stack + x5 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + y5 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32) + result5 = converted_ir_mod.pyfuncs["test_stack"](x5, y5) + expected5 = torch.stack([x5, y5], dim=1) + assert torch.allclose(result5, expected5), "Stack operation failed" + + # Test take + x6 = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + dtype=torch.float32, + ) + indices = torch.tensor([0, 2], dtype=torch.int64) + result6 = converted_ir_mod.pyfuncs["test_take"](x6, indices) + expected6 = x6[indices] + assert torch.allclose(result6, expected6), "Take operation failed" + + # Test flip + x7 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result7 = converted_ir_mod.pyfuncs["test_flip"](x7) + expected7 = torch.flip(x7, dims=[1]) + assert torch.allclose(result7, expected7), "Flip operation failed" + + # Test tile + x8 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result8 = converted_ir_mod.pyfuncs["test_tile"](x8) + expected8 = torch.tile(x8, (2, 2)) + assert torch.allclose(result8, expected8), "Tile operation failed" + + # Test repeat + x9 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result9 = converted_ir_mod.pyfuncs["test_repeat"](x9) + expected9 = torch.repeat_interleave(x9, repeats=2, dim=0) + assert torch.allclose(result9, expected9), "Repeat operation failed" + + # Test expand_dims + x10 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result10 = converted_ir_mod.pyfuncs["test_expand_dims"](x10) + expected10 = torch.unsqueeze(x10, dim=2) + assert torch.allclose(result10, expected10), "Expand_dims operation failed" + + # Test squeeze + x11 = torch.tensor([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], dtype=torch.float32) + result11 = converted_ir_mod.pyfuncs["test_squeeze"](x11) + expected11 = torch.squeeze(x11, dim=2) + assert torch.allclose(result11, expected11), "Squeeze operation failed" + + # Test sum with axis + x12 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result12 = converted_ir_mod.pyfuncs["test_sum_with_axis"](x12) + expected12 = torch.sum(x12, dim=0) + assert torch.allclose(result12, expected12), "Sum with axis operation failed" + + # Test max with axis + x13 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + result13 = converted_ir_mod.pyfuncs["test_max_with_axis"](x13) + expected13 = torch.max(x13, dim=0)[0] # torch.max returns (values, indices) + assert torch.allclose(result13, expected13), "Max with axis operation failed" + + +@I.ir_module +class ExtendedOperatorsModule: + """Extended test module with additional operators not covered in ComprehensiveTestModule.""" + + # Unary operations not covered in ComprehensiveTestModule + @R.function + def test_abs(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.abs(x) + + @R.function + def test_neg(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.negative(x) + + @R.function + def test_exp(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.exp(x) + + @R.function + def test_log(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.log(x) + + @R.function + def test_sqrt(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sqrt(x) + + @R.function + def test_sin(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sin(x) + + @R.function + def test_cos(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.cos(x) + + @R.function + def test_tanh(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.tanh(x) + + @R.function + def test_sigmoid(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.sigmoid(x) + + # Comparison operations not covered in ComprehensiveTestModule + @R.function + def test_less( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): + return R.less(x, y) + + @R.function + def test_not_equal( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "bool"): + return R.not_equal(x, y) + + # Binary operations not covered in ComprehensiveTestModule + @R.function + def test_multiply( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.multiply(x, y) + + @R.function + def test_divide( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.divide(x, y) + + @R.function + def test_power( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.power(x, y) + + @R.function + def test_maximum( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.maximum(x, y) + + @R.function + def test_minimum( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.minimum(x, y) + + @R.function + def test_subtract( + x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32") + ) -> R.Tensor((5,), "float32"): + return R.subtract(x, y) + + # Additional tensor operations with different parameters + @R.function + def test_transpose_2d(x: R.Tensor((2, 4), "float32")) -> R.Tensor((4, 2), "float32"): + return R.permute_dims(x, axes=[1, 0]) + + @R.function + def test_mean_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.mean(x, axis=0) + + @R.function + def test_min_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"): + return R.min(x, axis=0) + + # Neural network operations not covered in ComprehensiveTestModule + @R.function + def test_gelu_nn(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"): + return R.nn.gelu(x) + + @R.function + def test_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"): + return R.nn.softmax(x, axis=1) + + @R.function + def test_log_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"): + return R.nn.log_softmax(x, axis=1) + + # Advanced tensor operations with different parameters + @R.function + def test_tile_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 9), "float32"): + return R.tile(x, (2, 3)) + + @R.function + def test_repeat_axis(x: R.Tensor((3,), "float32")) -> R.Tensor((6,), "float32"): + return R.repeat(x, repeats=2, axis=0) + + +class TestExtendedOperators: + """Test class for extended operator coverage.""" + + @classmethod + def setup_class(cls): + """Set up test fixtures.""" + cls.ir_mod = ExtendedOperatorsModule + cls.converter = RelaxToPyFuncConverter(cls.ir_mod) + + def test_unary_operations(self): + """Test unary operations.""" + converted_ir_mod = self.converter.convert( + ["test_abs", "test_neg", "test_exp", "test_log", "test_sqrt"] + ) + + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32) + + # Test abs + result = converted_ir_mod.pyfuncs["test_abs"](x) + expected = torch.abs(x) + assert torch.allclose(result, expected) + + # Test negative + result = converted_ir_mod.pyfuncs["test_neg"](x) + expected = torch.neg(x) + assert torch.allclose(result, expected) + + # Test exp + result = converted_ir_mod.pyfuncs["test_exp"](x) + expected = torch.exp(x) + assert torch.allclose(result, expected) + + # Test log (with positive values) + x_pos = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["test_log"](x_pos) + expected = torch.log(x_pos) + assert torch.allclose(result, expected) + + # Test sqrt + result = converted_ir_mod.pyfuncs["test_sqrt"](x_pos) + expected = torch.sqrt(x_pos) + assert torch.allclose(result, expected) + + def test_trigonometric_operations(self): + """Test trigonometric operations.""" + converted_ir_mod = self.converter.convert( + ["test_sin", "test_cos", "test_tanh", "test_sigmoid"] + ) + + x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0], dtype=torch.float32) + + # Test sin + result = converted_ir_mod.pyfuncs["test_sin"](x) + expected = torch.sin(x) + assert torch.allclose(result, expected) + + # Test cos + result = converted_ir_mod.pyfuncs["test_cos"](x) + expected = torch.cos(x) + assert torch.allclose(result, expected) + + # Test tanh + result = converted_ir_mod.pyfuncs["test_tanh"](x) + expected = torch.tanh(x) + assert torch.allclose(result, expected) + + # Test sigmoid + result = converted_ir_mod.pyfuncs["test_sigmoid"](x) + expected = torch.sigmoid(x) + assert torch.allclose(result, expected) + + def test_comparison_operations(self): + """Test comparison operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert(["test_less", "test_not_equal"]) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) + + # Test less + result = converted_ir_mod.pyfuncs["test_less"](x, y) + expected = torch.lt(x, y) + assert torch.equal(result, expected) + + # Test not equal + result = converted_ir_mod.pyfuncs["test_not_equal"](x, y) + expected = torch.ne(x, y) + assert torch.equal(result, expected) + + def test_binary_operations(self): + """Test binary operations.""" + converted_ir_mod = self.converter.convert( + [ + "test_multiply", + "test_divide", + "test_power", + "test_maximum", + "test_minimum", + "test_subtract", + ] + ) + + x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32) + y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32) + + # Test multiply + result = converted_ir_mod.pyfuncs["test_multiply"](x, y) + expected = torch.mul(x, y) + assert torch.allclose(result, expected) + + # Test divide + result = converted_ir_mod.pyfuncs["test_divide"](x, y) + expected = torch.div(x, y) + assert torch.allclose(result, expected) + + # Test power + result = converted_ir_mod.pyfuncs["test_power"](x, y) + expected = torch.pow(x, y) + assert torch.allclose(result, expected) + + # Test maximum + result = converted_ir_mod.pyfuncs["test_maximum"](x, y) + expected = torch.maximum(x, y) + assert torch.allclose(result, expected) + + # Test minimum + result = converted_ir_mod.pyfuncs["test_minimum"](x, y) + expected = torch.minimum(x, y) + assert torch.allclose(result, expected) + + # Test subtract + result = converted_ir_mod.pyfuncs["test_subtract"](x, y) + expected = torch.sub(x, y) + assert torch.allclose(result, expected) + + def test_tensor_operations(self): + """Test tensor operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert(["test_transpose_2d"]) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) + + # Test transpose + result = converted_ir_mod.pyfuncs["test_transpose_2d"](x) + expected = torch.transpose(x, 0, 1) + assert torch.allclose(result, expected) + assert result.shape == (4, 2) + + def test_reduction_operations(self): + """Test reduction operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert(["test_mean_axis", "test_min_axis"]) + + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32) + + # Test mean + result = converted_ir_mod.pyfuncs["test_mean_axis"](x) + expected = torch.mean(x, dim=0) + assert torch.allclose(result, expected) + assert result.shape == (3,) + + # Test min + result = converted_ir_mod.pyfuncs["test_min_axis"](x) + expected = torch.min(x, dim=0)[0] + assert torch.allclose(result, expected) + assert result.shape == (3,) + + def test_neural_network_operations(self): + """Test neural network operations not covered in ComprehensiveTestModule.""" + converted_ir_mod = self.converter.convert( + ["test_gelu_nn", "test_softmax_nn", "test_log_softmax_nn"] + ) + + x = torch.tensor( + [[-2.0, -1.0, 0.0, 1.0, 2.0], [0.5, 1.5, 2.5, 3.5, 4.5]], dtype=torch.float32 + ) + + # Test gelu + result = converted_ir_mod.pyfuncs["test_gelu_nn"](x[0]) + expected = F.gelu(x[0]) + assert torch.allclose(result, expected) + + # Test softmax + result = converted_ir_mod.pyfuncs["test_softmax_nn"](x) + expected = F.softmax(x, dim=1) + assert torch.allclose(result, expected) + + # Test log_softmax + result = converted_ir_mod.pyfuncs["test_log_softmax_nn"](x) + expected = F.log_softmax(x, dim=1) + assert torch.allclose(result, expected) + + def test_advanced_tensor_operations(self): + """Test advanced tensor operations with different parameters.""" + converted_ir_mod = self.converter.convert(["test_tile_dims", "test_repeat_axis"]) + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32) + + # Test tile with different dimensions + result = converted_ir_mod.pyfuncs["test_tile_dims"](x) + expected = torch.tile(x, (2, 3)) + assert torch.allclose(result, expected) + assert result.shape == (4, 12) + + # Test repeat with different parameters + x_1d = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = converted_ir_mod.pyfuncs["test_repeat_axis"](x_1d) + expected = torch.repeat_interleave(x_1d, repeats=2, dim=0) + assert torch.allclose(result, expected) + assert result.shape == (6,) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])