diff --git a/cmake/modules/contrib/CLML.cmake b/cmake/modules/contrib/CLML.cmake index 118091696a9f..21621bf34c3d 100644 --- a/cmake/modules/contrib/CLML.cmake +++ b/cmake/modules/contrib/CLML.cmake @@ -16,10 +16,10 @@ # under the License. if(USE_CLML) - file(GLOB CLML_RELAY_CONTRIB_SRC src/relay/backend/contrib/clml/*.cc) + file(GLOB CLML_RELAX_CONTRIB_SRC src/relax/backend/contrib/clml/*.cc) file(GLOB CLML_RUNTIME_MODULE src/runtime/contrib/clml/clml_runtime.cc) include_directories(SYSTEM "3rdparty/OpenCL-Headers") - list(APPEND COMPILER_SRCS ${CLML_RELAY_CONTRIB_SRC}) + list(APPEND COMPILER_SRCS ${CLML_RELAX_CONTRIB_SRC}) if(NOT USE_CLML_GRAPH_EXECUTOR) list(APPEND COMPILER_SRCS ${CLML_RUNTIME_MODULE}) endif() diff --git a/python/tvm/relax/backend/__init__.py b/python/tvm/relax/backend/__init__.py index 2a64ffe27b30..f414ae0c5468 100644 --- a/python/tvm/relax/backend/__init__.py +++ b/python/tvm/relax/backend/__init__.py @@ -16,7 +16,7 @@ # under the License. """Relax backends""" -from . import contrib, cpu_generic, cuda, gpu_generic, metal, rocm +from . import contrib, cpu_generic, cuda, gpu_generic, metal, rocm, adreno from .dispatch_sampling import DispatchSampling from .dispatch_sort_scan import DispatchSortScan from .pattern_registry import get_pattern, get_patterns_with_prefix diff --git a/python/tvm/relax/backend/adreno/__init__.py b/python/tvm/relax/backend/adreno/__init__.py new file mode 100644 index 000000000000..b3364f2f4b4a --- /dev/null +++ b/python/tvm/relax/backend/adreno/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax Adreno backend compilation pipeline and other passes.""" +from .pipeline import ( + finalize_passes, + get_default_pipeline, + dataflow_lower_passes, + legalize_passes, + library_dispatch_passes, +) diff --git a/python/tvm/relax/backend/adreno/clml.py b/python/tvm/relax/backend/adreno/clml.py new file mode 100644 index 000000000000..e50ac0dc1d03 --- /dev/null +++ b/python/tvm/relax/backend/adreno/clml.py @@ -0,0 +1,618 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, pointless-exception-statement +"""Pattern table for CLML backend""" +import tvm +from tvm import relax, IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm.relax import transform +from tvm.relax.expr_functor import PyExprMutator, mutator +from tvm.relax.expr import TupleGetItem, VarBinding +from tvm.relax.dpl.pattern import ( + is_const, + is_op, + is_tuple_get_item, + wildcard, +) +from tvm.relax.transform import PatternCheckContext +from ..pattern_registry import register_patterns + + +@mutator +class AppendReshapeToBNRewriter(PyExprMutator): + """ + Append Reshape Operator to BatchNorm Pass Rewriter Pass + + - Automatically appends a reshape operation after BatchNorm operators + - Resolves fusion issues for custom backends where BatchNorm output + might explicitly access the first elment of the Tuple + + Algo: + Identifies BatchNorm operators in the computational graph + When BatchNorm's first output is accessed via TupleGetItem + Automatically inserts a reshape operation to match input shape + + """ + + def __init__(self, mod): + super().__init__(mod) + self.bn_vars = {} + + def visit_tuple_getitem_(self, op: TupleGetItem): + tuple_value = op.tuple_value + reshape_op = tvm.ir.Op.get("relax.reshape") + + if isinstance(tuple_value, relax.Var) and tuple_value in self.bn_vars: + bn_call = self.bn_vars[tuple_value] + if op.index == 0: + bn_out = relax.TupleGetItem(bn_call, 0) + input_shape = bn_call.args[0].struct_info.shape + return relax.Call(reshape_op, [bn_out, input_shape]) + + return super().visit_tuple_getitem_(op) + + def visit_var_binding_(self, binding: VarBinding): + if isinstance(binding.value, relax.Call) and binding.value.op.name == "relax.nn.batch_norm": + self.bn_vars[binding.var] = binding.value + return super().visit_var_binding_(binding) + + +@transform.function_pass(opt_level=0, name="AppendReshapeToBN") +class AppendReshapeToBNRewriterPass: + def transform_function( + self, func: relax.Function, mod: IRModule, _ctx: tvm.transform.PassContext + ) -> relax.Function: + updated_func = AppendReshapeToBNRewriter(mod).visit_expr(func) + updated_func = relax.analysis.remove_all_unused(updated_func) + return updated_func + + +def clml_sdk_version(): + """Utility function to get clml version""" + + return int(tvm.support.libinfo().get("TVM_CLML_VERSION", 2)) + + +def is_clml_runtime_enabled(): + """Check if the CLML graph runtime is present. + + Returns + ------- + ret: bool + True if present, False if not. + """ + check_enabled = tvm.get_global_func("relax.op.is_openclml_runtime_enabled", True) + if check_enabled: + return check_enabled() + return False + + +def _check_default(context: PatternCheckContext) -> bool: + return True + + +def clml_pattern_table(): + """Get the CLML pattern table.""" + + def _check_conv2d(context: PatternCheckContext) -> bool: + if "root" in context.annotated_expr: + root_call = context.annotated_expr["root"] + if root_call.op.name == "relax.nn.conv2d": + input_layout = root_call.attrs.data_layout + weight_layout = root_call.attrs.kernel_layout + if input_layout != "NCHW" or weight_layout != "OIHW": + return False + if root_call.op.name == "relax.nn.conv2d_transpose": + input_layout = root_call.attrs.data_layout + weight_layout = root_call.attrs.kernel_layout + if input_layout != "NCHW" or weight_layout != "OIHW": + return False + + if "data" in context.annotated_expr: + input_expr = context.annotated_expr["data"] + input_dtype = input_expr.struct_info.dtype + if input_dtype not in ["float32", "float16"]: + return False + + if "weight" in context.annotated_expr: + weight_expr = context.annotated_expr["weight"] + weight_dtype = weight_expr.struct_info.dtype + if weight_dtype not in ["float32", "float16"]: + return False + + return True + + def populate_patterns(patterns, name, op, annotations, *args): + ret = {} + for k, v in patterns.items(): + ret_ann = v["annotation"].copy() + ret_ann.update(annotations) + ret[name + "." + k] = {"pattern": op(v["pattern"], *args), "annotation": ret_ann.copy()} + + return ret + + def conv_pattern(): + """Create a convolution pattern.""" + data = wildcard() + weight = wildcard() + bias = is_const() + bn_scale = is_const() + bn_bias = is_const() + bn_mean = is_const() + bn_var = is_const() + + annotations = { + "data": data, + "weight": weight, + } + + patterns = {} + patterns["nn.conv2d"] = { + "pattern": is_op("relax.nn.conv2d")(data, weight), + "annotation": annotations.copy(), + } + + pad_annotations = annotations.copy() + patterns["pad.nn.conv2d"] = { + "pattern": is_op("relax.nn.conv2d")(is_op("relax.nn.pad")(data), weight), + "annotation": pad_annotations, + } + + patterns["nn.conv2d_transpose"] = { + "pattern": is_op("relax.nn.conv2d_transpose")(data, weight), + "annotation": annotations.copy(), + } + patterns.update( + populate_patterns(patterns, "bias", is_op("relax.add"), {"bias": bias}, bias) + ) + patterns.update( + populate_patterns( + patterns, + "bn", + is_op("relax.nn.batch_norm"), + { + "bn_scale": bn_scale, + "bn_bias": bn_bias, + "bn_mean": bn_mean, + "bn_var": bn_var, + }, + bn_scale, + bn_bias, + bn_mean, + bn_var, + ) + ) + tuple_patterns = {} + for k, v in patterns.items(): + tuple_annotation = v["annotation"].copy() + tuple_patterns["tuple" + "." + k] = { + "pattern": is_tuple_get_item(v["pattern"], 0), + "annotation": tuple_annotation, + } + patterns.update(tuple_patterns) + + relu_patterns = populate_patterns(patterns, "relu", is_op("relax.nn.relu"), {}) + clip_patterns = populate_patterns(patterns, "clip", is_op("relax.clip"), {}) + patterns.update(relu_patterns) + patterns.update(clip_patterns) + + conv_patterns = [] + for k, v in patterns.items(): + ret_annotations = v["annotation"] + ret_annotations["root"] = v["pattern"] + conv_patterns.append( + ("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_conv2d) + ) + return conv_patterns[::-1] + + def _check_maxpool2d(context: PatternCheckContext) -> bool: + root = context.annotated_expr.get("root") + if not root or not isinstance(root, relax.Call): + return False + + if root.op.name != "relax.nn.max_pool2d": + return False + + if "data" not in context.annotated_expr: + return False + + data = context.annotated_expr["data"] + input_shape = data.struct_info.shape + + if len(input_shape) != 4: + return False + + if any(dim <= 0 for dim in input_shape): + return False + + pool_size = root.attrs.pool_size + if len(pool_size) != 2: + return False + if any(size <= 0 for size in pool_size): + return False + + strides = root.attrs.strides + if len(strides) != 2: + return False + if any(stride <= 0 for stride in strides): + return False + + dilation = root.attrs.dilation + if len(dilation) != 2: + return False + if any(d <= 0 for d in dilation): + return False + + padding = root.attrs.padding + if len(padding) != 4: + return False + if any(p < 0 for p in padding): + return False + + return True + + def maxpool_pattern(): + + """Create Pool Pattern""" + data = wildcard() + annotations = { + "data": data, + } + patterns = {} + patterns["nn.max_pool2d"] = { + "pattern": is_op("relax.nn.max_pool2d")(data), + "annotation": annotations.copy(), + } + + pool_patterns = [] + for k, v in patterns.items(): + ret_annotations = v["annotation"] + ret_annotations["root"] = v["pattern"] + pool_patterns.append( + ("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_maxpool2d) + ) + return pool_patterns + + def _check_avgpool2d(context: PatternCheckContext) -> bool: + root = context.annotated_expr.get("root") + if not root or not isinstance(root, relax.Call): + return False + + if root.op.name != "relax.nn.avg_pool2d": + return False + + if "data" not in context.annotated_expr: + return False + + data = context.annotated_expr["data"] + input_shape = data.struct_info.shape + + if len(input_shape) != 4: + return False + + if any(dim <= 0 for dim in input_shape): + return False + + pool_size = root.attrs.pool_size + if len(pool_size) != 2: + return False + if any(size <= 0 for size in pool_size): + return False + + strides = root.attrs.strides + if len(strides) != 2: + return False + if any(stride <= 0 for stride in strides): + return False + + padding = root.attrs.padding + if len(padding) != 4: + return False + if any(p < 0 for p in padding): + return False + + return True + + def avgpool_pattern(): + + data = wildcard() + annotations = { + "data": data, + } + patterns = {} + patterns["nn.avg_pool2d"] = { + "pattern": is_op("relax.nn.avg_pool2d")(data), + "annotation": annotations.copy(), + } + + pool_patterns = [] + for k, v in patterns.items(): + ret_annotations = v["annotation"] + ret_annotations["root"] = v["pattern"] + pool_patterns.append( + ("openclml." + (k), v["pattern"], ret_annotations.copy(), _check_avgpool2d) + ) + return pool_patterns + + def _check_global_avgpool(context: PatternCheckContext) -> bool: + + root = context.annotated_expr.get("root") + if not root or not isinstance(root, relax.Call): + return False + + if root.op.name != "relax.mean": + return False + + if "data" not in context.annotated_expr: + return False + + data = context.annotated_expr["data"] + input_shape = data.struct_info.shape + + if len(input_shape) != 4: + return False + + if input_shape[1] <= 0 or input_shape[2] <= 0 or input_shape[3] <= 0: + return False + + if not hasattr(root.attrs, "axis"): + return False + + axis = root.attrs.axis + if not (len(axis) == 2 and axis[0] == 2 and axis[1] == 3): + return False + + return True + + def global_avgpool_pattern(): + + """Create Pool Pattern""" + data = wildcard() + pattern = is_op("relax.mean")(data).has_attr({"axis": [2, 3]}) + + annotations = { + "data": data, + "root": pattern, + } + + return [ + ("openclml.nn.global_avg_pool2d", pattern, annotations, _check_global_avgpool), + ] + + def _check_reshape(context: PatternCheckContext) -> bool: + + root = context.annotated_expr.get("root") + if not root or not isinstance(root, relax.Call): + return False + + if root.op.name != "relax.reshape": + return False + + shape_arg = root.args[1] + if not isinstance(shape_arg, relax.Expr): + return False + + return True + + def reshape_pattern(): + + """Create Reshape Pattern""" + + pattern = is_op("relax.reshape")(wildcard(), wildcard()) + annotations = { + "root": pattern, + } + return [("openclml.reshape", pattern, annotations, _check_reshape)] + + def _check_batchnorm(context: PatternCheckContext) -> bool: + root = context.annotated_expr.get("root") + if not root or not isinstance(root, relax.Call): + return False + + if root.op.name != "relax.reshape": + return False + + required_params = ["moving_var", "gamma", "moving_mean", "beta"] + for param in required_params: + if param not in context.annotated_expr: + return False + + params = { + "moving_var": context.annotated_expr["moving_var"], + "gamma": context.annotated_expr["gamma"], + "moving_mean": context.annotated_expr["moving_mean"], + "beta": context.annotated_expr["beta"], + } + + for param in params.values(): + if not isinstance(param, relax.expr.Constant): + return False + + base_shape = None + for param in params.values(): + shape = param.struct_info.shape + dtype = param.struct_info.dtype + + if dtype not in {"float32"}: + return False + + # Initialize base_shape if not set + if base_shape is None: + base_shape = shape + continue + + # All parameters should have same shape + if len(shape) != len(base_shape): + return False + if any(s1 != s2 for s1, s2 in zip(shape, base_shape)): + return False + + return True + + def batch_norm_pattern(): + """Create a batch norm pattern.""" + data = wildcard() + bn_scale = is_const() + bn_bias = is_const() + bn_mean = is_const() + bn_var = is_const() + + pattern = is_op("relax.nn.batch_norm")(data, bn_scale, bn_bias, bn_mean, bn_var) + pattern = is_tuple_get_item(pattern, 0) + pattern = is_op("relax.reshape")(pattern, wildcard()) + + annotations = { + "gamma": bn_scale, + "beta": bn_bias, + "moving_mean": bn_mean, + "moving_var": bn_var, + "root": pattern, + } + + return [ + ("openclml.nn.batch_norm", pattern, annotations, _check_batchnorm), + ] + + def _check_binary_op(context: PatternCheckContext) -> bool: + def _check_arg(input_expr): + input_dtype = input_expr.struct_info.dtype + input_shape = input_expr.struct_info.shape + if len(input_shape) == 0: + return False + + # Avoid any operators with dtype Int64 + if input_dtype == "int64": + return False + + # No support for batch> 1 + if input_shape[0] > 1: + return False + + return True + + def compare_shapes(lhs_shape, rhs_shape): + if len(lhs_shape) != len(rhs_shape): + return False + for lhs_dim, rhs_dim in zip(lhs_shape, rhs_shape): + if lhs_dim != rhs_dim: + return False + return True + + lhs_shape = None + rhs_shape = None + if "lhs" in context.annotated_expr: + lhs = context.annotated_expr["lhs"] + lhs_shape = lhs.struct_info.shape + if not _check_arg(lhs): + return False + + if "rhs" in context.annotated_expr: + rhs = context.annotated_expr["rhs"] + rhs_shape = rhs.struct_info.shape + if not _check_arg(rhs): + return False + + # Checking for BinaryOps ( False for unaryOp ) + if ( + "lhs" in context.annotated_expr + and "rhs" in context.annotated_expr + and not compare_shapes(lhs_shape, rhs_shape) + ): + + return False + + return True + + def binary_op_pattern(): + """Create a binary op pattern.""" + + def make_pattern(op): + lhs = wildcard() + rhs = wildcard() + pattern = is_op(op)(lhs, rhs) + annotations = {"lhs": lhs, "rhs": rhs} + return ("openclml." + op, pattern, annotations, _check_binary_op) + + binary_ops = [ + "relax.add", + "relax.subtract", + "relax.multiply", + "relax.divide", + "relax.maximum", + "relax.minimum", + ] + + return [make_pattern(op) for op in binary_ops] + + def unary_op_pattern(): + """Create a unary op pattern.""" + + def make_pattern(op): + lhs = wildcard() + pattern = is_op(op)(lhs) + annotations = {"lhs": lhs} + return ("openclml." + op, pattern, annotations, _check_binary_op) + + unary_ops = [ + "relax.nn.softmax", + "relax.nn.relu", + "relax.clip", + ] + + return [make_pattern(op) for op in unary_ops] + + return [ + *conv_pattern(), + *batch_norm_pattern(), + *binary_op_pattern(), + *unary_op_pattern(), + *maxpool_pattern(), + *avgpool_pattern(), + *global_avgpool_pattern(), + *reshape_pattern(), + ] + + +clml_patterns = clml_pattern_table() +register_patterns(clml_patterns) + + +@module_pass(opt_level=0, name="OpenCLMLOffLoad") +class OpenCLMLOffLoad: + """The pass sequence used for CLML offload""" + + def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule: + """The transform""" + + clml_layouts = { + "relax.nn.conv2d": ["NCHW", "OIHW"], + "relax.nn.conv2d_transpose": ["NCHW", "OIHW"], + } + seq = tvm.transform.Sequential( + [ + transform.ConvertLayout(clml_layouts), + transform.Normalize(), + transform.FoldBatchnormToConv2D(), + AppendReshapeToBNRewriterPass(), + transform.FoldConstant(), + transform.FuseOpsByPattern(clml_pattern_table()), + transform.MergeCompositeFunctions(), + transform.RunCodegen(), + ], + ) + mod = seq(mod) + return mod diff --git a/python/tvm/relax/backend/adreno/pipeline.py b/python/tvm/relax/backend/adreno/pipeline.py new file mode 100644 index 000000000000..612b8ce7011d --- /dev/null +++ b/python/tvm/relax/backend/adreno/pipeline.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax Adreno GPU backend compilation pipeline and other passes.""" +import tvm +from tvm import dlight as dl +from tvm import relax + + +def library_dispatch_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default library dispatch passes for Adreno GPU backend.""" + if "clml" in target.keys: + return [relax.backend.adreno.clml.OpenCLMLOffLoad()] + else: + return [] + + +def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default legalization passes for Adreno GPU backend.""" + return [ + relax.transform.DecomposeOpsForInference(), + relax.transform.FoldConstant(), + relax.transform.LegalizeOps(), + relax.transform.AnnotateTIROpPattern(), + relax.transform.FoldConstant(), + relax.transform.FuseOps(), + relax.transform.FuseTIR(), + relax.transform.DeadCodeElimination(), + dl.ApplyDefaultSchedule( + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + ] + + +def dataflow_lower_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default dataflow lowering passes for Adreno GPU backend.""" + return relax.backend.gpu_generic.dataflow_lower_passes(target) + + +def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default finalization passes for Adreno GPU backend.""" + return relax.backend.gpu_generic.finalize_passes(target) + + +def get_default_pipeline(target: tvm.target.Target): + """Return the default compilation pipeline for Adreno GPU.""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext): + with target: + seq = tvm.transform.Sequential( + library_dispatch_passes(target) + + legalize_passes(target) + + dataflow_lower_passes(target) + + finalize_passes(target) + ) + mod = seq(mod) + return mod + + return _pipeline diff --git a/python/tvm/relax/backend/gpu_generic/__init__.py b/python/tvm/relax/backend/gpu_generic/__init__.py index 9c5e65fb49b6..ea2d2a2afb5a 100644 --- a/python/tvm/relax/backend/gpu_generic/__init__.py +++ b/python/tvm/relax/backend/gpu_generic/__init__.py @@ -19,5 +19,6 @@ finalize_passes, get_default_pipeline, legalize_passes, + dataflow_lower_passes, library_dispatch_passes, ) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index ebb61ad3e609..ffb38cdd9370 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -250,6 +250,8 @@ def library_dispatch_passes(target: tvm.target.Target): return backend.gpu_generic.library_dispatch_passes(target) if target.kind.name == "llvm": return backend.cpu_generic.library_dispatch_passes(target) + if target.kind.name == "opencl" and "adreno" in target.keys: + return backend.adreno.library_dispatch_passes(target) # Todo(tvm-team): support gpu-generic raise ValueError(f"Target {target} is not yet supported by library dispatch passes.") @@ -264,6 +266,8 @@ def legalize_passes(target: tvm.target.Target): return backend.gpu_generic.legalize_passes(target) if target.kind.name == "llvm": return backend.cpu_generic.legalize_passes(target) + if target.kind.name == "opencl" and "adreno" in target.keys: + return backend.adreno.legalize_passes(target) # Todo(tvm-team): support gpu-generic raise ValueError(f"Target {target} is not yet supported by library dispatch passes.") @@ -278,6 +282,8 @@ def dataflow_lower_passes(target: tvm.target.Target): return backend.gpu_generic.dataflow_lower_passes(target) if target.kind.name == "llvm": return backend.cpu_generic.dataflow_lower_passes(target) + if target.kind.name == "opencl" and "adreno" in target.keys: + return backend.adreno.dataflow_lower_passes(target) # Todo(tvm-team): support gpu-generic raise ValueError(f"Target {target} is not yet supported by dataflow lowering passes.") @@ -292,6 +298,8 @@ def finalize_passes(target: tvm.target.Target): return backend.gpu_generic.finalize_passes(target) if target.kind.name == "llvm": return backend.cpu_generic.finalize_passes(target) + if target.kind.name == "opencl" and "adreno" in target.keys: + return backend.adreno.finalize_passes(target) # Todo(tvm-team): support gpu-generic raise ValueError(f"Target {target} is not yet supported by finalization passes.") @@ -306,6 +314,8 @@ def get_default_pipeline(target: tvm.target.Target): return backend.gpu_generic.get_default_pipeline(target) if target.kind.name == "llvm": return backend.cpu_generic.get_default_pipeline(target) + if target.kind.name == "opencl" and "adreno" in target.keys: + return backend.adreno.get_default_pipeline(target) # Todo(tvm-team): support gpu-generic raise ValueError( f"Target {target} is not yet supported by default pipeline. " diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 16e4800ca33d..ffdf31975a70 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -94,6 +94,7 @@ from .lazy_transform_params import LazyTransformParams from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage from .optimize_layout_transform import OptimizeLayoutTransform +from .fold_batch_norm_to_conv2d_for_inference import FoldBatchnormToConv2D from .remove_redundant_reshape import RemoveRedundantReshape # Import to register the legalization functions. diff --git a/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py b/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py new file mode 100644 index 000000000000..9680b540cffa --- /dev/null +++ b/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local +"""Relax Fold Batchnorm into Conv2D.""" +from tvm.ir.module import IRModule +from tvm.ir.transform import PassContext +from tvm.relax import Expr +from tvm.relax.dpl import is_op, rewrite_call, wildcard, is_const, TupleGetItemPattern +from tvm import relax, tir + +from . import function_pass + + +@function_pass(opt_level=0) +class FoldBatchnormToConv2D: + """ + Fuse Batchnorm to its previous Conv2D + This optimization is a special case of FoldScaleAxis that folds scale into conv2d weights. + This pass can be removed when FoldScaleAcis enhances to support this case. + """ + + def __init__(self): + self.input = wildcard() + self.weight = is_const() + self.pattern_conv2d = is_op("relax.nn.conv2d")(self.input, self.weight) + self.bn_weight = is_const() + self.bias = is_const() + self.mean = is_const() + self.variance = is_const() + self.pattern_bn = is_op("relax.nn.batch_norm")( + self.pattern_conv2d, self.bn_weight, self.bias, self.mean, self.variance + ) + + self.pattern = TupleGetItemPattern(self.pattern_bn, 0) + + def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule: + """ + Tranformation function for pattern Conv2D+BatchNorm+TupleGetItem pattern + Parameters + ---------- + func: Expr + The relax function to be optimized + mod: IRModule + The ir module + ctx: PassContext + Relax pass context + """ + + self.mod = mod + updated_call = func + + # Skip primitive functions + if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0: + return updated_call + + def rewriter(expr, matches): + conv_input = matches[self.input] + conv_weight = matches[self.weight] + bn_weight = matches[self.bn_weight] + bn_bias = matches[self.bias] + bn_mean = matches[self.mean] + bn_variance = matches[self.variance] + conv_op = matches[self.pattern_conv2d] + bn_op = matches[self.pattern_bn] + conv_attrs = conv_op.attrs + bn_attrs = bn_op.attrs + + bn_variance = relax.op.add( + bn_variance, relax.PrimValue(tir.FloatImm("float32", bn_attrs["epsilon"])) + ) + dino = relax.op.sqrt(bn_variance) + wt = relax.op.divide(bn_weight, dino) + bs = relax.op.subtract(bn_bias, relax.op.multiply(bn_mean, wt)) + if conv_attrs["kernel_layout"] == "OIHW": + wt = relax.op.reshape(wt, shape=(bn_weight.struct_info.shape[0], 1, 1, 1)) + elif conv_attrs["kernel_layout"] == "IOHW": + wt = wt.reshape(1, bn_weight.struct_info.shape[0], 1, 1) + else: + return expr + wt_conv = relax.op.multiply(conv_weight, wt) + bs_args = relax.op.reshape(bs, shape=(1, bn_bias.struct_info.shape[0], 1, 1)) + + conv_out = relax.Call(conv_op.op, (conv_input, wt_conv), conv_attrs) + return relax.op.add(conv_out, bs_args) + + updated_call = rewrite_call(self.pattern, rewriter, func) + + return updated_call diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 1bb883e840cc..9288eb3f97bf 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -70,6 +70,7 @@ riscv_cpu, hexagon, stm32, + adreno, ) from .virtual_device import VirtualDevice from .compilation_config import make_compilation_config diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 81baa57f9eec..d78561eadfd4 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -830,7 +830,7 @@ def stm32(series="unknown", options=None): return Target(" ".join(["c"] + opts)) -def adreno(model="unknown", options=None): +def adreno(model="unknown", options=None, clml=False): """Returns a Qualcomm GPU target. Parameters ---------- @@ -839,7 +839,10 @@ def adreno(model="unknown", options=None): options : str or list of str Additional options """ - opts = ["-device=adreno", "-model=%s" % model] + if clml: + opts = ["-device=adreno", "--keys=adreno,opencl,gpu,clml", "-model=%s" % model] + else: + opts = ["-device=adreno", "--keys=adreno,opencl,gpu", "-model=%s" % model] opts = _merge_opts(opts, options) return Target(" ".join(["opencl"] + opts)) diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc new file mode 100644 index 000000000000..8480ca379a38 --- /dev/null +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/backend/contrib/clml/codegen.cc + * \brief Implementation of the OpenCLML JSON serializer. + */ +#include +#include +#include + +#include +#include +#include + +#include "../../../transform/utils.h" +#include "../codegen_json/codegen_json.h" +#include "../utils.h" + +namespace tvm { +namespace relax { +namespace contrib { + +/*! \brief Attributes to store the compiler options for OpenCLML. */ +struct OpenCLMLCompilerConfigNode : public tvm::AttrsNode { + Integer clml_version; + + TVM_DECLARE_ATTRS(OpenCLMLCompilerConfigNode, "relax.ext.attrs.OpenCLMLCompilerConfigNode") { + TVM_ATTR_FIELD(clml_version) + .describe("OpenCLML version as (major, minor, patch).") + .set_default(Integer(3)); + } +}; + +class OpenCLMLCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(OpenCLMLCompilerConfig, Attrs, + OpenCLMLCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(OpenCLMLCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relax.ext.clml.options", OpenCLMLCompilerConfig); + +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; +using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr; +using OpAttrExtractor = backend::contrib::OpAttrExtractor; +using JSONSerializer = backend::contrib::JSONSerializer; + +class OpenCLMLJSONSerializer; + +/*! + * \brief Collect the constants and attributes from all operator calls in the body + * of a "Composite" function. + */ +class CollectCLMLFromCompositeFunctionBody : public ExprVisitor { + public: + explicit CollectCLMLFromCompositeFunctionBody(OpenCLMLJSONSerializer* serializer) + : serializer_(serializer), node_(std::make_shared()) {} + + void VisitExpr_(const ConstantNode* constant_node) final; + void VisitExpr_(const CallNode* call_node) final; + + void SetGenericAttributes(const CallNode* call_node) { + if (backend::IsOp(call_node, "relax.nn.relu")) { + std::vector activation_type = {"relu"}; + std::vector act_attr; + act_attr.emplace_back(activation_type); + node_->SetAttr("activation_type", act_attr); + } + + OpAttrExtractor extractor(node_); + const Object* attr_obj = call_node->attrs.get(); + extractor.Extract(const_cast(attr_obj)); + } + + OpenCLMLJSONSerializer* serializer_; + /*! \brief Accumulated translated arguments. */ + std::vector args_; + /*! + * \brief Temporary node into which we'll accumulate attributes. Ideally this would be the + * final JSONGraphNode however we don't yet know how many inputs that will have. + */ + JSONGraphObjectPtr node_; +}; + +/*! + * \brief Generates an OpenCLMLModule from a relax expression by serializing the expression to a + * json representation. OpenCLML is not required here because use of OpenCLML APIs is deferred until + * runtime. + */ +class OpenCLMLJSONSerializer : public JSONSerializer { + public: + explicit OpenCLMLJSONSerializer(Map constant_names, Map bindings) + : JSONSerializer(constant_names), bindings_(bindings) {} + + /*! + * \brief A series of operators that form a composite + * convolution. Supports nn.conv2d + */ + struct CompositeConvNode { + const CallNode* pad = nullptr; + const CallNode* conv = nullptr; + const CallNode* bn = nullptr; + const CallNode* bias = nullptr; + const CallNode* activation = nullptr; + std::string act_type; + }; + + using JSONSerializer::VisitExpr_; + + std::vector VisitExpr_(const CallNode* call_node) final { + // The call must be to an inline "Composite" function + const auto* fn_var = call_node->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + + auto opt_composite = fn->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + std::string name = opt_composite.value(); + + std::shared_ptr node; + + if (backend::EndsWithPattern(name, "nn.conv2d") || + backend::EndsWithPattern(name, "nn.pad_conv2d") || + backend::EndsWithPattern(name, "nn.pad_conv2d_transpose")) { + node = CreateCompositeConvJSONNode(call_node); + } else { + // Collect the constants and attributes of all operator calls inside the composite body. + CollectCLMLFromCompositeFunctionBody collector(this); + collector.VisitExpr(fn->body); + + // Capture the args to the "Composite" function as inputs for this node. + std::vector inputs; + for (const auto& arg : call_node->args) { + auto res = VisitExpr(arg); + inputs.insert(inputs.end(), res.begin(), res.end()); + } + // Capture constants from the composite function body as additional inputs for this node. + for (const auto& node : collector.args_) { + inputs.emplace_back(node); + } + + // Create the final node. + node = std::make_shared(name, + /*op_type=*/"kernel", inputs, + /*num_output=*/1); + + // Transfer attributes from the collector's node to the final node. + node->CaptureAttrs(*collector.node_); + + // Capture global settings on the JSON node. + SaveGlobalAttributes(node); + + VLOG(1) << name << " has " << node->GetInputs().size() << " inputs"; + } + + return AddNode(node, GetRef(call_node)); + } + + /*! + * \brief Extract convolution nodes from a composite function. + * + * \param cn The call node of the composite function. + * \return Extracted composite convolution nodes. + */ + CompositeConvNode UnpackCompositeConvolution(const CallNode* cn) { + CompositeConvNode nodes{}; + + const auto* fn_var = cn->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + auto opt_composite = fn->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + + nodes.pad = backend::TryGetOpInFunction(fn, "relax.nn.pad"); + nodes.conv = backend::TryGetOpInFunction(fn, "relax.nn.conv2d"); + + if (!nodes.conv) { + nodes.conv = backend::TryGetOpInFunction(fn, "relax.nn.conv2d_transpose"); + } + ICHECK(nodes.conv) << "No Convolution op found in composite function"; + nodes.bn = backend::TryGetOpInFunction(fn, "relax.nn.batch_norm"); + nodes.bias = backend::TryGetOpInFunction(fn, "relax.add"); + nodes.activation = backend::TryGetOpInFunction(fn, "relax.nn.relu"); + nodes.act_type = "relu"; + return nodes; + } + + /*! + * \brief Create a JSON representation of a composite convolution. + * + * \param cn The call to be represented. + * \return A JSON representation of a specific operator. + */ + std::shared_ptr CreateCompositeConvJSONNode(const CallNode* cn) { + CompositeConvNode nodes = UnpackCompositeConvolution(cn); + + const auto* fn_var = cn->op.as(); + ICHECK(fn_var); + const auto fn = Downcast(bindings_[GetRef(fn_var)]); + auto opt_composite = fn->GetAttr(attr::kComposite); + ICHECK(opt_composite.defined()); + std::string name = opt_composite.value(); + + std::vector inputs; + + inputs.push_back(VisitExpr(cn->args[0])[0]); + inputs.push_back(VisitExpr(nodes.conv->args[1])[0]); + if (nodes.bias) { + inputs.push_back(VisitExpr(nodes.bias->args[1])[0]); + } + // Deal with Batchnorm Fusing here + if (nodes.bn) { + inputs.push_back(VisitExpr(nodes.bn->args[1])[0]); + inputs.push_back(VisitExpr(nodes.bn->args[2])[0]); + inputs.push_back(VisitExpr(nodes.bn->args[3])[0]); + inputs.push_back(VisitExpr(nodes.bn->args[4])[0]); + } + + auto json_node = std::make_shared(name, "kernel", inputs, 1); + SetCallNodeAttribute(json_node, nodes.conv); + + if (nodes.bn) { + const auto* bn_attr = nodes.bn->attrs.as(); + std::vector bn_any_attr; + std::vector bn_args = { + std::to_string(bn_attr->axis), std::to_string(bn_attr->epsilon), + std::to_string(bn_attr->center), std::to_string(bn_attr->scale)}; + bn_any_attr.emplace_back(bn_args); + json_node->SetAttr("batchnorm", bn_any_attr); + } + + // Override attributes + if (nodes.pad) { + const auto* pad_attr = nodes.pad->attrs.as(); + ICHECK(pad_attr); + auto p = pad_attr->pad_width; + // Pad layout for TVM: dimension wise pre and post padding. + // CLML takes dimension wise pre-padding followed by dimension wise post-padding for W, H. + std::vector padding = {std::to_string(p[4].as()->value), + std::to_string(p[6].as()->value), + std::to_string(p[5].as()->value), + std::to_string(p[7].as()->value)}; + std::vector padding_attr; + padding_attr.emplace_back(padding); + json_node->SetAttr("padding", padding_attr); + } + + if (nodes.activation) { + std::vector activation_type = {nodes.act_type}; + std::vector act_attr; + act_attr.emplace_back(activation_type); + json_node->SetAttr("activation_type", act_attr); + } + return json_node; + } + + static void SaveGlobalAttributes(std::shared_ptr node) { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relax.ext.clml.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + std::vector clml_version = {std::to_string(cfg.value()->clml_version.IntValue())}; + std::vector clml_version_attr; + clml_version_attr.emplace_back(clml_version); + node->SetAttr("clml_version", clml_version_attr); + } + + private: + /*! \brief The bindings to look up composite functions. */ + Map bindings_; +}; + +void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const ConstantNode* constant_node) { + for (const auto& entry : serializer_->VisitExpr(GetRef(constant_node))) { + args_.emplace_back(entry); + } +} + +void CollectCLMLFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) { + SetGenericAttributes(call_node); + ExprVisitor::VisitExpr_(call_node); +} + +/*! + * \brief Create runtime modules for OpenCLML. + * \param functions The extern functions to be compiled via OpenCLML + * \return Runtime modules. + */ +Array OpenCLMLCompiler(Array functions, + Map /*unused*/, + Map constant_names) { + Array compiled_functions; + for (const auto& func : functions) { + VLOG(1) << "OpenCLML partition:" << std::endl << func; + OpenCLMLJSONSerializer serializer(constant_names, AnalyzeVar2Value(func)); + serializer.serialize(func); + std::string graph_json = serializer.GetJSON(); + auto constant_names = serializer.GetConstantNames(); + const auto* pf = runtime::Registry::Get("runtime.clml_runtime_create"); + ICHECK(pf != nullptr) << "Cannot find OpenCLML runtime module create function."; + std::string func_name = GetExtSymbol(func); + VLOG(1) << "Creating clml runtime::Module for '" << func_name << "'"; + compiled_functions.push_back((*pf)(func_name, graph_json, constant_names)); + } + return compiled_functions; +} + +TVM_REGISTER_GLOBAL("relax.ext.openclml").set_body_typed(OpenCLMLCompiler); + +/*! + * \brief Check whether OpenCLML graph executor is enabled. + * \return True if enabled, False if not. + */ +inline constexpr bool IsOpenCLMLRuntimeEnabled() { +#if TVM_GRAPH_EXECUTOR_CLML + return true; +#else + return false; +#endif // TVM_GRAPH_EXECUTOR_CLML +} + +/*! + * \brief Get OpenCLML version that TVM is built against. + * \return The OpenCLML SDK version. + */ +Integer GetOpenCLMLVersion() { +#if TVM_GRAPH_EXECUTOR_CLML + return Integer(TVM_CLML_VERSION); +#else + return Integer(3); +#endif // TVM_GRAPH_EXECUTOR_CLML +} + +TVM_REGISTER_GLOBAL("relax.is_openclml_runtime_enabled").set_body_typed(IsOpenCLMLRuntimeEnabled); +TVM_REGISTER_GLOBAL("relax.get_openclml_version").set_body_typed(GetOpenCLMLVersion); + +} // namespace contrib +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index b260ea24bed3..8e214809dd51 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -64,6 +64,17 @@ Map ExtractArgIdx(String pattern_name, Function f) { return arg_idx; } +/*! + * \brief Utility function to find the string pattern in string str + * \param str the main string to check the pattern + * \param pattern the pattern to check in the main string + * \return return true if the main string ends with pattern, false otherwise + */ +bool EndsWithPattern(const std::string& str, const std::string& pattern) { + if (str.length() < pattern.length()) return false; + return str.compare(str.length() - pattern.length(), pattern.length(), pattern) == 0; +} + TVM_REGISTER_GLOBAL("relax.contrib.extract_arg_idx").set_body_typed(ExtractArgIdx); } // namespace backend diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index e0195a61950f..aa3928ce026a 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -111,19 +111,31 @@ inline bool IsOp(const CallNode* call, const std::string& op_name) { * The function must contain exactly one call to such op. * \param f The function to look for an op. * \param op_name The name of the op - * \return A call node which calls an op with the given name + * \return A call node which calls an op with the given name or nullptr if not */ -inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { +inline const CallNode* TryGetOpInFunction(Function f, const std::string& op_name) { auto local_bindings = AnalyzeVar2Value(f); for (const auto& entry : local_bindings) { if (auto call = entry.second.as(); call && backend::IsOp(call, op_name)) { return call; } } - LOG(FATAL) << op_name << " not found in the function:\n" << f; return nullptr; } +/*! + * \brief Return a call node within the function which calls an op with the given name + * The function must contain exactly one call to such op. + * \param f The function to look for an op. + * \param op_name The name of the op + * \return A call node which calls an op with the given name + */ +inline const CallNode* GetOpInFunction(Function f, const std::string& op_name) { + const CallNode* op = TryGetOpInFunction(f, op_name); + ICHECK(op) << op_name << " not found in the function:\n" << f; + return op; +} + /*! * \brief Extract indices of the argument patterns in the function parameter list. * Each composite function pattern can register a mapping between variable names and the @@ -149,6 +161,14 @@ std::string to_str(const Type& value) { return os.str(); } +/*! + * \brief Utility function to find the string pattern in string str + * \param str the main string to check the pattern + * \param pattern the pattern to check in the main string + * \return return true if the main string ends with pattern, false otherwise + */ +bool EndsWithPattern(const std::string& str, const std::string& pattern); + } // namespace backend } // namespace relax } // namespace tvm diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index fa7338177cbe..4998f2b476c5 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -157,7 +157,7 @@ class CLMLRuntime : public JSONRuntimeBase { CLML_CALL(clReleaseMLTuningCacheQCOM, this->layer_.tuning_cache); } for (auto it = this->layer_.storage_map.begin(); it != this->layer_.storage_map.end(); it++) { - auto tensor_desc = it->second.first; + auto tensor_desc = it->second.tensor_desc; CLML_CALL(clReleaseMLTensorQCOM, tensor_desc->tensor) if (this->layer_.ddr_storage_ref_map.find(tensor_desc->memory) != this->layer_.ddr_storage_ref_map.end()) { @@ -278,8 +278,8 @@ class CLMLRuntime : public JSONRuntimeBase { int op_index = 0; for (auto it = this->layer_.storage_map.begin(); it != this->layer_.storage_map.end(); it++) { int nid = it->first; - auto clml_desc = it->second.first; - auto node = it->second.second; + auto clml_desc = it->second.tensor_desc; + auto node = it->second.node; if ("kernel" == node.GetOpType()) { CLML_CALL(clEnqueueMLOpQCOM, queue, this->layer_.function[op_index], @@ -431,6 +431,7 @@ class CLMLRuntime : public JSONRuntimeBase { * \return Status of inference. */ void Run() override { + LOG_CLML << "Run Start"; cl_command_queue queue = CLML_QUEUE; std::vector& evts = cws->workspace->GetEventQueue(cws->tentry->device); for (size_t i = 0; i < input_nodes_.size(); ++i) { @@ -453,9 +454,11 @@ class CLMLRuntime : public JSONRuntimeBase { evts.resize(evts.size() + 1); evt = &(evts.back()); } + LOG_CLML << "Enqueue CLML Copy"; CLML_CALL(clEnqueueCopyMLTensorDataQCOM, queue, layer_.in_placeholder[nid]->tensor, layer_.in_placeholder[nid]->memory, layer_.inputs[nid]->tensor, layer_.inputs[nid]->memory, 0, nullptr, evt); + LOG_CLML << "Enqueue CLML Copy Completed"; } else { DLDataType tvm_dtype = const_cast(data_entry_[eid])->dtype; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); @@ -468,9 +471,11 @@ class CLMLRuntime : public JSONRuntimeBase { } } } + LOG_CLML << "Inputs Set"; int64_t duration = 0; if (cws->is_recordable_queue) { + LOG_CLML << "Execution by Rec Queue"; if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = Registry::Get(std::string("profiling.timer.opencl")); @@ -488,8 +493,10 @@ class CLMLRuntime : public JSONRuntimeBase { 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, nullptr); } } else { + LOG_CLML << "Execution by Normal Queue"; for (size_t i = 0; i < this->layer_.function.size(); ++i) { // Make CLML subgraphs accounted by OpenCLTimerNode. + LOG_CLML << "Run Layer:" << this->layer_.layer_names[i]; if (cws->workspace->IsProfiling(cws->tentry->device)) { Timer t; auto f = Registry::Get(std::string("profiling.timer.opencl")); @@ -514,6 +521,7 @@ class CLMLRuntime : public JSONRuntimeBase { LOG_CLML << "Total Duration for " << clml_symbol << " is:" << duration; } + LOG_CLML << "Run Completed"; for (size_t i = 0; i < outputs_.size(); ++i) { uint32_t eid = EntryID(outputs_[i]); void* data = data_entry_[eid]->data; @@ -548,6 +556,7 @@ class CLMLRuntime : public JSONRuntimeBase { free(tmpptr); } } + LOG_CLML << "Run End"; } private: @@ -611,7 +620,12 @@ class CLMLRuntime : public JSONRuntimeBase { for (size_t nid = 0; nid < nodes_.size(); ++nid) { const auto& node = nodes_[nid]; uint32_t size = 0; - CLML_CALL(clGetMLTensorMemorySizeQCOM, CLML_CTX, layer_.storage_map[nid].first->tensor, + if (this->layer_.storage_map.find(nid) == this->layer_.storage_map.end()) { + // Possible that some nodes are not consumed by any operation + // Example being nn.pad second argument. + continue; + } + CLML_CALL(clGetMLTensorMemorySizeQCOM, CLML_CTX, layer_.storage_map[nid].tensor_desc->tensor, &size); if ((node.GetOpType() == "kernel") || (node.GetOpType() == "input")) { @@ -686,34 +700,57 @@ class CLMLRuntime : public JSONRuntimeBase { const JSONGraphNode node = nodes_[nid]; cl_ml_tensor_usage_qcom usage = CL_TENSOR_USAGE_CNN_QCOM; - if (this->layer_.storage_map.find(nid) == this->layer_.storage_map.end()) { - void* node_data = nullptr; - if (node.GetOpType() == "const") { - uint32_t eid = EntryID(nid, 0); - node_data = data_entry_[eid]->data; - usage = CL_TENSOR_USAGE_PARAMETER_QCOM; + if (this->layer_.storage_map.find(nid) != this->layer_.storage_map.end()) { + if (nullptr != layer_.storage_map[nid].tensor_desc) { + return this->layer_.storage_map[nid].tensor_desc; } + } else { + this->layer_.storage_map.insert({nid, NodeDescriptor()}); + this->layer_.storage_map[nid].node = node; + } - auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, usage, dtype, node_data, shape); - this->layer_.storage_map.insert({nid, std::make_pair(clml_tensor, node)}); + void* node_data = nullptr; + if (node.GetOpType() == "const") { + uint32_t eid = EntryID(nid, 0); + node_data = data_entry_[eid]->data; + usage = CL_TENSOR_USAGE_PARAMETER_QCOM; + ICHECK(CL_TENSOR_USAGE_INVALID_QCOM == this->layer_.storage_map[nid].usage) + << "Parameter have usage reservation !!!"; + } + if (CL_TENSOR_USAGE_INVALID_QCOM != this->layer_.storage_map[nid].usage) { + // Respect special reservation on usage. + usage = this->layer_.storage_map[nid].usage; + } else { + this->layer_.storage_map[nid].usage = usage; + } + if (this->layer_.storage_map[nid].custom_layout) { + // Respect special reservation on layout. + layout = this->layer_.storage_map[nid].layout; + } else { + this->layer_.storage_map[nid].layout = layout; + } - if ("input" == node.GetOpType()) { - this->layer_.inputs.insert({nid, this->layer_.storage_map[nid].first}); - // Input copy placeholder Tensor - if (layout == CL_TENSOR_LAYOUT_OPTIMAL_QCOM) { - this->layer_.in_placeholder.insert( - {nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, usage, dtype, - node_data, shape)}); - } else { - this->layer_.in_placeholder.insert( - {nid, MakeCLMLTensorFromJSONNode(node, layout, usage, dtype, node_data, shape)}); - } + auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, usage, dtype, node_data, shape); + + this->layer_.storage_map[nid].tensor_desc = clml_tensor; + this->layer_.storage_map[nid].usage = usage; + this->layer_.storage_map[nid].layout = layout; + LOG_CLML << "Storage Map Alloc:" << nid << " Name:" << node.GetOpName() << " Usage: " << usage + << " Layout:" << layout; + + if ("input" == node.GetOpType()) { + this->layer_.inputs.insert({nid, this->layer_.storage_map[nid].tensor_desc}); + // Input copy placeholder Tensor + if (layout == CL_TENSOR_LAYOUT_OPTIMAL_QCOM) { + this->layer_.in_placeholder.insert( + {nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, usage, dtype, + node_data, shape)}); + } else { + this->layer_.in_placeholder.insert( + {nid, MakeCLMLTensorFromJSONNode(node, layout, usage, dtype, node_data, shape)}); } - - return clml_tensor; - } else { - return this->layer_.storage_map[nid].first; } + return clml_tensor; } /*! @@ -730,7 +767,8 @@ class CLMLRuntime : public JSONRuntimeBase { const auto& node = nodes_[nid]; if ("nn.dense" == node.GetOpName()) CreateDenseLayerTensor(&layer_, node, nid); if ("nn.batch_matmul" == node.GetOpName()) CreateBatchMatmulLayerTensor(&layer_, node, nid); - if ("nn.softmax" == node.GetOpName()) CreateSoftmaxLayerTensor(&layer_, node, nid); + if ("nn.softmax" == node.GetOpName() || PatternMatch(node.GetOpName(), "nn.softmax")) + CreateSoftmaxLayerTensor(&layer_, node, nid); } for (nid = 0; nid < nodes_.size(); ++nid) { @@ -739,30 +777,33 @@ class CLMLRuntime : public JSONRuntimeBase { // Layers may request for different layout. Differ the input allocation. } else if (node.GetOpType() == "kernel") { auto op_name = node.GetOpName(); - if ("nn.conv2d" == op_name) + if (PatternMatch(op_name, "nn.conv2d") || PatternMatch(op_name, "nn.pad_conv2d")) CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_CONVOLUTION_QCOM, nid); - else if ("nn.depthwise_conv2d" == op_name) + else if (PatternMatch(op_name, "nn.depthwise_conv2d")) CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_DEPTHWISE_QCOM, nid); - else if ("nn.conv2d_transpose" == op_name) + else if (PatternMatch(op_name, "nn.conv2d_transpose")) CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_TRANSPOSE_QCOM, nid); - else if ("nn.relu6" == op_name) + else if ("nn.relu6" == op_name || PatternMatch(op_name, "nn.relu6")) CreateReLULayer(&layer_, node, nid, CL_ACTIVATION_RELU6); - else if ("nn.relu" == op_name) + else if (PatternMatch(op_name, "nn.relu")) CreateReLULayer(&layer_, node, nid, CL_ACTIVATION_RELU); - else if ("nn.batch_norm" == op_name) + else if (PatternMatch(op_name, "nn.batch_norm")) CreateBatchNormLayer(&layer_, node, nid); else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name || - "nn.l2_pool2d" == op_name) + "nn.l2_pool2d" == op_name || PatternMatch(op_name, "nn.max_pool2d") || + PatternMatch(op_name, "nn.avg_pool2d")) CreatePoolingLayer(&layer_, node, nid); - else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" == op_name) + else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" == op_name || + PatternMatch(op_name, "nn.global_avg_pool2d") || + PatternMatch(op_name, "nn.global_max_pool2d")) CreateGlobalPoolingLayer(&layer_, node, nid); - else if ("reshape" == op_name) + else if ("reshape" == op_name || PatternMatch(op_name, "reshape")) CreateReshapeLayer(&layer_, node, nid); else if ("concatenate" == op_name) CreateConcatLayer(&layer_, node, nid); else if ("nn.dense" == op_name) CreateDenseLayer(&layer_, node, nid); - else if ("nn.softmax" == op_name) + else if ("nn.softmax" == op_name || PatternMatch(op_name, "nn.softmax")) CreateSoftMaxLayer(&layer_, node, nid); else if ("nn.pad" == op_name) CreatePadLayer(&layer_, node, nid); @@ -771,7 +812,11 @@ class CLMLRuntime : public JSONRuntimeBase { else if ("clip" == op_name) CreateClipLayer(&layer_, node, nid); else if ("add" == op_name || "subtract" == op_name || "multiply" == op_name || - "minimum" == op_name || "maximum" == op_name || "divide" == op_name) + "minimum" == op_name || "maximum" == op_name || "divide" == op_name || + PatternMatch(op_name, "relax.add") || PatternMatch(op_name, "relax.subtract") || + PatternMatch(op_name, "relax.multiply") || + PatternMatch(op_name, "relax.minimum") || PatternMatch(op_name, "relax.maximum") || + PatternMatch(op_name, "relax.divide")) CreateBinaryLayer(&layer_, node, nid); else if ("nn.depth_to_space" == op_name) CreateDepthToSpaceLayer(&layer_, node, nid); @@ -793,7 +838,7 @@ class CLMLRuntime : public JSONRuntimeBase { nid = outputs_[i].id_; DLDataType tvm_dtype = nodes_[nid].GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); - this->layer_.outputs.push_back(this->layer_.storage_map[nid].first); + this->layer_.outputs.push_back(this->layer_.storage_map[nid].tensor_desc); if (this->layer_.out_shapes.find(nid) != this->layer_.out_shapes.end()) { // Handle customized shapes here this->layer_.out_placeholder.push_back(MakeCLMLTensorFromJSONNode( @@ -814,12 +859,12 @@ class CLMLRuntime : public JSONRuntimeBase { size_t alloc_ddr = 0; size_t alloc_ddr_reuse = 0; for (auto it = this->layer_.storage_map.begin(); it != this->layer_.storage_map.end(); it++) { - auto tensor_desc = it->second.first; + auto tensor_desc = it->second.tensor_desc; uint32_t mem_size = 0; result = CL_OUT_OF_HOST_MEMORY; CLML_CALL(clGetMLTensorMemorySizeQCOM, CLML_CTX, tensor_desc->tensor, &mem_size); - JSONGraphNode node = it->second.second; + JSONGraphNode node = it->second.node; void* node_data = nullptr; size_t on_chip_mem_offset = -1; if (layer_.on_chip_alloc_plan.find(it->first) != layer_.on_chip_alloc_plan.end()) { @@ -939,6 +984,7 @@ class CLMLRuntime : public JSONRuntimeBase { std::vector strides = node.GetAttr>("strides"); std::vector dilation = node.GetAttr>("dilation"); std::vector clml_padding = GetVectorValues(padding); + DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, cl_dtype); @@ -946,6 +992,7 @@ class CLMLRuntime : public JSONRuntimeBase { clml_padding.resize(4); std::fill(clml_padding.begin(), clml_padding.end(), 0); } + cl_uint clml_padding_b[CL_ML_TENSOR_MAX_SPATIAL_DIMS_QCOM] = {clml_padding[0], clml_padding[1]}; cl_uint clml_padding_a[CL_ML_TENSOR_MAX_SPATIAL_DIMS_QCOM] = {clml_padding[2], clml_padding[3]}; std::vector v_strides = GetVectorValues(strides); @@ -982,8 +1029,8 @@ class CLMLRuntime : public JSONRuntimeBase { size_t num_inputs = inputs.size(); bool has_bias; bool has_bn; - ICHECK(num_inputs >= 2U && num_inputs <= 7U) - << "Batchnorm fused convolution requires bax 7 arguments"; + ICHECK(num_inputs >= 2 && num_inputs <= 7) + << "Batchnorm fused convolution requires max 7 arguments"; has_bias = (num_inputs == 3) || (num_inputs == 7); has_bn = (num_inputs == 6) || (num_inputs == 7); // Input @@ -1032,6 +1079,12 @@ class CLMLRuntime : public JSONRuntimeBase { int bn_index = has_bias ? 3 : 2; int axis = std::stoi(node.GetAttr>("batchnorm")[0]); auto bn_dims = GetTensorDims(nodes_[inputs[bn_index].id_]); + float epsilon = std::stof(node.GetAttr>("batchnorm")[1]); + + std::vector opProperties; + opProperties.push_back(CL_ML_BATCH_NORM_OP_EPSILON_QCOM); + opProperties.push_back(*reinterpret_cast(&epsilon)); + opProperties.push_back(CL_ML_OP_PROPERTY_LIST_END_QCOM); std::vector bn_shape = {1, 1, 1, 1}; bn_shape[axis] = bn_dims.n; auto bn_mean = std::make_shared(); @@ -1049,15 +1102,15 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, cl_arithmetic_mode}; if (!has_act) { - CLML_CALL(clCreateMLOpFusedConvolutionBatchNormForwardQCOM, CLML_CTX, nullptr, &conv_desc, - &bn_desc, input->tensor, weight->tensor, bias->tensor, output->tensor, + CLML_CALL(clCreateMLOpFusedConvolutionBatchNormForwardQCOM, CLML_CTX, opProperties.data(), + &conv_desc, &bn_desc, input->tensor, weight->tensor, bias->tensor, output->tensor, bn_mean->tensor, bn_var->tensor, bn_scale->tensor, bn_bias->tensor, &op, layer_.tuning_cache); } else { - CLML_CALL(clCreateMLOpFusedConvolutionBatchNormActivationForwardQCOM, CLML_CTX, nullptr, - &conv_desc, &bn_desc, &act_desc, input->tensor, weight->tensor, bias->tensor, - output->tensor, nullptr, bn_mean->tensor, bn_var->tensor, bn_scale->tensor, - bn_bias->tensor, &op, layer_.tuning_cache); + CLML_CALL(clCreateMLOpFusedConvolutionBatchNormActivationForwardQCOM, CLML_CTX, + opProperties.data(), &conv_desc, &bn_desc, &act_desc, input->tensor, + weight->tensor, bias->tensor, output->tensor, nullptr, bn_mean->tensor, + bn_var->tensor, bn_scale->tensor, bn_bias->tensor, &op, layer_.tuning_cache); } layer->function.push_back(op); } @@ -1176,8 +1229,9 @@ class CLMLRuntime : public JSONRuntimeBase { std::vector clml_padding = GetVectorValues(padding); cl_ml_op_pooling_desc_qcom pool_desc = { - node.GetOpName() == "nn.max_pool2d" ? CL_POOLING_MODE_MAX_QCOM - : CL_POOLING_MODE_AVERAGE_EXCLUDE_PADDING_QCOM, + ((node.GetOpName() == "nn.max_pool2d") || PatternMatch(node.GetOpName(), "nn.max_pool2d")) + ? CL_POOLING_MODE_MAX_QCOM + : CL_POOLING_MODE_AVERAGE_EXCLUDE_PADDING_QCOM, 4, // reserved {clml_padding[0], clml_padding[1]}, {clml_padding[2], clml_padding[3]}, @@ -1221,8 +1275,10 @@ class CLMLRuntime : public JSONRuntimeBase { auto output = MakeCLMLTensorFromJSONEntry(nid, {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]); cl_ml_op_pooling_desc_qcom pool_desc = { - node.GetOpName() == "nn.global_max_pool2d" ? CL_POOLING_MODE_MAX_QCOM - : CL_POOLING_MODE_AVERAGE_EXCLUDE_PADDING_QCOM, + ((node.GetOpName() == "nn.global_max_pool2d") || + PatternMatch(node.GetOpName(), "nn.global_max_pool2d")) + ? CL_POOLING_MODE_MAX_QCOM + : CL_POOLING_MODE_AVERAGE_EXCLUDE_PADDING_QCOM, 4, // reserved {0, 0}, {0, 0}, @@ -1252,7 +1308,6 @@ class CLMLRuntime : public JSONRuntimeBase { * \param node The JSON representation of the operator. * \param nid The node index of JSON graph node, which points to this operator. */ - void CreateSoftmaxLayerTensor(CachedLayer* layer, const JSONGraphNode& node, size_t nid) { cl_ml_tensor_layout_qcom layout; DLDataType tvm_dtype = node.GetOpDataType()[0]; @@ -1664,19 +1719,23 @@ class CLMLRuntime : public JSONRuntimeBase { auto output = MakeCLMLTensorFromJSONEntry(nid, {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); std::string op_name = node.GetOpName(); cl_binary_op_qcom binary_op = CL_TENSOR_OP_ADD_QCOM; - if (op_name == "subtract") + if (op_name == "subtract" || PatternMatch(op_name, "relax.subtract")) binary_op = CL_TENSOR_OP_SUB_QCOM; - else if (op_name == "multiply") + else if (op_name == "multiply" || PatternMatch(op_name, "relax.multiply")) binary_op = CL_TENSOR_OP_MUL_QCOM; - else if (op_name == "divide") + else if (op_name == "divide" || PatternMatch(op_name, "relax.divide")) binary_op = CL_TENSOR_OP_DIV_QCOM; - else if (op_name == "minimum") + else if (op_name == "minimum" || PatternMatch(op_name, "relax.minimum")) binary_op = CL_TENSOR_OP_MIN_QCOM; - else if (op_name == "maximum") + else if (op_name == "maximum" || PatternMatch(op_name, "relax.maximum")) binary_op = CL_TENSOR_OP_MAX_QCOM; + else if (op_name == "add" || PatternMatch(op_name, "relax.add")) + binary_op = CL_TENSOR_OP_ADD_QCOM; + else + LOG(FATAL) << "Undefined binary op:" << op_name; cl_ml_op_binary_desc_qcom add_desc = { binary_op, {{1.0}, CL_FLOAT}, {{1.0}, CL_FLOAT}, {{0.0}, CL_FLOAT}, cl_arithmetic_mode}; - + LOG(INFO) << "Op name - " << op_name; CLML_CALL(clCreateMLOpBinaryQCOM, CLML_CTX, nullptr, &add_desc, input_a->tensor, input_b->tensor, output->tensor, &op, layer_.tuning_cache); ICHECK(op) << op_name << " Node Error"; diff --git a/src/runtime/contrib/clml/clml_runtime.h b/src/runtime/contrib/clml/clml_runtime.h index 9dfde2f7820d..faada2ddeeb5 100644 --- a/src/runtime/contrib/clml/clml_runtime.h +++ b/src/runtime/contrib/clml/clml_runtime.h @@ -230,6 +230,18 @@ class CLMLThreadEntry { static CLMLThreadEntry* ThreadLocal(); }; +/*! + * \brief Node descriptor to hold various information related to a Node. + */ +struct NodeDescriptor { + std::shared_ptr tensor_desc = nullptr; + JSONGraphNode node; + // Check the flag and them pick the layout. + bool custom_layout = false; + cl_ml_tensor_layout_qcom layout; + cl_ml_tensor_usage_qcom usage = CL_TENSOR_USAGE_INVALID_QCOM; +}; + /*! * \brief CLML objects we cache in order to avoid needing to construct * a new layer each time. @@ -249,9 +261,8 @@ struct CachedLayer { std::vector> out_placeholder; /* Tensor shape exception list while returning from CLML Subgraph */ std::map> out_shapes; - /* Map of all tensors which need backing memory allocation */ - std::map, JSONGraphNode>> - storage_map; + /* Map of nodeid and descriptors */ + std::map storage_map; /* Tensor memory descriptors list to set after backing memory allocation */ std::vector tensorMemDescs; cl_ml_tensor_mem_desc_set_qcom descriptorSet; diff --git a/src/runtime/contrib/clml/clml_utils.cc b/src/runtime/contrib/clml/clml_utils.cc index 354bd104b81f..557815dfa172 100644 --- a/src/runtime/contrib/clml/clml_utils.cc +++ b/src/runtime/contrib/clml/clml_utils.cc @@ -240,6 +240,17 @@ std::vector GetVectorValues(const std::vector& val) { return array; } +/*! + * \brief Utility function to find the string pattern in string str + * \param str the main string to check the pattern + * \param pattern the pattern to check in the main string + * \return return true if the main string ends with pattern, false otherwise + */ +bool PatternMatch(const std::string& str, const std::string& pattern) { + if (str.length() < pattern.length()) return false; + return str.compare(str.length() - pattern.length(), pattern.length(), pattern) == 0; +} + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/clml/clml_utils.h b/src/runtime/contrib/clml/clml_utils.h index 2051793cf18b..0496878840d7 100644 --- a/src/runtime/contrib/clml/clml_utils.h +++ b/src/runtime/contrib/clml/clml_utils.h @@ -68,6 +68,8 @@ std::shared_ptr MakeCLMLTensorFromJSONNode( std::vector GetVectorValues(const std::vector& val); +bool PatternMatch(const std::string& str, const std::string& pattern); + } // namespace contrib } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/backend/clml/conftest.py b/tests/python/relax/backend/clml/conftest.py new file mode 100644 index 000000000000..00bad5da216f --- /dev/null +++ b/tests/python/relax/backend/clml/conftest.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import tvm +import pytest +from tvm import rpc as _rpc + + +@pytest.fixture(scope="session") +def rpc(): + rpc_target = os.getenv("RPC_TARGET", None) + if rpc_target: + connection_type = "tracker" + host = os.getenv("TVM_TRACKER_HOST", "localhost") + port = int(os.getenv("TVM_TRACKER_PORT", 9090)) + target = "opencl" + target_host = "llvm -mtriple=aarch64-linux-gnu" + device_key = os.getenv("RPC_DEVICE_KEY", "android") + cross_compile = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++") + tracker = _rpc.connect_tracker(host, port) + return tracker.request(device_key, priority=1, session_timeout=1000) + else: + return None diff --git a/tests/python/relax/backend/clml/mod_utils.py b/tests/python/relax/backend/clml/mod_utils.py new file mode 100644 index 000000000000..1efbf40c5c40 --- /dev/null +++ b/tests/python/relax/backend/clml/mod_utils.py @@ -0,0 +1,728 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CLML integration operator tests.""" +import pytest +import numpy as np +import tvm +import tvm.testing +import json + +from tvm import relax, rpc +from tvm.script import relax as R +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder +from tvm.relax.backend.adreno import clml + + +def get_relax_conv2d_mod( + data_shape, + weight_shape, + stride, + dilation, + padding, + weight_layout="OIHW", + groups=1, + dtype="float32", + has_bias=False, + has_bn=False, + has_activation=False, + has_pad=False, + is_depthwise=False, +): + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + if has_pad: + p = (0, 0, 0, 0, padding[0], padding[0], padding[1], padding[1]) + orig_data = R.arg("data", R.Tensor(data_shape, dtype)) + data = R.nn.pad(orig_data, pad_width=p, pad_value=0.0) + padding = (0, 0, 0, 0) + else: + data = R.arg("data", R.Tensor(data_shape, dtype)) + weight = R.arg("weight", R.Tensor(weight_shape, dtype)) + if has_bias: + bias = R.arg("bias", R.Tensor((1, weight_shape[0], 1, 1), dtype)) + + is_depthwise = data_shape[1] == weight_shape[0] == groups + + with R.dataflow() as frame: + output = R.emit( + R.nn.conv2d( + data, + weight, + out_dtype=dtype, + strides=stride, + dilation=dilation, + padding=padding, + data_layout="NCHW", + kernel_layout=weight_layout, + groups=groups, + ) + ) + if has_bias: + output = R.emit(output + bias) + if has_bn: + gamma = R.arg("gamma", R.Tensor((weight_shape[0],), dtype)) + beta = R.arg("beta", R.Tensor((weight_shape[0],), dtype)) + mean = R.arg("mean", R.Tensor((weight_shape[0],), dtype)) + variance = R.arg("variance", R.Tensor((weight_shape[0],), dtype)) + output = R.emit( + R.nn.batch_norm(output, gamma, beta, mean, variance, axis=1, epsilon=1e-5)[ + 0 + ] + ) + if has_activation: + output = R.emit(R.nn.relu(output)) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_clml_conv2d_codegen( + data_shape, + weight_shape, + stride, + dilation, + padding, + weight_layout="OIHW", + groups=1, + dtype="float32", + has_bias=False, + has_bn=False, + has_activation=False, + has_pad=False, + is_depthwise=False, +): + kernel_h, kernel_w = weight_shape[2], weight_shape[3] + channels = weight_shape[0] + if len(padding) == 2: + padding = (padding[0], padding[1], padding[0], padding[1]) + output_height = ((data_shape[2] - kernel_h + padding[0] + padding[2]) / stride[0]) + 1 + output_width = ((data_shape[3] - kernel_w + padding[1] + padding[3]) / stride[1]) + 1 + output_shape = (1, channels, int(output_height), int(output_width)) + out_dtype = dtype + is_depthwise = data_shape[1] == channels == groups + + weight_layout = "IOHW" if is_depthwise else "OIHW" + if weight_layout == "OIHW": + weight_shape = (channels, data_shape[1] // groups, kernel_h, kernel_w) + else: + weight_shape = (data_shape[1] // groups, channels, kernel_h, kernel_w) + + if is_depthwise: + name = "openclml.nn.depthwise_conv2d" + else: + name = "openclml.nn.conv2d" + + node = { + "op": "kernel", + "name": "", + "inputs": [], + "attrs": { + "groups": [[str(groups)]], + "num_outputs": "1", + "data_layout": [["NCHW"]], + "kernel_layout": [[weight_layout]], + "dilation": [[str(dilation[0]), str(dilation[1])]], + "out_layout": [["NCHW"]], + "out_dtype": [[out_dtype]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "padding": [[str(p) for p in padding]], + "strides": [[str(s) for s in stride]], + }, + } + + if has_activation: + node["attrs"]["activation_type"] = [["relu"]] + + nodes = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(data_shape)]], "dtype": [[str(dtype)]]}, + }, + ] + + nodes.append( + { + "op": "const", + "name": "", + "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]}, + } + ) + + if has_bias: + bias_dtype = dtype + nodes.append( + { + "op": "const", + "name": "", + "attrs": { + "shape": [[[1, weight_shape[1] if is_depthwise else weight_shape[0], 1, 1]]], + "dtype": [[bias_dtype]], + }, + } + ) + + if has_bn: + bn_shape = [[1, weight_shape[0], 1, 1]] + # conv2d + bn --> conv2d + Add due to OptimizeBatchNorm transformation Pass + nodes.append( + { + "name": "", + "op": "const", + "attrs": {"dtype": [[dtype]], "shape": [[[1, weight_shape[0], 1, 1]]]}, + }, + ) + + input_idx = 0 + for _ in range(len(nodes)): + node["inputs"].append([input_idx, 0, 0]) + input_idx += 1 + node["attrs"]["num_inputs"] = str(len(nodes)) + nodes.append(node) + return nodes + + +def get_relax_conv2d_transpose_mod( + data_shape, + weight_shape, + channels, + stride, + padding, + dtype="float32", +): + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + data = R.arg("data", R.Tensor(data_shape, dtype)) + weight = R.arg("weight", R.Tensor(weight_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit( + R.nn.conv2d_transpose( + data, + weight, + groups=1, + strides=stride, + padding=padding, + kernel_layout="OIHW", + data_layout="NCHW", + ) + ) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_conv2d_transpose_expected_codegen( + dshape, kshape, channels, kernel_size, strides, padding, dilation, dtype, output_shape +): + attrs = { + "data_layout": [["NCHW"]], + "kernel_layout": [["OIHW"]], + "groups": [["1"]], + "clml_version": [["3"]], + "dilation": [[str(p) for p in dilation]], + "num_inputs": "2", + "num_outputs": "1", + "padding": [[str(p) for p in padding]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in strides]], + "out_dtype": [[""]], + "out_layout": [["NCHW"]], + "output_padding": [["0", "0"]], + } + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(dshape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "const", + "name": "", + "attrs": {"shape": [[list(kshape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "kernel", + "name": "", + "inputs": [[0, 0, 0], [1, 0, 0]], + "attrs": attrs, + }, + ] + return exp_codegen + + +def get_batchnorm_mod(data_shape, channels, axis, epsilon, dtype): + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + data = R.arg("data", R.Tensor(data_shape, dtype)) + gamma = R.arg("gamma", R.Tensor((channels,), dtype)) + beta = R.arg("beta", R.Tensor((channels,), dtype)) + mean = R.arg("moving_mean", R.Tensor((channels,), dtype)) + variance = R.arg("moving_var", R.Tensor((channels,), dtype)) + with R.dataflow() as frame: + output = R.emit( + R.nn.batch_norm(data, gamma, beta, mean, variance, axis, epsilon)[0] + ) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_binary_op_mod(a_shape, b_shape, op, dtype): + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + a = R.arg("a", R.Tensor(a_shape, dtype)) + b = R.arg("b", R.Tensor(b_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit(op(a, b)) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + + low, high = 0, 1 + a_data = np.random.uniform(low, high, size=(a_shape)).astype(dtype) + b_data = np.random.uniform(low, high, size=(b_shape)).astype(dtype) + + return (tvm.IRModule({"main": func}), (a_data, b_data)) + + +def get_unary_op_mod(a_shape, op, dtype): + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + a = R.arg("a", R.Tensor(a_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit(op(a)) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + + low, high = 0, 1 + a_data = np.random.uniform(low, high, size=(a_shape)).astype(dtype) + + return (tvm.IRModule({"main": func}), (a_data,)) + + +def get_relax_maxpool_mod( + data_shape, dtype, pool_size, stride=None, dilation=(1, 1), padding=(0, 0), has_pad=False +): + """ + Args: + data_shape (tuple): Input tensor shape + pool_size (tuple): Pooling window size (height, width) + stride (tuple, optional): Stride of pooling operation. Defaults to pool_size. + dilation (tuple, optional): Dilation rate. Defaults to (1, 1). + padding (tuple, optional): Padding for the input tensor. Defaults to (0, 0). + dtype (str, optional): Data type. Defaults to "float32". + has_pad (bool, optional): Whether to apply explicit padding. Defaults to False. + + Returns: + tvm.IRModule: Relax MaxPool module + """ + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + + if has_pad: + p = (0, 0, 0, 0, padding[0], padding[1], padding[0], padding[1]) + orig_data = R.arg("data", R.Tensor(data_shape, dtype)) + data = R.nn.pad(orig_data, pad_width=p, pad_value=float("-inf")) + padding = (0, 0) + else: + data = R.arg("data", R.Tensor(data_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit( + R.nn.max_pool2d( + data, + pool_size=pool_size, + strides=stride, + dilation=dilation, + padding=padding, + layout="NCHW", + ) + ) + R.output(output) + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_maxpool_expected_codegen(input_shape, pool_size, stride, padding, pool_type, dtype): + import math + + adjusted_input_shape = [ + input_shape[0], + input_shape[1], + input_shape[2] + padding[0] + padding[1], + input_shape[3] + padding[2] + padding[3], + ] + + pool_height = math.floor(((adjusted_input_shape[2] - pool_size[0]) / stride[0]) + 1) + pool_width = math.floor(((adjusted_input_shape[3] - pool_size[1]) / stride[1]) + 1) + output_shape = [adjusted_input_shape[0], adjusted_input_shape[1], pool_height, pool_width] + + attrs = { + "ceil_mode": [["0"]], + "clml_version": [["3"]], + "dilation": [["1", "1"]], + "layout": [["NCHW"]], + "num_inputs": "1", + "num_outputs": "1", + "out_layout": [["NCHW"]], + "padding": [[str(0) for p in padding]], + "pool_size": [[str(p) for p in pool_size]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in stride]], + "count_include_pad": [["0"]], + } + if sum(padding): + attrs["count_include_pad"] = [["0"]] + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(adjusted_input_shape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "kernel", + "name": "", + "inputs": [[0, 0, 0]], + "attrs": attrs, + }, + ] + return exp_codegen + + +def get_relax_avgpool_mod(data_shape, dtype, pool_size, stride, dilation, padding, has_pad): + """ + Args: + data_shape (tuple): Input tensor shape + pool_size (tuple): Pooling window size (height, width) + stride (tuple, optional): Stride of pooling operation. Defaults to pool_size. + dilation (tuple, optional): Dilation rate. Defaults to (1, 1). + padding (tuple, optional): Padding for the input tensor. Defaults to (0, 0). + dtype (str, optional): Data type. Defaults to "float32". + has_pad (bool, optional): Whether to apply explicit padding. Defaults to False. + count_include_pad (bool, optional): Whether to include padding in averaging. Defaults to True. + + Returns: + tvm.IRModule: Relax AvgPool module + """ + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + + if has_pad: + p = (0, 0, 0, 0, padding[0], padding[1], padding[0], padding[1]) + orig_data = R.arg("data", R.Tensor(data_shape, dtype)) + data = R.nn.pad(orig_data, pad_width=p, pad_value=0.0) + padding = (0, 0) + else: + data = R.arg("data", R.Tensor(data_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit( + R.nn.avg_pool2d( + data, + pool_size=pool_size, + strides=stride, + dilation=dilation, + padding=padding, + layout="NCHW", + ) + ) + R.output(output) + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_avgpool_expected_codegen(input_shape, pool_size, stride, padding, pool_type, dtype): + import math + + adjusted_input_shape = [ + input_shape[0], + input_shape[1], + input_shape[2] + padding[0] + padding[1], + input_shape[3] + padding[2] + padding[3], + ] + + pool_height = math.floor(((adjusted_input_shape[2] - pool_size[0]) / stride[0]) + 1) + pool_width = math.floor(((adjusted_input_shape[3] - pool_size[1]) / stride[1]) + 1) + output_shape = [adjusted_input_shape[0], adjusted_input_shape[1], pool_height, pool_width] + + attrs = { + "ceil_mode": [["0"]], + "clml_version": [["3"]], + "dilation": [["1", "1"]], + "layout": [["NCHW"]], + "num_inputs": "1", + "num_outputs": "1", + "out_layout": [["NCHW"]], + "padding": [[str(0) for p in padding]], + "pool_size": [[str(p) for p in pool_size]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in stride]], + "count_include_pad": [["0"]], + } + if sum(padding): + attrs["count_include_pad"] = [["0"]] + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(adjusted_input_shape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "kernel", + "name": "", + "inputs": [[0, 0, 0]], + "attrs": attrs, + }, + ] + return exp_codegen + + +def get_relax_reshape_mod(input_shape, output_shape, dtype): + """ + Args: + input_shape (tuple): Input tensor shape + output_shape (tuple): Desired output tensor shape + dtype (str, optional): Data type. Defaults to "float32". + + Returns: + tvm.IRModule: Relax Reshape module + """ + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + data = R.arg("data", R.Tensor(input_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit(R.reshape(data, output_shape)) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_relax_reshape_codegen(input_shape, output_shape, dtype): + def compute_output_shape(input_shape, output_shape): + input_elements = np.prod(input_shape) + specified_elements = np.prod([dim for dim in output_shape if dim != -1]) + missing_dim = input_elements // specified_elements + return [int(dim) if dim != -1 else int(missing_dim) for dim in output_shape] + + expected_output_shape = compute_output_shape(input_shape, output_shape) + + expected_codegen_str = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(input_shape)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "clml_version": [["3"]], + "dtype": [[dtype]], + "num_inputs": "1", + "num_outputs": "1", + "shape": [[expected_output_shape]], + }, + "inputs": [[0, 0, 0]], + "name": "", + "op": "kernel", + }, + ] + return expected_codegen_str + + +def get_relax_global_avgpool_mod(data_shape, keepdims, dtype): + """ + Create a Relax module for Global Average Pooling (GAP). + + Args: + data_shape (tuple): Input tensor shape (N, C, H, W) + dtype (str): Data type + + Returns: + tvm.IRModule: Relax GAP module + """ + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + data = R.arg("data", R.Tensor(data_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit(R.mean(data, axis=[2, 3], keepdims=keepdims)) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_global_avgpool_expected_codegen(input_shape, keep_dims, dtype): + """ + Generate expected codegen for Global Average Pooling. + + Args: + input_shape (tuple): Input shape (N, C, H, W) + dtype (str): Data type + + Returns: + dict: Expected codegen output + """ + output_shape = ( + [input_shape[0], input_shape[1]] + if not keep_dims + else [input_shape[0], input_shape[1], 1, 1] + ) + attrs = { + "num_inputs": "1", + "num_outputs": "1", + "clml_version": [["3"]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "axis": [["2", "3"]], + "keepdims": [["1" if keep_dims else "0"]], + } + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(input_shape)]], "dtype": [[str(dtype)]]}, + }, + {"op": "kernel", "name": "", "inputs": [[0, 0, 0]], "attrs": attrs}, + ] + return exp_codegen + + +def get_relax_global_maxpool_mod(data_shape, keepdims, dtype): + """ + Create a Relax module for Global Average Pooling (GAP). + + Args: + data_shape (tuple): Input tensor shape (N, C, H, W) + dtype (str): Data type + + Returns: + tvm.IRModule: Relax GAP module + """ + N, C, H, W = data_shape + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + data = R.arg("data", R.Tensor(data_shape, dtype)) + + with R.dataflow() as frame: + output = R.emit( + R.nn.max_pool2d( + data, pool_size=(H, W), strides=(1, 1), padding=(0, 0), layout="NCHW" + ) + ) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def get_global_maxpool_expected_codegen(input_shape, pool_size, stride, padding, pool_type, dtype): + import math + + adjusted_input_shape = [ + input_shape[0], + input_shape[1], + input_shape[2] + padding[0] + padding[1], + input_shape[3] + padding[2] + padding[3], + ] + + output_shape = [adjusted_input_shape[0], adjusted_input_shape[1], 1, 1] + + attrs = { + "ceil_mode": [["0"]], + "clml_version": [["3"]], + "dilation": [["1", "1"]], + "layout": [["NCHW"]], + "num_inputs": "1", + "num_outputs": "1", + "out_layout": [["NCHW"]], + "padding": [[str(0) for p in padding]], + "pool_size": [[str(p) for p in pool_size]], + "shape": [[list(output_shape)]], + "dtype": [[dtype]], + "strides": [[str(s) for s in stride]], + "count_include_pad": [["0"]], + } + if sum(padding): + attrs["count_include_pad"] = [["0"]] + + exp_codegen = [ + { + "op": "input", + "name": "", + "attrs": {"shape": [[list(adjusted_input_shape)]], "dtype": [[str(dtype)]]}, + }, + { + "op": "kernel", + "name": "", + "inputs": [[0, 0, 0]], + "attrs": attrs, + }, + ] + return exp_codegen diff --git a/tests/python/relax/backend/clml/test_clml_codegen.py b/tests/python/relax/backend/clml/test_clml_codegen.py new file mode 100644 index 000000000000..b03d6afa1c9b --- /dev/null +++ b/tests/python/relax/backend/clml/test_clml_codegen.py @@ -0,0 +1,505 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CLML integration operator tests.""" +import pytest +import numpy as np +import tvm +import tvm.testing +import json + +from tvm import relax +from tvm.script import relax as R +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder +from tvm.relax.backend.adreno import clml +from tvm.relax.backend.adreno.clml import OpenCLMLOffLoad + +from mod_utils import ( + get_relax_conv2d_mod, + get_clml_conv2d_codegen, + get_relax_conv2d_transpose_mod, + get_conv2d_transpose_expected_codegen, + get_batchnorm_mod, + get_binary_op_mod, + get_unary_op_mod, + get_relax_maxpool_mod, + get_maxpool_expected_codegen, + get_relax_avgpool_mod, + get_avgpool_expected_codegen, + get_relax_reshape_mod, + get_relax_reshape_codegen, + get_relax_global_avgpool_mod, + get_global_avgpool_expected_codegen, + get_relax_global_maxpool_mod, + get_global_maxpool_expected_codegen, +) + + +def compare_codegen(clml_mod, clml_codegen): + source = clml_mod.attrs["external_mods"][0].get_source() + codegen = json.loads(source)["nodes"] + for node in range(len(codegen)): + if codegen[node]["op"] == "input" or codegen[node]["op"] == "const": + codegen[node]["name"] = "" + if codegen[node]["op"] == "kernel": + codegen[node]["name"] = "" + codegen_str = json.dumps(codegen, sort_keys=True, indent=2) + known_good_codegen_str = json.dumps(clml_codegen, sort_keys=True, indent=2) + assert codegen_str == known_good_codegen_str, ( + f"The JSON produced by codegen does not match the expected result. \n" + f"Actual={codegen_str} \n" + f"Expected={known_good_codegen_str}" + ) + + +def verify(mod, params_np, clml_codegen): + mod = tvm.relax.transform.BindParams("main", params_np)(mod) + clml_mod = OpenCLMLOffLoad()(mod) + compare_codegen(clml_mod, clml_codegen) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "kernel_h, kernel_w, padding, stride, dilation, out_channels, shape, has_bias, has_bn, has_activation, has_pad, is_depthwise", + [ + (3, 3, (1, 1), (1, 1), (1, 1), 64, (3, 224, 224), False, True, False, True, False), + (3, 3, (1, 1), (1, 1), (1, 1), 64, (3, 224, 224), False, True, False, False, False), + (5, 5, (2, 2), (1, 1), (1, 1), 16, (16, 64, 64), False, True, True, False, False), + (7, 7, (3, 3), (2, 2), (1, 1), 32, (3, 224, 224), True, False, True, True, False), + (3, 3, (0, 0), (1, 1), (1, 1), 512, (256, 14, 14), True, False, True, False, False), + (1, 1, (0, 0), (1, 1), (1, 1), 1024, (512, 7, 7), True, False, True, False, False), + (1, 3, (0, 0), (1, 1), (1, 1), 64, (64, 7, 7), True, False, True, False, False), + (3, 1, (0, 0), (1, 1), (1, 1), 64, (64, 7, 7), False, True, True, True, False), + ], +) +def test_conv2d_offload( + kernel_h, + kernel_w, + padding, + stride, + dilation, + out_channels, + shape, + has_bias, + has_bn, + has_activation, + has_pad, + is_depthwise, + dtype, +): + low, high = 0, 1 + data_shape = (1, *shape) + if is_depthwise: + groups = data_shape[1] // out_channels + else: + groups = 1 + padding = (padding[0], padding[1], padding[0], padding[1]) + + weight_format = "IOHW" if is_depthwise else "OIHW" + weight_shape = (out_channels, data_shape[1] // groups, kernel_h, kernel_w) + + weight = np.random.uniform(low, high, size=weight_shape).astype(dtype) + bias = np.random.uniform(low, high, size=(1, weight_shape[0], 1, 1)).astype(dtype) + + gamma = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + beta = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + mean = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + variance = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + + params_np = {"weight": weight} + if has_bias: + params_np["bias"] = bias + if has_bn: + params_np.update({"gamma": gamma, "beta": beta, "mean": mean, "variance": variance}) + + mod = get_relax_conv2d_mod( + data_shape, + weight_shape, + stride=stride, + dilation=dilation, + padding=padding, + weight_layout=weight_format, + groups=groups, + dtype=dtype, + has_bias=has_bias, + has_bn=has_bn, + has_activation=has_activation, + has_pad=has_pad, + is_depthwise=is_depthwise, + ) + + clml_codegen = get_clml_conv2d_codegen( + data_shape, + weight_shape, + stride=stride, + dilation=dilation, + padding=padding, + weight_layout=weight_format, + groups=groups, + dtype=dtype, + has_bias=has_bias, + has_bn=has_bn, + has_activation=has_activation, + has_pad=has_pad, + is_depthwise=is_depthwise, + ) + + verify(mod, params_np, clml_codegen) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "dshape, kshape, channels, kernel_size, strides, padding, out_shape", + [ + ((1, 256, 100, 100), (64, 256, 4, 4), 64, (4, 4), (2, 2), (0, 0, 0, 0), (1, 64, 202, 202)), + ((1, 64, 200, 200), (64, 64, 4, 4), 64, (4, 4), (2, 2), (1, 1, 1, 1), (1, 64, 400, 400)), + ((1, 64, 200, 200), (64, 64, 4, 4), 64, (4, 4), (2, 2), (1, 1, 1, 1), (1, 64, 400, 400)), + ((1, 64, 400, 400), (16, 64, 4, 4), 16, (4, 4), (2, 2), (1, 1, 1, 1), (1, 16, 800, 800)), + ], +) +def test_conv2d_transpose( + dshape, kshape, channels, kernel_size, strides, padding, dtype, out_shape +): + low, high = -1, 1 + weight = np.random.uniform(low, high, size=kshape).astype(dtype) + + params_np = {"weight": weight} + + mod = get_relax_conv2d_transpose_mod( + dshape, + kshape, + channels=channels, + stride=strides, + padding=padding, + dtype=dtype, + ) + + exp_codegen = get_conv2d_transpose_expected_codegen( + dshape=dshape, + kshape=kshape, + channels=channels, + kernel_size=kernel_size, + strides=strides, + padding=padding, + dilation=(1, 1), + dtype=dtype, + output_shape=out_shape, + ) + verify(mod, params_np, exp_codegen) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 14, 14), 1, 3e-4], + [(1, 14, 256, 256), 1, 3e-4], + [(1, 14, 256, 256), 1, 3e-4], + [(1, 256, 1, 1), 1, 3e-4], + ], +) +def test_batchnorm(dtype, trials): + low, high = 0, 1 + if clml.clml_sdk_version() < 3: + print("Skip due to unsupported CLML version:", clml.clml_sdk_version()) + return + + (input_shape, axis, epsilon) = trials + channels = input_shape[axis] + + def _get_axis_tuple(axis): + if axis == 0: + return (1, 2, 3) + elif axis == 1: + return (0, 2, 3) + elif axis == 2: + return (0, 1, 3) + else: + return (0, 1, 2) + + data = np.random.uniform(low, high, size=(input_shape)).astype(dtype) + gamma = np.random.uniform(low, high, size=(channels)).astype(dtype) + beta = np.random.uniform(low, high, size=(channels)).astype(dtype) + mean = np.mean(data, _get_axis_tuple(axis), keepdims=False) + variance = np.var(data, _get_axis_tuple(axis), keepdims=False) + + params_np = {"gamma": gamma, "beta": beta, "moving_mean": mean, "moving_var": variance} + mod = get_batchnorm_mod(input_shape, channels, axis, epsilon, dtype) + exp_codegen = [ + { + "attrs": {"dtype": [[dtype]], "shape": [[input_shape]]}, + "name": "", + "op": "input", + }, + {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", "op": "const"}, + {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", "op": "const"}, + {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", "op": "const"}, + {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", "op": "const"}, + { + "attrs": { + "axis": [[str(axis)]], + "center": [["1"]], + "dtype": [[dtype]], + "clml_version": [["3"]], + "momentum": [["0.10000000000000001"]], + "epsilon": [["0.00029999999999999997"]], + "num_inputs": "5", + "num_outputs": "1", + "scale": [["1"]], + "shape": [[input_shape]], + }, + "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0], [3, 0, 0], [4, 0, 0]], + "name": "", + "op": "kernel", + }, + ] + verify(mod, params_np, exp_codegen) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "a_shape, b_shape, op", + [ + ((1, 64, 14, 14), (1, 64, 14, 14), R.add), + ((1, 256), (1, 256), R.add), + ((1, 64, 14, 14), (1, 64, 14, 14), R.subtract), + ((1, 256), (1, 256), R.subtract), + ((1, 64, 14, 14), (1, 64, 14, 14), R.multiply), + ((1, 256), (1, 256), R.multiply), + ((1, 64, 14, 14), (1, 64, 14, 14), R.divide), + ((1, 256), (1, 256), R.divide), + ((1, 64, 14, 14), (1, 64, 14, 14), R.minimum), + ((1, 256), (1, 256), R.minimum), + ((1, 64, 14, 14), (1, 64, 14, 14), R.maximum), + ((1, 256), (1, 256), R.maximum), + ], +) +@tvm.testing.requires_openclml +def test_binary_ops(a_shape, b_shape, op, dtype): + def _verify(mod): + expected_codegen_str = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[a_shape]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "dtype": [[dtype]], + "shape": [[b_shape]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "clml_version": [["3"]], + "dtype": [[dtype]], + "num_inputs": "2", + "num_outputs": "1", + "shape": [[a_shape]], + }, + "inputs": [[0, 0, 0], [1, 0, 0]], + "name": "", + "op": "kernel", + }, + ] + verify(mod, {}, expected_codegen_str) + + (mod, _) = get_binary_op_mod(a_shape, b_shape, op, dtype) + + _verify(mod) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize( + "dtype", + [ + "float32", + ], +) +@pytest.mark.parametrize( + "a_shape, op", + [ + ((1, 64, 14, 14), R.nn.relu), + ((1, 256, 1, 1), R.nn.relu), + ((1, 14, 256, 256), R.nn.relu), + ((1, 14, 14, 256), R.nn.relu), + ], +) +@tvm.testing.requires_openclml +def test_unary_ops(a_shape, op, dtype): + def _verify(mod): + expected_codegen_str = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[a_shape]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "activation_type": [["relu"]], + "clml_version": [["3"]], + "dtype": [[dtype]], + "num_inputs": "1", + "num_outputs": "1", + "shape": [[a_shape]], + }, + "inputs": [[0, 0, 0]], + "name": "", + "op": "kernel", + }, + ] + verify(mod, {}, expected_codegen_str) + + (mod, _) = get_unary_op_mod(a_shape, op, dtype) + + _verify(mod) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), (3, 3), (2, 2), (1, 1), (0, 0, 0, 0), False], + [(1, 256, 17, 17), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 1024, 14, 14), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (0, 1, 0, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 0, 1, 0), True], + ], +) +def test_max_pool(dtype, trials): + low, high = -1, 1 + (input_shape, pool_size, stride, dilation, padding, has_pad) = trials + mod = get_relax_maxpool_mod(input_shape, dtype, pool_size, stride, dilation, padding, has_pad) + params_np = {} + + expected_codegen_str = get_maxpool_expected_codegen( + input_shape, pool_size, stride, padding, "maxpool2d", dtype + ) + verify(mod, params_np, expected_codegen_str) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), (3, 3), (2, 2), (1, 1), (0, 0, 0, 0), False], + [(1, 256, 17, 17), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 1024, 14, 14), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (0, 1, 0, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 0, 1, 0), True], + ], +) +def test_avg_pool(dtype, trials): + low, high = -1, 1 + (input_shape, pool_size, stride, dilation, padding, has_pad) = trials + mod = get_relax_avgpool_mod(input_shape, dtype, pool_size, stride, dilation, padding, has_pad) + params_np = {} + exp_codegen_str = get_avgpool_expected_codegen( + input_shape, pool_size, stride, padding, "avg_pool2d", dtype + ) + verify(mod, params_np, exp_codegen_str) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 3, 32, 32), (1, 4, -1, 32)], + [(1, 4, 8, 32), (1, 4, -1, 16)], + [(1, 64, 3, 3), (1, 32, 3, -1)], + ], +) +def test_reshape(dtype, trials): + low, high = -1, 1 + (input_shape, output_shape) = trials + mod = get_relax_reshape_mod(input_shape, output_shape, dtype) + params_np = {} + expected_codegen = get_relax_reshape_codegen(input_shape, output_shape, dtype) + verify(mod, params_np, expected_codegen) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), True], + [(1, 256, 17, 17), False], + [(1, 1024, 14, 14), True], + [(1, 32, 256, 256), False], + ], +) +def test_global_avg_pool(dtype, trials): + """Test function for global average pooling.""" + low, high = -1, 1 + (input_shape, keep_dims) = trials + mod = get_relax_global_avgpool_mod(input_shape, keep_dims, dtype) + params_np = {} + exp_codegen_str = get_global_avgpool_expected_codegen(input_shape, keep_dims, dtype) + verify(mod, params_np, exp_codegen_str) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), True], + [(1, 256, 17, 17), False], + [(1, 1024, 14, 14), True], + [(1, 32, 256, 256), False], + ], +) +def test_global_max_pool(dtype, trials): + """Test function for global average pooling.""" + low, high = -1, 1 + (input_shape, keep_dims) = trials + N, C, H, W = input_shape + pool_size = (H, W) + stride = (1, 1) + padding = (0, 0, 0, 0) + mod = get_relax_global_maxpool_mod(input_shape, keep_dims, dtype) + params_np = {} + exp_codegen_str = get_global_maxpool_expected_codegen( + input_shape, pool_size, stride, padding, "global_max", dtype + ) + verify(mod, params_np, exp_codegen_str) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/backend/clml/test_op_exec_clml_codegen.py b/tests/python/relax/backend/clml/test_op_exec_clml_codegen.py new file mode 100644 index 000000000000..4e5b4b652b45 --- /dev/null +++ b/tests/python/relax/backend/clml/test_op_exec_clml_codegen.py @@ -0,0 +1,329 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CLML integration operator tests.""" +import pytest +import numpy as np +import tvm +import tvm.testing +import json + +from tvm import relax, rpc +from tvm.script import relax as R +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder +from tvm.relax.backend.adreno import clml +from utils import run_compare + +from mod_utils import ( + get_relax_conv2d_mod, + get_batchnorm_mod, + get_binary_op_mod, + get_unary_op_mod, + get_relax_maxpool_mod, + get_relax_avgpool_mod, + get_relax_reshape_mod, + get_relax_reshape_codegen, + get_relax_global_avgpool_mod, + get_relax_global_maxpool_mod, +) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "kernel_h, kernel_w, padding, stride, dilation, out_channels, shape, has_bias, has_bn, has_activation, has_pad, is_depthwise", + [ + (3, 3, (1, 1), (1, 1), (1, 1), 64, (3, 224, 224), False, True, False, True, False), + (3, 3, (1, 1), (1, 1), (1, 1), 64, (3, 224, 224), False, True, False, False, False), + (5, 5, (2, 2), (1, 1), (1, 1), 16, (16, 64, 64), False, True, True, False, False), + (7, 7, (3, 3), (2, 2), (1, 1), 32, (3, 224, 224), True, False, True, True, False), + (3, 3, (0, 0), (1, 1), (1, 1), 512, (256, 14, 14), True, False, True, False, False), + (1, 1, (0, 0), (1, 1), (1, 1), 1024, (512, 7, 7), True, False, True, False, False), + (1, 3, (0, 0), (1, 1), (1, 1), 64, (64, 7, 7), True, False, True, False, False), + (3, 1, (0, 0), (1, 1), (1, 1), 64, (64, 7, 7), False, True, True, True, False), + ], +) +def test_conv2d_offload( + kernel_h, + kernel_w, + padding, + stride, + dilation, + out_channels, + shape, + has_bias, + has_bn, + has_activation, + has_pad, + is_depthwise, + dtype, + rpc, +): + low, high = 0, 1 + data_shape = (1, *shape) + if is_depthwise: + groups = data_shape[1] // out_channels + else: + groups = 1 + padding = (padding[0], padding[1], padding[0], padding[1]) + + weight_format = "IOHW" if is_depthwise else "OIHW" + weight_shape = (out_channels, data_shape[1] // groups, kernel_h, kernel_w) + + data = np.random.uniform(low, high, size=data_shape).astype(dtype) + weight = np.random.uniform(low, high, size=weight_shape).astype(dtype) + bias = np.random.uniform(low, high, size=(1, weight_shape[0], 1, 1)).astype(dtype) + + gamma = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + beta = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + mean = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + variance = np.random.uniform(low, high, size=(weight_shape[0],)).astype(dtype) + + inputs = [data] + params_np = {"weight": weight} + if has_bias: + params_np["bias"] = bias + if has_bn: + params_np.update({"gamma": gamma, "beta": beta, "mean": mean, "variance": variance}) + + mod = get_relax_conv2d_mod( + data_shape, + weight_shape, + stride=stride, + dilation=dilation, + padding=padding, + weight_layout=weight_format, + groups=groups, + dtype=dtype, + has_bias=has_bias, + has_bn=has_bn, + has_activation=has_activation, + has_pad=has_pad, + is_depthwise=is_depthwise, + ) + run_compare(mod, inputs, params_np, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 14, 14), 1, 3e-4], + [(1, 14, 256, 256), 1, 3e-4], + [(1, 14, 256, 256), 1, 3e-4], + [(1, 256, 1, 1), 1, 3e-4], + ], +) +def test_batchnorm(dtype, trials, rpc): + low, high = 0, 1 + if clml.clml_sdk_version() < 3: + print("Skip due to unsupported CLML version:", clml.clml_sdk_version()) + return + + (input_shape, axis, epsilon) = trials + channels = input_shape[axis] + + def _get_axis_tuple(axis): + if axis == 0: + return (1, 2, 3) + elif axis == 1: + return (0, 2, 3) + elif axis == 2: + return (0, 1, 3) + else: + return (0, 1, 2) + + data = np.random.uniform(low, high, size=(input_shape)).astype(dtype) + gamma = np.random.uniform(low, high, size=(channels)).astype(dtype) + beta = np.random.uniform(low, high, size=(channels)).astype(dtype) + mean = np.mean(data, _get_axis_tuple(axis), keepdims=False) + variance = np.var(data, _get_axis_tuple(axis), keepdims=False) + + inputs = [data] + params_np = {"gamma": gamma, "beta": beta, "moving_mean": mean, "moving_var": variance} + mod = get_batchnorm_mod(input_shape, channels, axis, epsilon, dtype) + run_compare(mod, inputs, params_np, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "a_shape, b_shape, op", + [ + ((1, 64, 14, 14), (1, 64, 14, 14), R.add), + ((1, 256), (1, 256), R.add), + ((1, 64, 14, 14), (1, 64, 14, 14), R.subtract), + ((1, 256), (1, 256), R.subtract), + ((1, 64, 14, 14), (1, 64, 14, 14), R.multiply), + ((1, 256), (1, 256), R.multiply), + ((1, 64, 14, 14), (1, 64, 14, 14), R.divide), + ((1, 256), (1, 256), R.divide), + ((1, 64, 14, 14), (1, 64, 14, 14), R.minimum), + ((1, 256), (1, 256), R.minimum), + ((1, 64, 14, 14), (1, 64, 14, 14), R.maximum), + ((1, 256), (1, 256), R.maximum), + ], +) +@tvm.testing.requires_openclml +def test_binary_ops(a_shape, b_shape, op, rpc, dtype): + (mod, inputs) = get_binary_op_mod(a_shape, b_shape, op, dtype) + run_compare(mod, inputs, {}, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize( + "dtype", + [ + "float32", + ], +) +@pytest.mark.parametrize( + "a_shape, op", + [ + ((1, 64, 14, 14), R.nn.relu), + ((1, 256, 1, 1), R.nn.relu), + ((1, 14, 256, 256), R.nn.relu), + ((1, 14, 14, 256), R.nn.relu), + ], +) +@tvm.testing.requires_openclml +def test_unary_ops(a_shape, op, rpc, dtype): + (mod, inputs) = get_unary_op_mod(a_shape, op, dtype) + run_compare(mod, inputs, {}, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), (3, 3), (2, 2), (1, 1), (0, 0, 0, 0), False], + [(1, 256, 17, 17), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 1024, 14, 14), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (0, 1, 0, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 0, 1, 0), True], + ], +) +def test_max_pool(dtype, trials, rpc): + low, high = -1, 1 + (input_shape, pool_size, stride, dilation, padding, has_pad) = trials + data = np.random.uniform(low, high, size=input_shape).astype(dtype) + inputs = [data] + mod = get_relax_maxpool_mod(input_shape, dtype, pool_size, stride, dilation, padding, has_pad) + params_np = {} + run_compare(mod, inputs, params_np, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), (3, 3), (2, 2), (1, 1), (0, 0, 0, 0), False], + [(1, 256, 17, 17), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 1024, 14, 14), (3, 3), (1, 1), (1, 1), (0, 0, 0, 0), False], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (3, 3), (2, 2), (1, 1), (0, 1, 0, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 1, 1, 1), True], + [(1, 32, 256, 256), (2, 2), (2, 2), (1, 1), (1, 0, 1, 0), True], + ], +) +def test_avg_pool(dtype, trials, rpc): + low, high = -1, 1 + (input_shape, pool_size, stride, dilation, padding, has_pad) = trials + data = np.random.uniform(low, high, size=input_shape).astype(dtype) + inputs = [data] + mod = get_relax_avgpool_mod(input_shape, dtype, pool_size, stride, dilation, padding, has_pad) + params_np = {} + run_compare(mod, inputs, params_np, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 3, 32, 32), (1, 4, -1, 32)], + [(1, 4, 8, 32), (1, 4, -1, 16)], + [(1, 64, 3, 3), (1, 32, 3, -1)], + ], +) +def test_reshape(dtype, trials, rpc): + low, high = -1, 1 + (input_shape, output_shape) = trials + data = np.random.uniform(low, high, size=input_shape).astype(dtype) + inputs = [data] + mod = get_relax_reshape_mod(input_shape, output_shape, dtype) + params_np = {} + run_compare(mod, inputs, params_np, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), True], + [(1, 256, 17, 17), False], + [(1, 1024, 14, 14), True], + [(1, 32, 256, 256), False], + ], +) +def test_global_avg_pool(dtype, trials, rpc): + """Test function for global average pooling.""" + low, high = -1, 1 + (input_shape, keep_dims) = trials + data = np.random.uniform(low, high, size=input_shape).astype(dtype) + inputs = [data] + mod = get_relax_global_avgpool_mod(input_shape, keep_dims, dtype) + params_np = {} + run_compare(mod, inputs, params_np, rpc) + + +@tvm.testing.requires_openclml +@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize( + "trials", + [ + [(1, 64, 147, 147), True], + [(1, 256, 17, 17), False], + [(1, 1024, 14, 14), True], + [(1, 32, 256, 256), False], + ], +) +def test_global_max_pool(dtype, trials, rpc): + """Test function for global average pooling.""" + low, high = -1, 1 + (input_shape, keep_dims) = trials + N, C, H, W = input_shape + pool_size = (H, W) + stride = (1, 1) + padding = (0, 0, 0, 0) + data = np.random.uniform(low, high, size=input_shape).astype(dtype) + inputs = [data] + mod = get_relax_global_maxpool_mod(input_shape, keep_dims, dtype) + params_np = {} + run_compare(mod, inputs, params_np, rpc) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/backend/clml/utils.py b/tests/python/relax/backend/clml/utils.py new file mode 100644 index 000000000000..22b587c964ab --- /dev/null +++ b/tests/python/relax/backend/clml/utils.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Run utils for CLML integration operator tests""" +import pytest +import numpy as np +import json +import tvm +import tvm.testing +import copy + +from tvm import relax, rpc +from tvm.relax import transform +from tvm import dlight as dl +from tvm.contrib import utils, ndk +from tvm.relax.backend.adreno.clml import OpenCLMLOffLoad + + +def build_and_run( + mod, + inputs_np, + target, + rpc=None, + load_path="vm_library.so", + clml_enable=False, +): + + tgt = tvm.target.Target(target, host="llvm -mtriple=aarch64-linux-gnu") + pipeline = relax.pipeline.get_default_pipeline(tgt) + mod = pipeline(mod) + if rpc: + ex = relax.build(mod, tgt) + temp = utils.tempdir() + path = temp.relpath(load_path) + path = "./" + load_path + ex.export_library(path, fcompile=ndk.create_shared, options=["-shared", "-fPIC", "-lm"]) + rpc.upload(path) + rexec = rpc.load_module(load_path) + dev = rpc.cl(0) + vm = relax.VirtualMachine(rexec, dev) + else: + ex = relax.build(mod, target) + dev = tvm.device(target, 0) + vm = relax.VirtualMachine(ex, dev) + + f = vm["main"] + inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] + vm.set_input("main", *inputs) + vm.invoke_stateful("main") + tvm_output = vm.get_outputs("main") + return tvm_output.numpy() + + +def run_compare(mod, inputs, params_np, rpc=None): + clml_mod = copy.deepcopy(mod) + mod = tvm.relax.transform.BindParams("main", params_np)(mod) + clml_mod = tvm.relax.transform.BindParams("main", params_np)(clml_mod) + + if not rpc: + return + + ref = build_and_run( + mod, + inputs, + tvm.target.adreno(), + rpc=rpc, + load_path="vm_library_opencl.so", + ) + out = build_and_run( + clml_mod, + inputs, + tvm.target.adreno(clml=True), + rpc=rpc, + load_path="vm_library_clml.so", + clml_enable=True, + ) + np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) diff --git a/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py b/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py new file mode 100644 index 000000000000..fc68f51b9f6b --- /dev/null +++ b/tests/python/relax/test_transform_fold_batch_norm_to_conv2d.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R +from tvm.script import ir as I +from tvm.script.ir_builder import IRBuilder +from tvm.ir.module import IRModule +from tvm.script.ir_builder import relax as relax_builder +from tvm.relax.expr_functor import PyExprVisitor, visitor + + +def get_conv2d_batchnorm_sample(): + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + data = R.arg("data", R.Tensor((1, 3, 224, 224), "float32")) + weight = R.arg("weight", R.Tensor((32, 3, 3, 3), "float32")) + with R.dataflow() as frame: + output = R.emit( + R.nn.conv2d( + data, + weight, + out_dtype="float32", + strides=(1, 1), + dilation=(1, 1), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + groups=1, + ) + ) + gamma = R.arg("gamma", R.Tensor((32,), "float32")) + beta = R.arg("beta", R.Tensor((32,), "float32")) + mean = R.arg("mean", R.Tensor((32,), "float32")) + variance = R.arg("variance", R.Tensor((32,), "float32")) + output = R.emit( + R.nn.batch_norm(output, gamma, beta, mean, variance, axis=1, epsilon=1e-5)[0] + ) + R.output(output) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + + return tvm.IRModule({"main": func}) + + +def test_fold_batchnorm_info_conv2d(): + mod = get_conv2d_batchnorm_sample() + mod_fold = get_conv2d_batchnorm_sample() + + target = tvm.target.Target("llvm", host="llvm") + data_in = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype(np.float32)) + + weight_data = tvm.nd.array(np.random.rand(32, 3, 3, 3).astype(np.float32)) + gamma_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + beta_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + mean_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + variance_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + params_np = { + "weight": weight_data, + "gamma": gamma_data, + "beta": beta_data, + "mean": mean_data, + "variance": variance_data, + } + + mod = tvm.relax.transform.BindParams("main", params_np)(mod) + mod_fold = tvm.relax.transform.BindParams("main", params_np)(mod_fold) + + # Normal build + mod = tvm.relax.transform.DecomposeOpsForInference()(mod) + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + out = vm["main"](data_in) + + # Fold BN to Conv2D + mod_fold = relax.transform.FoldBatchnormToConv2D()(mod_fold) + mod_fold = relax.transform.FoldConstant()(mod_fold) + ex_fold = relax.build(mod_fold, target) + vm_fold = relax.VirtualMachine(ex_fold, tvm.cpu()) + out_fold = vm_fold["main"](data_in) + + tvm.testing.assert_allclose(out.numpy(), out_fold.numpy(), rtol=1e-5, atol=1e-5) + + +@visitor +class VerifyFolding(PyExprVisitor): # pylint: disable=abstract-method + def visit(self, mod: IRModule) -> None: + """Entry point""" + for _, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.visit_expr(func) + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + assert ( + call.op.name != "relax.nn.batch_norm" + ), f"Batchnorm op shouldn't be present after folding to previous conv2d" + + +def test_fold_batchnorm_info_conv2d_transform(): + mod = get_conv2d_batchnorm_sample() + mod = relax.transform.FoldBatchnormToConv2D()(mod) + weight_data = tvm.nd.array(np.random.rand(32, 3, 3, 3).astype(np.float32)) + gamma_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + beta_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + mean_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + variance_data = tvm.nd.array(np.random.rand(32).astype(np.float32)) + params_np = { + "weight": weight_data, + "gamma": gamma_data, + "beta": beta_data, + "mean": mean_data, + "variance": variance_data, + } + mod = tvm.relax.transform.BindParams("main", params_np)(mod) + mod = relax.transform.FoldBatchnormToConv2D()(mod) + mod = relax.transform.FoldConstant()(mod) + + VerifyFolding().visit(mod) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 244ce4b8a504..476886782620 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -46,3 +46,4 @@ echo set\(USE_CUTLASS ON\) >> config.cmake # Temporary disable MSC # echo set\(USE_MSC ON\) >> config.cmake echo set\(CMAKE_CUDA_ARCHITECTURES 75\) >> config.cmake +echo set\(USE_CLML ON\) >> config.cmake diff --git a/tests/scripts/task_python_adreno.sh b/tests/scripts/task_python_adreno.sh index 684a63e77fae..f019cd1eccb1 100755 --- a/tests/scripts/task_python_adreno.sh +++ b/tests/scripts/task_python_adreno.sh @@ -75,5 +75,14 @@ for node_id in $CLML_TESTS; do i=$((i+1)) done +# Relax test +RELAX_TESTS=$(./ci/scripts/jenkins/pytest_ids.py --folder tests/python/relax/backend/clml 2> /dev/null | grep -v dlerror) +i=0 +for node_id in $RELAX_TESTS; do + echo "$node_id" + CXX=${TVM_NDK_CC} run_pytest ctypes "$TVM_INTEGRATION_TESTSUITE_NAME-openclml-relax-$i" "$node_id" --reruns=0 + i=$((i+1)) +done + kill ${TRACKER_PID} kill ${DEVICE_PID} diff --git a/tests/scripts/unity/task_python_relax.sh b/tests/scripts/unity/task_python_relax.sh index 688812b35d32..5eb2a9e4201e 100755 --- a/tests/scripts/unity/task_python_relax.sh +++ b/tests/scripts/unity/task_python_relax.sh @@ -39,3 +39,6 @@ TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm}" pytest tests/python/dlight # Test for MSC # pytest tests/python/contrib/test_msc + +# Test for OpenCLML +pytest tests/python/relax/backend/clml/