diff --git a/CMakeLists.txt b/CMakeLists.txt index b5931d624ced2..14028f4e40e6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -507,6 +507,7 @@ include(cmake/modules/VTA.cmake) include(cmake/modules/StandaloneCrt.cmake) include(cmake/modules/CUDA.cmake) include(cmake/modules/Hexagon.cmake) # This must come before logging.cmake +include(cmake/modules/contrib/CLML.cmake) # Must be before OpenCL.cmake include(cmake/modules/OpenCL.cmake) include(cmake/modules/OpenMP.cmake) include(cmake/modules/Vulkan.cmake) @@ -540,7 +541,6 @@ include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) -include(cmake/modules/contrib/CLML.cmake) include(cmake/modules/contrib/UMA.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) diff --git a/apps/cpp_clml/clml_runner.cc b/apps/cpp_clml/clml_runner.cc index d733922da4996..0a5508635e0a5 100644 --- a/apps/cpp_clml/clml_runner.cc +++ b/apps/cpp_clml/clml_runner.cc @@ -50,8 +50,8 @@ CLMLRunner::CLMLRunner(std::string name, ToolArgs& args, cl_platform_id arg_plat context(arg_context), device_id(arg_device_id), queue(arg_queue) { - LOG(INFO) << "CLMLRunner Constructor: Input:" << r_args.input << " Output:" << r_args.output - << " Params:" << r_args.params; + LOG(INFO) << "CLMLRunner Constructor:" << name << " Input:" << r_args.input + << " Output:" << r_args.output << " Params:" << r_args.params; cl_int result; // Query and Get CLML Interface @@ -648,25 +648,29 @@ void CLMLRunner::MakeConcatenate( void CLMLRunner::MakeDense(std::shared_ptr input_desc, std::shared_ptr weight_desc, std::shared_ptr output_desc, - std::shared_ptr bias_desc, + std::vector in_shape, std::vector wt_shape, std::string dtype) { cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(MakeCLDataType(dtype)); cl_ml_op_qcom op = nullptr; cl_int result; + cl_gemm_transform_qcom b_transform = CL_GEMM_TRANSFORM_NONE_QCOM; - cl_ml_op_convolution_desc_qcom conv_desc = {CL_CONVOLUTION_MODE_CONVOLUTION_QCOM, - 1, - 4, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 0, - cl_arithmetic_mode}; - - result = h_ClmlIntf->clCreateMLOpConvolutionForwardQCOM( - this->context, 0, &conv_desc, input_desc->tensor, weight_desc->tensor, bias_desc->tensor, - output_desc->tensor, &op, tuning_cache); + if (in_shape[1] == wt_shape[1]) { + b_transform = CL_GEMM_TRANSFORM_TRANSPOSE_QCOM; + } + + cl_ml_op_gemm_desc_qcom gemmDesc = {in_shape[0], // m + wt_shape[0], // n + wt_shape[1], // k + CL_GEMM_TRANSFORM_NONE_QCOM, // A transform + b_transform, // B transform + {{1.0}, CL_FLOAT}, // alpha + {{0.0}, CL_FLOAT}, // beta + cl_arithmetic_mode}; + + result = + h_ClmlIntf->clCreateMLOpGemmQCOM(this->context, 0, &gemmDesc, input_desc->tensor, + weight_desc->tensor, output_desc->tensor, &op, tuning_cache); CLML_SDK_TEST_AND_EXIT(op && result == CL_SUCCESS); this->function.push_back(op); diff --git a/apps/cpp_clml/clml_runner.h b/apps/cpp_clml/clml_runner.h index 4e73674d72ae2..a1e78fcb66bef 100644 --- a/apps/cpp_clml/clml_runner.h +++ b/apps/cpp_clml/clml_runner.h @@ -178,7 +178,7 @@ class CLMLRunner { void MakeDense(std::shared_ptr input_desc, std::shared_ptr weight_desc, std::shared_ptr output_desc, - std::shared_ptr bias_desc, std::string dtype); + std::vector in_shape, std::vector wt_shape, std::string dtype); /*! \brief SoftMax layer implementattion */ void MakeSoftMax(std::shared_ptr input_desc, diff --git a/apps/cpp_clml/scripts/clml_codegen.py b/apps/cpp_clml/scripts/clml_codegen.py index 32e5782db3852..bf19c0e4b9b60 100644 --- a/apps/cpp_clml/scripts/clml_codegen.py +++ b/apps/cpp_clml/scripts/clml_codegen.py @@ -45,7 +45,7 @@ def main(): clml_mod = clml.partition_for_clml(mod, params) libm = relay.build( clml_mod, - target="opencl -device=adreno", + target="opencl", target_host="llvm -mtriple=aarch64-linux-gnu", params=params, ) diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index f380ad75d14c5..2dc1fc18f36cd 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -59,20 +59,35 @@ if(USE_OPENCL) list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES}) endif() - if(DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) - include(FetchContent) - FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}") - set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) - FetchContent_MakeAvailable(googletest) - install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) + if(DEFINED USE_OPENCL_GTEST) + if(EXISTS ${USE_OPENCL_GTEST}) + include(FetchContent) + FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}") + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(googletest) + install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) - message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}") + message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}") + set(Build_OpenCL_GTests ON) + elseif (ANDROID_ABI AND DEFINED ENV{ANDROID_NDK_HOME}) + set(GOOGLETEST_ROOT $ENV{ANDROID_NDK_HOME}/sources/third_party/googletest) + add_library(gtest_main STATIC ${GOOGLETEST_ROOT}/src/gtest_main.cc ${GOOGLETEST_ROOT}/src/gtest-all.cc) + target_include_directories(gtest_main PRIVATE ${GOOGLETEST_ROOT}) + target_include_directories(gtest_main PUBLIC ${GOOGLETEST_ROOT}/include) + message(STATUS "Using gtest from Android NDK") + set(Build_OpenCL_GTests ON) + endif() - tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS - "${CMAKE_SOURCE_DIR}/tests/cpp-runtime/opencl/*.cc" - ) - add_executable(opencl-cpptest ${OPENCL_TEST_SRCS}) - target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime) + if(Build_OpenCL_GTests) + message(STATUS "Building OpenCL-Gtests") + tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS + "${CMAKE_SOURCE_DIR}/tests/cpp-runtime/opencl/*.cc" + ) + add_executable(opencl-cpptest ${OPENCL_TEST_SRCS}) + target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime) + else() + message(STATUS "Couldn't build OpenCL-Gtests") + endif() endif() list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS}) if(USE_OPENCL_ENABLE_HOST_PTR) diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py index 608c8a2a1b737..82926bb31d524 100644 --- a/python/tvm/relay/op/contrib/clml.py +++ b/python/tvm/relay/op/contrib/clml.py @@ -81,6 +81,36 @@ def transform_function( return RemoveDropout().visit(func) +class BroadcastInputs(ExprMutator): + """ + Binary operators need broadcasting for CLML. + """ + + def visit_call(self, call): + if call.op.name in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]: + new_fn = self.visit(call.op) + call_shape = call.checked_type.shape + lhs = call.args[0] + rhs = call.args[1] + lhs_shape = lhs.checked_type.shape + rhs_shape = rhs.checked_type.shape + if list(call_shape) != list(lhs_shape): + lhs = relay.broadcast_to(self.visit(lhs), call_shape) + if list(call_shape) != list(rhs_shape): + rhs = relay.broadcast_to(self.visit(rhs), call_shape) + args = [lhs, rhs] + return Call(new_fn, args, call.attrs) + return super().visit_call(call) + + +@transform.function_pass(opt_level=0) +class BinaryOpBroadcaster: + def transform_function( + self, func: relay.function.Function, mod: tvm.IRModule, _: tvm.transform.PassContext + ) -> relay.function.Function: + return BroadcastInputs().visit(func) + + def partition_for_clml(mod, params=None, **opts): """Partition the graph greedily offloading supported operators to CLML Library. @@ -104,6 +134,7 @@ def partition_for_clml(mod, params=None, **opts): [ transform.InferType(), RemoveDropoutPass(), + BinaryOpBroadcaster(), transform.FoldConstant(), transform.MergeComposite(clml_pattern_table()), transform.AnnotateTarget("clml", False), @@ -261,8 +292,6 @@ def concat_pattern(): def dense_pattern(): """Create a dense pattern.""" pattern = is_op("nn.dense")(wildcard(), is_constant()) - pattern = pattern.optional(lambda x: is_op("add")(x, is_constant())) - pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant())) return pattern def pad_pattern(): @@ -344,9 +373,19 @@ def check_conv_transpose(extract): def check_binary_op(extract): call = extract - if len(call.args[1].checked_type.shape) > 0: - return True - return False + # Scalars are not supported + if len(call.args[1].checked_type.shape) == 0: + return False + + for arg in call.args: + # Avoid any operators with dtype Int64 + if arg.checked_type.dtype == "int64": + return False + # No support for batch> 1 + if arg.checked_type.shape[0] > 1: + return False + + return True def check_pad_op(extract): call = extract @@ -377,6 +416,24 @@ def check_concat_op(extract): return True def check_default_op(extract): + call = extract + + if isinstance(call, tvm.relay.expr.TupleGetItem): + call = call.tuple_value + + # Avoid any operators with dtype Int64 + for arg in call.args: + if arg.checked_type.dtype == "int64": + return False + return True + + def check_batch_matmul_op(extract): + call = extract + # Only support single Matmul + if call.args[0].checked_type.shape[0] > 1: + return False + if call.args[1].checked_type.shape[0] > 1: + return False return True return [ @@ -394,7 +451,7 @@ def check_default_op(extract): ("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op), ("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op), ("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op), - ("clml.reshape", is_op("reshape")(wildcard()), check_default_op), + # ("clml.reshape", is_op("reshape")(wildcard()), check_default_op), ("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op), ("clml.max_pool2d", is_op("nn.max_pool2d")(wildcard()), check_default_op), ("clml.global_avg_pool2d", is_op("nn.global_avg_pool2d")(wildcard()), check_default_op), @@ -404,6 +461,11 @@ def check_default_op(extract): ("clml.batch_flatten", is_op("nn.batch_flatten")(wildcard()), check_default_op), ("clml.depth_to_space", is_op("nn.depth_to_space")(wildcard()), check_default_op), ("clml.upsampling", is_op("nn.upsampling")(wildcard()), check_upsampling_op), + ( + "clml.batch_matmul", + is_op("nn.batch_matmul")(wildcard(), wildcard()), + check_batch_matmul_op, + ), ] @@ -570,7 +632,9 @@ def __init__(self, cmod): runner.MakeDense($input_tensor, $weight_tensor, $output_tensor, - $bias_tensor, "$dtype");""" + std::vector ({$in_shape}), + std::vector ({$wt_shape}), + "$dtype");""" ) self.MakeSoftMax = Template( """ @@ -641,13 +705,12 @@ def __init__(self, cmod): " Output Count : $output_count\\n" ' Input MetaInfo\\n$input_meta\\n Output MetaInfo\\n$output_meta");' ) - self.MakeInputMetaInfo = Template( - " Input: $in_name\\n Dtype : $dtype\\n Shape : [$shape]" + " Input: $in_name\\n Dtype : $dtype\\n Shape : [$shape]\\n" ) self.MakeOutputMetaInfo = Template( - " Output: $out_name\\n Dtype : $dtype\\n Shape : [$shape]" + " Output: $out_name\\n Dtype : $dtype\\n Shape : [$shape]\\n" ) def get_src(self): @@ -666,23 +729,40 @@ def get_tensor_from_map( else: node = self.nodes[node_seq] dtype = str(node["attrs"]["dtype"][0][0]) + if node["op"] == "input": + self.clml_code.append("// Input Node") + node_out_name = self.sub_module_name + "_" + "input_" + str(node_seq) + else: + node_out_name = node["name"] if shape is None: shape = str(tuple(node["attrs"]["shape"][0][0]))[1:-1] self.clml_code.append( self.MakeCLMLTensor.substitute( - name=node["name"], shape=shape, dtype=dtype, layout=layout + name=node_out_name, shape=shape, dtype=dtype, layout=layout ) ) self.clml_code.append( - self.MapInsert.substitute(nid=node["name"], tensor_desc=node["name"]) + self.MapInsert.substitute(nid=node_out_name, tensor_desc=node_out_name) ) + if node["op"] == "input": + self.clml_code.append( + Template("runner.inputs.push_back($clml_input);").substitute( + clml_input=node_out_name + ) + ) + self.input_meta.append( + self.MakeInputMetaInfo.substitute( + in_name=node_out_name, dtype=dtype, shape=shape + ) + ) + if self.nodes[node_seq]["op"] == "const": self.clml_code.append( Template('runner.consts.push_back("$nid");').substitute(nid=node["name"]) ) - self.node_map[node_seq] = node["name"] - return node["name"] + self.node_map[node_seq] = node_out_name + return node_out_name def make_output_tensor( node, node_seq, shape=None, layout="CL_TENSOR_LAYOUT_OPTIMAL_QCOM", dtype="float32" @@ -697,40 +777,13 @@ def make_output_tensor( name=node_out_name, shape=shape, dtype=dtype, - layout="CL_TENSOR_LAYOUT_OPTIMAL_QCOM", + layout=layout, ) ) return node_out_name for node_seq, node in enumerate(self.nodes): - if node["op"] == "input": - self.clml_code.append("// Input Node") - dtype = str(node["attrs"]["dtype"][0][0]) - shape = str(tuple(node["attrs"]["shape"][0][0]))[1:-1] - node_out_name = self.sub_module_name + "_" + "input_" + str(node_seq) - self.clml_code.append( - self.MakeCLMLTensor.substitute( - name=node_out_name, - shape=shape, - dtype=dtype, - layout="CL_TENSOR_LAYOUT_OPTIMAL_QCOM", - ) - ) - self.clml_code.append( - self.MapInsert.substitute(nid=node_out_name, tensor_desc=node_out_name) - ) - self.clml_code.append( - Template("runner.inputs.push_back($clml_input);").substitute( - clml_input=node_out_name - ) - ) - self.node_map[node_seq] = node_out_name - self.input_meta.append( - self.MakeInputMetaInfo.substitute( - in_name=node_out_name, dtype=dtype, shape=shape - ) - ) - elif node["op"] == "kernel": + if node["op"] == "kernel": self.clml_code.append("// Kernel Node : " + node["name"]) if node["name"] == "nn.conv2d" or node["name"] == "nn.depthwise_conv2d": if "padding" in node["attrs"]: @@ -791,6 +844,7 @@ def make_output_tensor( bn_shape = [1, 1, 1, 1] bn_node = self.nodes[node["inputs"][bn_index][0]] bn_shape[axis] = bn_node["attrs"]["shape"][0][0] + dtype = bn_node["attrs"]["dtype"][0][0] bn_scale_tensor = get_tensor_from_map( node["inputs"][bn_index][0], @@ -858,6 +912,7 @@ def make_output_tensor( bn_shape = [1, 1, 1, 1] bn_node = self.nodes[node["inputs"][0][0]] bn_shape[axis] = bn_node["attrs"]["shape"][0][0] + dtype = bn_node["attrs"]["dtype"][0][0] bn_scale_tensor = get_tensor_from_map( node["inputs"][0][0], shape=str(tuple(bn_shape))[1:-1], dtype=dtype ) @@ -947,26 +1002,26 @@ def make_output_tensor( in_shape = tuple(in_node["attrs"]["shape"][0][0]) wt_shape = tuple(in_node["attrs"]["shape"][0][0]) input_tensor = get_tensor_from_map( - node["inputs"][0][0], shape=str(tuple([1, in_shape[1], 1, 1]))[1:-1] + node["inputs"][0][0], layout="CL_TENSOR_LAYOUT_NCHW_QCOM" ) weight_tensor = get_tensor_from_map( node["inputs"][1][0], - shape=str(tuple([wt_shape[0], wt_shape[1], 1, 1]))[1:-1], + shape=str(tuple([1, 1, wt_shape[0], wt_shape[1]]))[1:-1], + layout="CL_TENSOR_LAYOUT_NCHW_QCOM", ) - if len(node["inputs"]) == 3: - bias_tensor = "runner.unusedTensor" - else: - bias_tensor = get_tensor_from_map(node["inputs"][2][0]) - node_out_name = make_output_tensor( - node, node_seq, shape=str(tuple([1, wt_shape[0], 1, 1]))[1:-1] + node, + node_seq, + shape=str(tuple([in_shape[0], wt_shape[0], 1, 1]))[1:-1], + layout="CL_TENSOR_LAYOUT_NCHW_QCOM", ) self.clml_code.append( self.MakeDense.substitute( input_tensor=input_tensor, weight_tensor=weight_tensor, output_tensor=node_out_name, - bias_tensor=bias_tensor, + in_shape=str(in_shape)[1:-1], + wt_shape=str(wt_shape)[1:-1], dtype=node["attrs"]["dtype"][0][0], ) ) @@ -1045,7 +1100,7 @@ def make_output_tensor( ) self.node_map[node_seq] = node_out_name - elif node["op"] != "const": + elif node["op"] not in ["const", "input"]: print("Unknown Node type:", node["op"]) # Populate outputs @@ -1086,8 +1141,8 @@ def make_output_tensor( name=self.sub_module_name, input_count=len(self.input_meta), output_count=len(self.output_meta), - input_meta="\n".join(self.input_meta), - output_meta="\n".join(self.output_meta), + input_meta="\\\n".join(self.input_meta), + output_meta="\\\n".join(self.output_meta), ) ) diff --git a/src/runtime/contrib/clml/clml_memory_planner.cc b/src/runtime/contrib/clml/clml_memory_planner.cc new file mode 100644 index 0000000000000..9e61f557f48e9 --- /dev/null +++ b/src/runtime/contrib/clml/clml_memory_planner.cc @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/clml/clml_memory_planner.cc + * \brief Various memory planning methods. + */ +#ifdef TVM_GRAPH_EXECUTOR_CLML +#include "clml_memory_planner.h" + +#include +#include + +#include "clml_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +/*! + * Release memory after use. + * + */ +void FreeMemory(CachedLayer* layer, int nid) { + LOG_MEM << "FreeMemory:" << nid; + if (layer->storage_ref_map.find(nid) != layer->storage_ref_map.end()) { + LOG_MEM << "Ref Cnt:" << layer->storage_ref_map[nid]; + layer->storage_ref_map[nid]--; + if (0 == layer->storage_ref_map[nid]) { + LOG_MEM << "Ref Cnt Nill"; + // Look into on-chip allocation + for (auto it = layer->on_chip_pool_alloc_info.begin(); + it != layer->on_chip_pool_alloc_info.end(); it++) { + if (it->second == nid) { + LOG_MEM << "Free Segment:" << it->first << " Nid:" << nid; + layer->in_chip_total_free += layer->on_chip_pool_size[it->first]; + layer->in_chip_total_alloc -= layer->on_chip_pool_size[it->first]; + layer->on_chip_pool_alloc_info.erase(it->first); + return; + } + } + // Look into DDR allocation + if (layer->ddr_alloc_plan.find(nid) != layer->ddr_alloc_plan.end()) { + LOG_MEM << "Free DDR segment from local pool"; + layer->ddr_storage_ref_map[layer->ddr_alloc_plan[nid]].second = false; + return; + } + LOG_MEM << "*** Not a managed memory buffer"; + } + } else { + LOG_MEM << "Not in storage ref map :" << nid; + } +} + +/*! + * \brief Partition and allocate + * + */ +size_t PartitionAndAllocate(CachedLayer* layer, size_t segment_start, size_t size, bool is_left) { + LOG_MEM << "PartitionAndAllocate:" << segment_start << " Size:" << size + << " Is Begin:" << is_left; + size_t segment_size = layer->on_chip_pool_size[segment_start]; + size_t left_space = segment_size - size; + + layer->in_chip_total_free -= size; + layer->in_chip_total_alloc += size; + + if (is_left) { + // Start allocation + layer->on_chip_pool_size[segment_start] = size; + if (left_space) { + layer->on_chip_pool_size.insert({segment_start + size, left_space}); + } + return segment_start; + } else { + // End allocation + if (left_space) { + layer->on_chip_pool_size[segment_start] = left_space; + } + layer->on_chip_pool_size.insert({segment_start + left_space, size}); + return segment_start + left_space; + } +} + +/*! + * \brief Ping-Pong allocation with in best fit + * + */ +size_t PingPongAllocate(CachedLayer* layer, const std::map& segments, size_t size) { + /* + * segments contains all free segments details (start, size) that can fit the requirement + * PingPong Allocation Strategy: + * Here we find the smallest segment among all. + * We allocate at begining or end of this segment based on the ping-pong flag. + * Ping-pong allocation helps to have largest possible free segment at center + * for most of the graphs. + * + */ + ssize_t free_start; + ssize_t free_size; + ssize_t last_found_size = CLMLWorkspace::Global()->onchip_mem_size + 1; + + for (auto it = segments.begin(); it != segments.end(); it++) { + if (it->second < last_found_size) { + free_start = it->first; + free_size = it->second; + last_found_size = it->second; + LOG_MEM << "Mem Found:" << free_start << " Size:" << free_size; + } + } + + LOG_MEM << "Alloc On-chip Mem:" << free_start << " Size:" << free_size + << " PingPong:" << layer->alloc_ping_pong; + + // Allocate on-chip memory + layer->alloc_ping_pong ^= 1; + return PartitionAndAllocate(layer, free_start, size, layer->alloc_ping_pong); +} + +/*! + * \brief Allocate on-chip memory. + * + */ +size_t RequestOnChipMemory(CachedLayer* layer, size_t size) { + LOG_MEM << "Request On-Chip Mem:" << size; + // Optimize for any fragmented parts + bool any_merge = true; + while (any_merge) { + any_merge = false; + for (auto it = layer->on_chip_pool_size.begin(); it != layer->on_chip_pool_size.end(); it++) { + if ((layer->on_chip_pool_alloc_info.find(it->first) == + layer->on_chip_pool_alloc_info.end()) && + (layer->on_chip_pool_alloc_info.find(it->first + it->second) == + layer->on_chip_pool_alloc_info.end()) && + (it->first + it->second < CLMLWorkspace::Global()->onchip_mem_size)) { + size_t left_begin = it->first; + size_t left_size = it->second; + size_t right_size = layer->on_chip_pool_size[it->first + it->second]; + LOG_MEM << "Merge:" << left_begin << " Size:" << left_size << " with :" << right_size; + layer->on_chip_pool_size[left_begin] = left_size + right_size; + layer->on_chip_pool_size.erase(left_begin + left_size); + any_merge = true; + break; + } + } + } + + // Look for any best fit free fragment + std::map feasible_segments; + for (auto it = layer->on_chip_pool_size.begin(); it != layer->on_chip_pool_size.end(); it++) { + if (layer->on_chip_pool_alloc_info.find(it->first) == layer->on_chip_pool_alloc_info.end()) { + if (it->second >= size) { + LOG_MEM << "Mem Pool:" << it->first << " - " << it->first + it->second << ":" << it->second + << " - Free"; + feasible_segments.insert({it->first, it->second}); + } else { + LOG_MEM << "Mem Pool:" << it->first << " - " << it->first + it->second << ":" << it->second + << " - Doesn't fit"; + } + } else { + LOG_MEM << "Mem Pool:" << it->first << " - " << it->first + it->second << ":" << it->second + << " - Busy"; + } + } + if (0 == feasible_segments.size()) { + LOG_MEM << "No Suitable Mem Found:" << size << " Free Size:" << layer->in_chip_total_free; + if (size <= layer->in_chip_total_free) { + LOG_STATS << "*** ALERT ***: Couldn't allocate due to fragmentation:" << size + << " Total Free:" << layer->in_chip_total_free; + layer->on_chip_alert_fail += size; + } + return -1; + } + + return PingPongAllocate(layer, feasible_segments, size); +} + +/*! + * \brief Allocate DDR memory for requested size. + * + */ +cl_mem RequestDDRMemory(CachedLayer* layer, size_t size) { + // Look for local storage map for a best fit + auto cws = CLMLWorkspace::Global(); + cl_mem memptr = nullptr; + size_t best_fit = INT_MAX; + for (auto it = layer->ddr_storage_ref_map.begin(); it != layer->ddr_storage_ref_map.end(); it++) { + if ((it->second.first >= size) && (false == it->second.second)) { + if (best_fit > it->second.first) { + memptr = it->first; + best_fit = it->second.first; + } + } + } + + if (memptr) { + LOG_MEM << "Reuse from local pool"; + layer->ddr_storage_ref_map[memptr].second = true; + return memptr; + } + // No available buffer in local pool, look for global pool + for (auto it = cws->ddr_global_pool.begin(); it != cws->ddr_global_pool.end(); it++) { + if ((it->second.first >= size) && + (layer->ddr_storage_ref_map.find(it->first) == layer->ddr_storage_ref_map.end())) { + // Found a buffer in global pool. Insert in local pool and then use. + if (best_fit > it->second.first) { + memptr = it->first; + best_fit = it->second.first; + } + } + } + + if (memptr) { + LOG_MEM << "Reuse from global pool"; + cws->ddr_global_pool[memptr].second += 1; + layer->ddr_storage_ref_map.insert( + {memptr, std::make_pair(cws->ddr_global_pool[memptr].first, true)}); + return memptr; + } + + // Allocate a fresh buffer in global then use in local pool. + LOG_MEM << "Allocating fresh buffer in global pool"; + memptr = AllocateDDRTensorMemory(size); + cws->ddr_global_pool.insert({memptr, std::make_pair(size, 1)}); + layer->ddr_storage_ref_map.insert({memptr, std::make_pair(size, true)}); + + return memptr; +} + +/*! + * \brief Release memory from global pool. + * + */ +void ReleaseDDRMemory(cl_mem memptr) { + cl_int result; + auto cws = CLMLWorkspace::Global(); + cws->ddr_global_pool[memptr].second -= 1; + if (0 == cws->ddr_global_pool[memptr].second) { + LOG_MEM << "Release DDR mem from global pool"; + result = clReleaseMemObject(memptr); + ICHECK(result == CL_SUCCESS) << "clReleaseMemObject:" << result; + cws->ddr_global_pool.erase(memptr); + } +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif diff --git a/src/runtime/contrib/clml/clml_memory_planner.h b/src/runtime/contrib/clml/clml_memory_planner.h new file mode 100644 index 0000000000000..b4e34e4f32d47 --- /dev/null +++ b/src/runtime/contrib/clml/clml_memory_planner.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/clml/clml_memory_planner.h + * \brief CLML memory planner header + */ +#ifndef TVM_RUNTIME_CONTRIB_CLML_CLML_MEMORY_PLANNER_H_ +#define TVM_RUNTIME_CONTRIB_CLML_CLML_MEMORY_PLANNER_H_ + +#include "clml_runtime.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +void FreeMemory(CachedLayer* layer, int nid); + +void ReleaseDDRMemory(cl_mem memptr); + +size_t RequestOnChipMemory(CachedLayer* layer, size_t size); + +cl_mem RequestDDRMemory(CachedLayer* layer, size_t size); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_CLML_CLML_MEMORY_PLANNER_H_ diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc index 7c716e68763bf..1146ff7249a41 100644 --- a/src/runtime/contrib/clml/clml_runtime.cc +++ b/src/runtime/contrib/clml/clml_runtime.cc @@ -21,33 +21,12 @@ * \file src/runtime/contrib/clml/clml_runtime.cc * \brief A simple JSON runtime for CLML. */ +#include "clml_runtime.h" -#include -#include #ifdef TVM_GRAPH_EXECUTOR_CLML -#include +#include "clml_memory_planner.h" +#include "clml_utils.h" #endif -#include -#include -#include -#include - -#include -#include -#include - -#include "../../file_utils.h" -#include "../../opencl/opencl_common.h" -#include "../json/json_node.h" -#include "../json/json_runtime.h" - -#define CAT_I(a, b) a##b -#define CAT(a, b) CAT_I(a, b) -#define GET_ML_INTERFACE CAT(CAT(clGetMLInterfaceV, CL_QCOM_ML_OPS_H_MAJOR_VERSION), QCOM) -#define GET_ML_API_INTERFACE CAT(CAT(CLMLInterfaceV, CL_QCOM_ML_OPS_H_MAJOR_VERSION), QCOM) - -/*! \brief Magic number for CLML Tuning cache entry */ -static const uint64_t kTVMCLMLTuningCacheMagic = 0x434C4D4C54554E45; namespace tvm { namespace runtime { @@ -56,6 +35,88 @@ namespace contrib { using namespace tvm::runtime::json; using JSONGraphNode = tvm::runtime::json::JSONGraphNode; +#ifdef TVM_GRAPH_EXECUTOR_CLML +CLMLThreadEntry* CLMLWorkspace::GetThreadEntry() { return CLMLThreadEntry::ThreadLocal(); } + +CLMLWorkspace* CLMLWorkspace::Global() { + static CLMLWorkspace* inst = new CLMLWorkspace(); + return inst; +} + +CLMLWorkspace::CLMLWorkspace() { + cl_int result = 0; + workspace = cl::OpenCLWorkspace::Global(); + workspace->Init(); + tentry = workspace->GetThreadEntry(); + + device_id = workspace->GetCLDeviceID(tentry->device.device_id); + platform_id = workspace->device_to_platform[device_id]; + + // Print extensions + size_t reqd_size = 0; + result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, 0, nullptr, &reqd_size); + ICHECK(reqd_size > 0u && result == CL_SUCCESS) << "clGetDeviceInfo:" << result; + std::vector extn_buf(reqd_size); + result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, reqd_size, extn_buf.data(), nullptr); + ICHECK(result == CL_SUCCESS) << "clGetDeviceInfo:" << result; + std::string extensions(extn_buf.data()); + LOG(WARNING) << "OpenCL Extensions:" << extensions; + + if (extensions.find("cl_qcom_ml_ops") == std::string::npos) { + LOG(FATAL) << "CLML Runtime Init: Qualcomm extn not present.\n"; + return; + } + is_recordable_queue = (extensions.find("cl_qcom_recordable_queues") != std::string::npos); + is_on_chip_memory = (extensions.find("cl_qcom_onchip_global_memory") != std::string::npos); + LOG(WARNING) << "Recordable Queues Support :" << is_recordable_queue; + LOG(WARNING) << "On chip Memory Support :" << is_on_chip_memory; + + if (is_on_chip_memory) { + result = clGetDeviceInfo(device_id, CL_DEVICE_ONCHIP_GLOBAL_MEM_SIZE_QCOM, + sizeof(onchip_mem_size), &onchip_mem_size, NULL); + ICHECK(result == CL_SUCCESS) << "clGetDeviceInfo(CL_DEVICE_ONCHIP_GLOBAL_MEM_SIZE_QCOM):" + << result; + LOG(WARNING) << "On chip memory size:" << onchip_mem_size; + } + + // Query and Get CLML Interface + static const cl_uint MAX_VERSIONS = 256; + cl_int majorVersions[MAX_VERSIONS]; + cl_int minorVersions[MAX_VERSIONS]; + cl_uint numVersions = 0; + result = clQueryMLInterfaceVersionsQCOM(nullptr, nullptr, 0, &numVersions); + ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; + ICHECK(numVersions > 0u); + ICHECK(numVersions <= MAX_VERSIONS); + + result = clQueryMLInterfaceVersionsQCOM(majorVersions, minorVersions, numVersions, nullptr); + ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; + + for (cl_uint i = 0; i < numVersions; ++i) { + if (majorVersions[i] == CL_QCOM_ML_OPS_H_MAJOR_VERSION) { + h_ClmlIntf = GET_ML_INTERFACE(0); + LOG(WARNING) << "CLML Target version:" << majorVersions[i]; + break; + } + } + ICHECK(h_ClmlIntf != nullptr) + << "clGetMLInterfaceVxQCOM:" << result + << " Perhaps there is mispatch between CLML SDK version to target supported version:" + << majorVersions[numVersions - 1]; + char* tune_flag; + if ((tune_flag = getenv("CLML_IS_TUNING_RUN"))) + is_tuning_run = std::stoi(tune_flag); + else + is_tuning_run = 0; + + if (!(tuning_file = getenv("CLML_TUNING_CACHE"))) this->is_tuning_run = 0; +} + +typedef dmlc::ThreadLocalStore CLMLThreadStore; + +CLMLThreadEntry* CLMLThreadEntry::ThreadLocal() { return CLMLThreadStore::Get(); } +#endif + class CLMLRuntime : public JSONRuntimeBase { public: /*! @@ -73,33 +134,42 @@ class CLMLRuntime : public JSONRuntimeBase { ~CLMLRuntime() { #ifdef TVM_GRAPH_EXECUTOR_CLML cl_int result = 0; - if (this->is_tuning_run) { - result = h_ClmlIntf->clReleaseMLTuningCacheQCOM(this->tuning_cache); + if (this->layer_.tuning_cache) { + result = CLML_INTF->clReleaseMLTuningCacheQCOM(this->layer_.tuning_cache); ICHECK(result == CL_SUCCESS) << "clReleaseMLTuningCacheQCOM:" << result; } for (auto it = this->layer_.storage_map.begin(); it != this->layer_.storage_map.end(); it++) { auto tensor_desc = it->second.first; - result = h_ClmlIntf->clReleaseMLTensorQCOM(tensor_desc->tensor); + result = CLML_INTF->clReleaseMLTensorQCOM(tensor_desc->tensor); ICHECK(result == CL_SUCCESS) << "clReleaseMLTensorQCOM:" << result; - result = clReleaseMemObject(tensor_desc->memory); - ICHECK(result == CL_SUCCESS) << "clReleaseMemObject:" << result; + if (this->layer_.ddr_storage_ref_map.find(tensor_desc->memory) != + this->layer_.ddr_storage_ref_map.end()) { + ReleaseDDRMemory(tensor_desc->memory); + } else { + result = clReleaseMemObject(tensor_desc->memory); + ICHECK(result == CL_SUCCESS) << "clReleaseMemObject:" << result; + } } for (size_t i = 0; i < this->layer_.function.size(); ++i) { - result = h_ClmlIntf->clReleaseMLOpQCOM(this->layer_.function[i]); + result = CLML_INTF->clReleaseMLOpQCOM(this->layer_.function[i]); ICHECK(result == CL_SUCCESS) << "clReleaseMLOpQCOM:" << result; } for (auto it = this->layer_.in_placeholder.begin(); it != this->layer_.in_placeholder.end(); it++) { - result = h_ClmlIntf->clReleaseMLTensorQCOM((*it)->tensor); + result = CLML_INTF->clReleaseMLTensorQCOM(it->second->tensor); ICHECK(result == CL_SUCCESS) << "clReleaseMLTensorQCOM:" << result; } for (auto it = this->layer_.out_placeholder.begin(); it != this->layer_.out_placeholder.end(); it++) { - result = h_ClmlIntf->clReleaseMLTensorQCOM((*it)->tensor); + result = CLML_INTF->clReleaseMLTensorQCOM((*it)->tensor); ICHECK(result == CL_SUCCESS) << "clReleaseMLTensorQCOM:" << result; } - result = h_ClmlIntf->clReleaseMLTensorMemoryDescriptorSetQCOM(layer_.descriptorSet); + result = CLML_INTF->clReleaseMLTensorMemoryDescriptorSetQCOM(layer_.descriptorSet); ICHECK(result == CL_SUCCESS) << "clReleaseMLTensorMemoryDescriptorSetQCOM:" << result; + + if (this->layer_.recordable_queue) { + clReleaseCommandQueue(this->layer_.recordable_queue); + } #endif } @@ -129,66 +199,27 @@ class CLMLRuntime : public JSONRuntimeBase { } #ifdef TVM_GRAPH_EXECUTOR_CLML - std::vector GetVectorValues(const std::vector& val) { - std::vector array; - for (auto i : val) { - array.push_back((cl_uint)stoi(i)); - } - return array; - } - void InitCLML() { // Setup CLML Context cl_int result = 0; - workspace = cl::OpenCLWorkspace::Global(); - workspace->Init(); - tentry = workspace->GetThreadEntry(); + cws = CLMLWorkspace::Global(); - if (!ExtensionStringPresent()) { - LOG(FATAL) << "CLML Runtime Init: Qualcomm extn not present.\n"; - return; - } - device_id = workspace->GetCLDeviceID(tentry->device.device_id); - platform_id = workspace->device_to_platform[device_id]; - - // Query and Get CLML Interface - static const cl_uint MAX_VERSIONS = 256; - cl_int majorVersions[MAX_VERSIONS]; - cl_int minorVersions[MAX_VERSIONS]; - cl_uint numVersions = 0; - result = clQueryMLInterfaceVersionsQCOM(nullptr, nullptr, 0, &numVersions); - ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; - ICHECK(numVersions > 0u); - ICHECK(numVersions <= MAX_VERSIONS); - - result = clQueryMLInterfaceVersionsQCOM(majorVersions, minorVersions, numVersions, nullptr); - ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result; - - for (cl_uint i = 0; i < numVersions; ++i) { - if (majorVersions[i] == CL_QCOM_ML_OPS_H_MAJOR_VERSION) { - h_ClmlIntf = GET_ML_INTERFACE(0); - LOG(WARNING) << "CLML Target version:" << majorVersions[i]; - break; - } + if (cws->is_recordable_queue) { + this->layer_.recordable_queue = + clCreateCommandQueue(CLML_CTX, cws->device_id, CL_QUEUE_RECORDABLE_QCOM, &result); + ICHECK(result == CL_SUCCESS) << "clCreateCommandQueue - Recordable:" << result; + + this->layer_.recording = clNewRecordingQCOM(this->layer_.recordable_queue, &result); + ICHECK(result == CL_SUCCESS) << "clNewRecordingQCOM:" << result; } - ICHECK(h_ClmlIntf != nullptr) - << "clGetMLInterfaceVxQCOM:" << result - << " Perhaps there is mispatch between CLML SDK version to target supported version:" - << majorVersions[numVersions - 1]; - char* tune_flag; - if ((tune_flag = getenv("CLML_IS_TUNING_RUN"))) - this->is_tuning_run = std::stoi(tune_flag); - else - this->is_tuning_run = 0; - if (!(tuning_file = getenv("CLML_TUNING_CACHE"))) this->is_tuning_run = 0; // A Tuning run, so create the cache from scratch - result = h_ClmlIntf->clCreateMLTuningCacheQCOM(&tuning_cache); + result = CLML_INTF->clCreateMLTuningCacheQCOM(&layer_.tuning_cache); ICHECK(result == CL_SUCCESS) << "clCreateMLTuningCacheQCOM:" << result; - if (!this->is_tuning_run && this->tuning_file) { + if (!cws->is_tuning_run && cws->tuning_file) { std::vector tune_buffer; std::string tune_blob; - LoadBinaryFromFile(this->tuning_file, &tune_blob); + LoadBinaryFromFile(cws->tuning_file, &tune_blob); dmlc::MemoryStringStream mstrm(const_cast(&tune_blob)); dmlc::Stream* strm = &mstrm; @@ -198,7 +229,6 @@ class CLMLRuntime : public JSONRuntimeBase { if (header != kTVMCLMLTuningCacheMagic) break; if (!strm->Read(&reserve)) break; if (!strm->Read(&tune_symbol)) break; - LOG(INFO) << "Tuning Cache Symbol:" << tune_symbol; if (tune_symbol == clml_symbol) { strm->Read(&tune_buffer); break; @@ -211,59 +241,16 @@ class CLMLRuntime : public JSONRuntimeBase { if (tune_buffer.size()) { LOG(INFO) << "Loading tuning cache for symbol:" << clml_symbol << " size:" << tune_buffer.size(); - result = h_ClmlIntf->clLoadMLTuningCacheQCOM(tuning_cache, tune_buffer.size(), - tune_buffer.data()); + result = CLML_INTF->clLoadMLTuningCacheQCOM(layer_.tuning_cache, tune_buffer.size(), + tune_buffer.data()); ICHECK(result == CL_SUCCESS) << "clLoadMLTuningCacheQCOM:" << result; } else { LOG(WARNING) << "Tuning cache not cound for symbol :" << clml_symbol << " in file " - << this->tuning_file; + << cws->tuning_file; } } } - std::vector readBinFile(const std::string& filename) { - std::ifstream fin(filename, std::ios::binary | std::ios::ate); - if (!fin.good()) { - LOG(FATAL) << "ERROR: Could not load tuning cache file: " + filename; - } - ICHECK(fin.good()); - int64_t size = fin.tellg(); - fin.seekg(0, std::ios::beg); - std::vector buffer(static_cast(size)); - char* ptr = reinterpret_cast(buffer.data()); - fin.read(ptr, size); - ICHECK(fin.good()); - return buffer; - } - - void CopyDataToCLMLTensor(std::shared_ptr tensor, void* data, - cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM) { - cl_int result = 0; - cl_event evt = nullptr; - result = h_ClmlIntf->clEnqueueWriteMLTensorDataQCOM(workspace->GetQueue(tentry->device), data, - layout, tensor->tensor, tensor->memory, - 0, // n waitlist - nullptr, // waitlist - &evt); // event - ICHECK((evt != nullptr) && result == CL_SUCCESS) << "clEnqueueWriteMLTensorDataQCOM:" << result; - } - - void CopyDataFromCLMLTensor(std::shared_ptr tensor, void* data, - cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM) { - cl_int result = 0; - cl_event readEvent = nullptr; - // Read the output tensor - result = h_ClmlIntf->clEnqueueReadMLTensorDataQCOM(workspace->GetQueue(tentry->device), - tensor->tensor, tensor->memory, data, layout, - 0, // n waitlist - nullptr, // waitlist - &readEvent); // event - ICHECK(result == CL_SUCCESS) << "clEnqueueReadMLTensorDataQCOM:" << result; - - result = clWaitForEvents(1, &readEvent); - ICHECK(result == CL_SUCCESS) << "clWaitForEvents:" << result; - } - /*! * \brief Unpack inputs and outputs and run inference on a given layer. * @@ -273,8 +260,8 @@ class CLMLRuntime : public JSONRuntimeBase { */ void Run() override { cl_int result = 0; - cl_command_queue queue = workspace->GetQueue(tentry->device); - std::vector& evts = workspace->GetEventQueue(tentry->device); + cl_command_queue queue = CLML_QUEUE; + std::vector& evts = cws->workspace->GetEventQueue(cws->tentry->device); for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; uint32_t eid = EntryID(nid, 0); @@ -285,19 +272,19 @@ class CLMLRuntime : public JSONRuntimeBase { isize *= data_entry_[eid]->shape[j]; } if (kDLCPU == data_entry_[eid]->device.device_type) { - CopyDataToCLMLTensor(layer_.inputs[i], data); + CopyDataToCLMLTensor(layer_.inputs[nid], data); } else if (kDLOpenCL == data_entry_[eid]->device.device_type) { - layer_.in_placeholder[i]->memory = static_cast( + layer_.in_placeholder[nid]->memory = static_cast( ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); cl_event cpy_evt = nullptr; cl_event* evt = &cpy_evt; - if (workspace->IsProfiling(tentry->device)) { + if (cws->workspace->IsProfiling(cws->tentry->device)) { evts.resize(evts.size() + 1); evt = &(evts.back()); } - result = h_ClmlIntf->clEnqueueCopyMLTensorDataQCOM( - queue, layer_.in_placeholder[i]->tensor, layer_.in_placeholder[i]->memory, - layer_.inputs[i]->tensor, layer_.inputs[i]->memory, 0, nullptr, evt); + result = CLML_INTF->clEnqueueCopyMLTensorDataQCOM( + queue, layer_.in_placeholder[nid]->tensor, layer_.in_placeholder[nid]->memory, + layer_.inputs[nid]->tensor, layer_.inputs[nid]->memory, 0, NULL, evt); ICHECK(result == CL_SUCCESS) << "clEnqueueCopyMLTensorDataQCOM:" << result; } else { DLDataType tvm_dtype = const_cast(data_entry_[eid])->dtype; @@ -306,36 +293,57 @@ class CLMLRuntime : public JSONRuntimeBase { void* tmpptr = reinterpret_cast(malloc(isize * dtype_size)); TVMArrayCopyToBytes(const_cast(data_entry_[eid]), const_cast(tmpptr), isize * dtype_size); - CopyDataToCLMLTensor(layer_.inputs[i], tmpptr); + CopyDataToCLMLTensor(layer_.inputs[nid], tmpptr); free(tmpptr); } } } int64_t duration = 0; - for (size_t i = 0; i < this->layer_.function.size(); ++i) { - // Make CLML subgraphs accounted by OpenCLTimerNode. - + if (cws->is_recordable_queue) { if (getenv("CLML_PROFILING")) { Timer t; auto f = Registry::Get(std::string("profiling.timer.opencl")); - t = f->operator()(tentry->device); + t = f->operator()(cws->tentry->device); t->Start(); - queue = workspace->GetQueue(tentry->device); + queue = CLML_QUEUE; evts.resize(evts.size() + 1); cl_event* evt = &(evts.back()); - - result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i], - this->layer_.descriptorSet, 0, nullptr, evt); + result = CLML_INTF->clEnqueueRecordingMLOpQCOM(queue, this->layer_.recording, 0, nullptr, 0, + nullptr, 0, nullptr, 0, nullptr, 0, nullptr, + 0, nullptr, 0, nullptr, 0, nullptr, evt); + ICHECK(result == CL_SUCCESS) << "clEnqueueRecordingMLOpQCOM:" << result; t->Stop(); duration += t->SyncAndGetElapsedNanos(); - LOG(WARNING) << "Layer:" << this->layer_.layer_names[i] - << " Duration:" << t->SyncAndGetElapsedNanos(); } else { - result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i], - this->layer_.descriptorSet, 0, nullptr, nullptr); + result = CLML_INTF->clEnqueueRecordingMLOpQCOM(queue, this->layer_.recording, 0, nullptr, 0, + nullptr, 0, nullptr, 0, nullptr, 0, nullptr, + 0, nullptr, 0, nullptr, 0, nullptr, nullptr); + ICHECK(result == CL_SUCCESS) << "clEnqueueRecordingMLOpQCOM:" << result; + } + } else { + for (size_t i = 0; i < this->layer_.function.size(); ++i) { + // Make CLML subgraphs accounted by OpenCLTimerNode. + if (getenv("CLML_PROFILING")) { + Timer t; + auto f = Registry::Get(std::string("profiling.timer.opencl")); + t = f->operator()(cws->tentry->device); + t->Start(); + queue = CLML_QUEUE; + evts.resize(evts.size() + 1); + cl_event* evt = &(evts.back()); + result = CLML_INTF->clEnqueueMLOpQCOM(queue, this->layer_.function[i], + this->layer_.descriptorSet, 0, nullptr, evt); + t->Stop(); + duration += t->SyncAndGetElapsedNanos(); + LOG(WARNING) << "Layer:" << this->layer_.layer_names[i] + << " Duration:" << t->SyncAndGetElapsedNanos(); + } else { + result = CLML_INTF->clEnqueueMLOpQCOM(queue, this->layer_.function[i], + this->layer_.descriptorSet, 0, nullptr, nullptr); + } + ICHECK(result == CL_SUCCESS) << "clEnqueueMLOpQCOM:" << result; } - ICHECK(result == CL_SUCCESS) << "clEnqueueMLOpQCOM:" << result; } if (getenv("CLML_PROFILING")) { LOG(WARNING) << "Total Duration for " << clml_symbol << " is:" << duration; @@ -356,11 +364,11 @@ class CLMLRuntime : public JSONRuntimeBase { ((cl::BufferDescriptor*)const_cast(data_entry_[eid])->data)->buffer); cl_event cpy_evt = nullptr; cl_event* evt = &cpy_evt; - if (workspace->IsProfiling(tentry->device)) { + if (cws->workspace->IsProfiling(cws->tentry->device)) { evts.resize(evts.size() + 1); evt = &(evts.back()); } - result = h_ClmlIntf->clEnqueueCopyMLTensorDataQCOM( + result = CLML_INTF->clEnqueueCopyMLTensorDataQCOM( queue, layer_.outputs[i]->tensor, layer_.outputs[i]->memory, layer_.out_placeholder[i]->tensor, layer_.out_placeholder[i]->memory, 0, nullptr, evt); ICHECK(result == CL_SUCCESS) << "clEnqueueCopyMLTensorDataQCOM:" << result; @@ -379,6 +387,164 @@ class CLMLRuntime : public JSONRuntimeBase { } private: + /*! + * \brief check if the nid is graph output tensor or not. + * + */ + bool IsOutputTensor(int nid) { + for (size_t i = 0; i < outputs_.size(); ++i) { + if (nid == outputs_[i].id_) return true; + } + return false; + } + + /*! + * \brief Initialize memory pool. + * + */ + void InitMemoryPool(void) { + layer_.on_chip_pool_size.clear(); + layer_.on_chip_pool_size.insert({0, cws->onchip_mem_size}); + layer_.on_chip_pool_alloc_info.clear(); + layer_.alloc_ping_pong = true; + layer_.in_chip_total_free = cws->onchip_mem_size; + layer_.in_chip_total_alloc = 0; + layer_.on_chip_alert_fail = 0; + } + + /*! + * \brief Plan Memory for activations to allocate on on-chip global memory where ever possible. + * + */ + void PlanMemory() { + InitMemoryPool(); + // Build the ref count table for all activation tensors. + LOG_MEM << "Build Ref Map"; + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + if (node.GetOpType() == "kernel") { + std::vector inputs = node.GetInputs(); + for (auto& input_node : inputs) { + if (nodes_[input_node.id_].GetOpType() != "const") { + if (layer_.storage_ref_map.find(input_node.id_) == layer_.storage_ref_map.end()) { + layer_.storage_ref_map.insert({input_node.id_, 1}); + layer_.life_span.insert({input_node.id_, nid}); + } else { + layer_.storage_ref_map[input_node.id_]++; + layer_.life_span[input_node.id_] = nid; + } + } + } + } + } + LOG_MEM << "Print Ref Map"; + + for (auto it = layer_.storage_ref_map.begin(); it != layer_.storage_ref_map.end(); it++) { + LOG_MEM << "RefMap:" << it->first << " Count:" << it->second + << "Life Span:" << layer_.life_span[it->first]; + } + + for (size_t nid = 0; nid < nodes_.size(); ++nid) { + const auto& node = nodes_[nid]; + uint32_t size = 0; + cl_int result = CL_OUT_OF_HOST_MEMORY; + result = CLML_INTF->clGetMLTensorMemorySizeQCOM(CLML_CTX, + layer_.storage_map[nid].first->tensor, &size); + ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result; + + if ((node.GetOpType() == "kernel") || (node.GetOpType() == "input")) { + std::vector inputs = node.GetInputs(); + LOG_MEM << "Request :" << size << " Nid:" << nid; + size_t offset = -1; + // On-chip memory only for intermediate tensors with in recording scope. + if ((cws->is_on_chip_memory) && (!IsOutputTensor(nid)) && (node.GetOpType() != "input")) { + offset = RequestOnChipMemory(&this->layer_, size); + } + if (-1 != offset) { + LOG_MEM << "Got On-Chip Mem:" << offset << "Nid:" << nid; + layer_.on_chip_pool_alloc_info.insert({offset, nid}); + layer_.on_chip_alloc_plan.insert({nid, std::make_pair(size, offset)}); + } else { + layer_.on_chip_reject.insert({nid, size}); + // DDR Allocation + auto ddr_mem = RequestDDRMemory(&this->layer_, size); + LOG_MEM << "Alloc DDR from global pool for nid:" << nid << " Type:" << node.GetOpType(); + layer_.ddr_alloc_plan.insert({nid, ddr_mem}); + } + + // Now free up the input tensors on-chip memory for reuse. + for (auto& input_node : inputs) { + if (nodes_[input_node.id_].GetOpType() != "const") { + LOG_MEM << "Free Input Mem:" << input_node.id_; + FreeMemory(&this->layer_, input_node.id_); + } + } + } + } + + // Stats dump + size_t in_chip_total_alloc = 0; + size_t total_reject = 0; + for (auto it = layer_.on_chip_alloc_plan.begin(); it != layer_.on_chip_alloc_plan.end(); it++) { + LOG_STATS << " On-chip Alloc:" << it->first << " Size:" << it->second.first + << " Offset:" << it->second.second; + in_chip_total_alloc += it->second.first; + } + + for (auto it = layer_.on_chip_reject.begin(); it != layer_.on_chip_reject.end(); it++) { + LOG_STATS << "Reject:" << it->first << " Size:" << it->second; + total_reject += it->second; + } + LOG_STATS << "Total On-chip Alloc:" << in_chip_total_alloc + total_reject + << " On-Chip:" << in_chip_total_alloc << " Reject:" << total_reject + << " Alert Fail:" << layer_.on_chip_alert_fail; + + auto cws = CLMLWorkspace::Global(); + for (auto it = cws->ddr_global_pool.begin(); it != cws->ddr_global_pool.end(); it++) { + LOG_STATS << "DDR Global pool - size:" << it->second.first << " Ref:" << it->second.second; + } + for (auto it = this->layer_.ddr_storage_ref_map.begin(); + it != this->layer_.ddr_storage_ref_map.end(); it++) { + LOG_STATS << "DDR Local pool - size:" << it->second.first << " Ref cnt:" << it->second.second; + } + } + + /*! + * \brief Create an CLML tensor from JSON node entry. Lookup storage map before creation. + * + * \param tensor The tensor as Node Entry . + * \param shape shape information of tensor + * \param layout the tensor layout to be used + * \param dtype tensor data type + * \return CLML Tensor descriptor. + */ + std::shared_ptr MakeCLMLTensorFromJSONEntry( + const JSONGraphNodeEntry& tensor, std::vector shape, cl_ml_tensor_layout_qcom layout, + cl_uint dtype) { + JSONGraphNode node = nodes_[tensor.id_]; + + if (this->layer_.storage_map.find(tensor.id_) == this->layer_.storage_map.end()) { + void* node_data = nullptr; + if (node.GetOpType() == "const") { + node_data = data_entry_[EntryID(tensor)]->data; + } + auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, dtype, node_data, shape); + this->layer_.storage_map.insert({tensor.id_, std::make_pair(clml_tensor, node)}); + + if ("input" == node.GetOpType()) { + this->layer_.inputs.insert({tensor.id_, clml_tensor}); + // Input copy placeholder Tensor + this->layer_.in_placeholder.insert( + {tensor.id_, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, dtype, + node_data, shape)}); + } + + return clml_tensor; + } else { + return this->layer_.storage_map[tensor.id_].first; + } + } + /*! * \brief Build CLML layer from JSON representation and cache. * @@ -392,88 +558,68 @@ class CLMLRuntime : public JSONRuntimeBase { DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); if (node.GetOpType() == "input") { - auto clml_input = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - this->layer_.storage_map.insert({nid, std::make_pair(clml_input, node)}); - this->layer_.inputs.push_back(clml_input); - // Input copy placeholder Tensor - this->layer_.in_placeholder.push_back( - MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype)); + // Layers may request for different layout. Differ the input allocation. } else if (node.GetOpType() == "kernel") { auto op_name = node.GetOpName(); if ("nn.conv2d" == op_name) { auto out = CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_CONVOLUTION_QCOM); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.depthwise_conv2d" == op_name) { auto out = CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_DEPTHWISE_QCOM); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.conv2d_transpose" == op_name) { auto out = CreateConvolution2DLayer(&layer_, node, CL_CONVOLUTION_MODE_TRANSPOSE_QCOM); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.relu6" == op_name) { auto out = CreateReLULayer(&layer_, node, CL_ACTIVATION_RELU6); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.relu" == op_name) { auto out = CreateReLULayer(&layer_, node, CL_ACTIVATION_RELU); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.batch_norm" == op_name) { auto out = CreateBatchNormLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name || "nn.l2_pool2d" == op_name) { auto out = CreatePoolingLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" == op_name) { auto out = CreateGlobalPoolingLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("reshape" == op_name) { auto out = CreateReshapeLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("concatenate" == op_name) { auto out = CreateConcatLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.dense" == op_name) { auto out = CreateDenseLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.softmax" == op_name) { auto out = CreateSoftMaxLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.pad" == op_name) { auto out = CreatePadLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.batch_flatten" == op_name) { auto out = CreateBatchFlattenLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("clip" == op_name) { auto out = CreateClipLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("add" == op_name || "subtract" == op_name || "multiply" == op_name || "minimum" == op_name || "maximum" == op_name || "divide" == op_name) { auto out = CreateBinaryLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.depth_to_space" == op_name) { auto out = CreateDepthToSpaceLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); } else if ("nn.upsampling" == op_name) { auto out = CreateResizeLayer(&layer_, node); this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); - this->layer_.func_outs.push_back(out); + } else if ("nn.batch_matmul" == op_name) { + auto out = CreateBatchMatmulLayer(&layer_, node, nid); + this->layer_.storage_map.insert({nid, std::make_pair(out, node)}); } else { LOG(FATAL) << "Unsupported op: " << op_name; } @@ -488,17 +634,56 @@ class CLMLRuntime : public JSONRuntimeBase { DLDataType tvm_dtype = nodes_[nid].GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); this->layer_.outputs.push_back(this->layer_.storage_map[nid].first); - this->layer_.out_placeholder.push_back( - MakeCLMLTensorFromJSONNode(nodes_[nid], CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype)); + if (this->layer_.out_shapes.find(nid) != this->layer_.out_shapes.end()) { + // Handle customized shapes here + this->layer_.out_placeholder.push_back( + MakeCLMLTensorFromJSONNode(nodes_[nid], CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype, nullptr, + this->layer_.out_shapes[nid])); + } else { + this->layer_.out_placeholder.push_back( + MakeCLMLTensorFromJSONNode(nodes_[nid], CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype)); + } } + + // Plan memory utilization + PlanMemory(); + // ALlocate device memories and initialize the params if any cl_int result = 0; + size_t alloc_on_chip = 0; + size_t alloc_ddr = 0; + size_t alloc_ddr_reuse = 0; for (auto it = this->layer_.storage_map.begin(); it != this->layer_.storage_map.end(); it++) { auto tensor_desc = it->second.first; + uint32_t mem_size = 0; + result = CL_OUT_OF_HOST_MEMORY; + result = CLML_INTF->clGetMLTensorMemorySizeQCOM(CLML_CTX, tensor_desc->tensor, &mem_size); + ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result; + JSONGraphNode node = it->second.second; void* node_data = nullptr; - - allocateTensorMemory(h_ClmlIntf, workspace->contexts[platform_id], tensor_desc); + size_t on_chip_mem_offset = -1; + if (layer_.on_chip_alloc_plan.find(it->first) != layer_.on_chip_alloc_plan.end()) { + LOG_MEM << "Found GMEM Alloc:" << it->first + << " Size:" << layer_.on_chip_alloc_plan[it->first].first + << " Offset:" << layer_.on_chip_alloc_plan[it->first].second; + on_chip_mem_offset = layer_.on_chip_alloc_plan[it->first].second; + alloc_on_chip += mem_size; + tensor_desc->memory = AllocateOnChipTensorMemory(mem_size, on_chip_mem_offset); + } else if (layer_.ddr_alloc_plan.find(it->first) != layer_.ddr_alloc_plan.end()) { + LOG_MEM << "DDR Alloc for nid:" << it->first << " Type:" << node.GetOpType(); + tensor_desc->memory = layer_.ddr_alloc_plan[it->first]; + alloc_ddr_reuse += mem_size; + //} else if ((node.GetOpType() == "input") || IsOutputTensor(it->first) || (node.GetOpType() + //== "const")) { + } else if (node.GetOpType() == "const") { + LOG_MEM << "DDR Alloc for Const/Input/Output"; + tensor_desc->memory = AllocateDDRTensorMemory(mem_size); + alloc_ddr += mem_size; + } else { + LOG(FATAL) << "Mem allocation not found on DDR as well as On-Chip nid: " << it->first + << " Type:" << node.GetOpType(); + } if (node.GetOpType() == "const") { node_data = data_entry_[EntryID(it->first, 0)]->data; @@ -508,37 +693,55 @@ class CLMLRuntime : public JSONRuntimeBase { } this->layer_.tensorMemDescs.push_back(*tensor_desc); } + LOG_STATS << "Total On-Chip Allocation :" << alloc_on_chip; + LOG_STATS << "Total DDR Reuse Allocation:" << alloc_ddr_reuse; + LOG_STATS << "Total DDR fixed allocation:" << alloc_ddr; + size_t ddr_global_pool = 0; + size_t ddr_local_pool = 0; + auto cws = CLMLWorkspace::Global(); + for (auto it = cws->ddr_global_pool.begin(); it != cws->ddr_global_pool.end(); it++) { + LOG_STATS << "DDR Global pool - size:" << it->second.first << " Ref:" << it->second.second; + ddr_global_pool += it->second.first; + } + LOG_STATS << "Total Global Pool:" << ddr_global_pool; + for (auto it = this->layer_.ddr_storage_ref_map.begin(); + it != this->layer_.ddr_storage_ref_map.end(); it++) { + LOG_STATS << "DDR Local pool - size:" << it->second.first << " Ref cnt:" << it->second.second; + ddr_local_pool += it->second.first; + } + LOG_STATS << "Total Local Pool:" << ddr_local_pool; // Setup descriptor set - result = h_ClmlIntf->clCreateMLTensorMemoryDescriptorSetQCOM(&this->layer_.descriptorSet); + result = CLML_INTF->clCreateMLTensorMemoryDescriptorSetQCOM(&this->layer_.descriptorSet); ICHECK(result == CL_SUCCESS) << "clCreateMLTensorMemoryDescriptorSetQCOM:" << result; - result = h_ClmlIntf->clUpdateMLTensorMemoryDescriptorSetQCOM( + result = CLML_INTF->clUpdateMLTensorMemoryDescriptorSetQCOM( this->layer_.descriptorSet, static_cast(this->layer_.tensorMemDescs.size()), this->layer_.tensorMemDescs.data()); ICHECK(result == CL_SUCCESS) << "clUpdateMLTensorMemoryDescriptorSetQCOM:" << result; - if (this->is_tuning_run) { + if (cws->is_tuning_run) { LOG(WARNING) << "CLML Tunning In Progress:"; // Let the command queue recreated in profiling mode. - cl::OpenCLWorkspace::Global()->EnableQueueProfiling(tentry->device, true); + cl::OpenCLWorkspace::Global()->EnableQueueProfiling(cws->tentry->device, true); for (size_t i = 0; i < this->layer_.function.size(); ++i) { LOG(WARNING) << "CLML Tunning:" << this->layer_.layer_names[i]; - result = h_ClmlIntf->clTuneMLOpQCOM(workspace->GetQueue(tentry->device), - this->layer_.function[i], this->layer_.descriptorSet, - this->tuning_cache, nullptr); + result = CLML_INTF->clTuneMLOpQCOM(CLML_QUEUE, this->layer_.function[i], + this->layer_.descriptorSet, this->layer_.tuning_cache, + nullptr); ICHECK(result == CL_SUCCESS) << "clTuneMLOpQCOM:" << result; } - cl::OpenCLWorkspace::Global()->EnableQueueProfiling(tentry->device, false); + cl::OpenCLWorkspace::Global()->EnableQueueProfiling(cws->tentry->device, false); size_t cache_len_bytes = 0; size_t len_ret = 0; - result = h_ClmlIntf->clSaveMLTuningCacheQCOM(tuning_cache, 0, nullptr, &cache_len_bytes); + result = + CLML_INTF->clSaveMLTuningCacheQCOM(layer_.tuning_cache, 0, nullptr, &cache_len_bytes); ICHECK(result == CL_SUCCESS) << "clSaveMLTuningCacheQCOM:" << result; std::vector saved_cache(cache_len_bytes, 0); - result = h_ClmlIntf->clSaveMLTuningCacheQCOM(tuning_cache, saved_cache.size(), - saved_cache.data(), &len_ret); + result = CLML_INTF->clSaveMLTuningCacheQCOM(layer_.tuning_cache, saved_cache.size(), + saved_cache.data(), &len_ret); ICHECK(result == CL_SUCCESS) << "clSaveMLTuningCacheQCOM" << result; std::string tune_str; @@ -551,189 +754,25 @@ class CLMLRuntime : public JSONRuntimeBase { strm->Write(clml_symbol); strm->Write(saved_cache); - std::ofstream fs(tuning_file, std::ios::app | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << tuning_file; + std::ofstream fs(cws->tuning_file, std::ios::app | std::ios::binary); + ICHECK(!fs.fail()) << "Cannot open " << cws->tuning_file; fs.write(&tune_str[0], tune_str.length()); - LOG(WARNING) << "CLML: Tuning cache dumped to:" << tuning_file << " size" << tune_str.length() - << " with tuning blob len " << saved_cache.size(); + LOG(WARNING) << "CLML: Tuning cache dumped to:" << cws->tuning_file << " size" + << tune_str.length() << " with tuning blob len " << saved_cache.size(); } - } - - /*! - * \brief CLML objects we cache in order to avoid needing to construct - * a new layer each time. - */ - struct CachedLayer { - std::vector function; - std::vector> inputs; - std::vector> in_placeholder; - std::vector> outputs; - std::vector> out_placeholder; - std::vector> func_outs; - std::vector> func_ins; - std::map, JSONGraphNode>> - storage_map; - std::vector tensorMemDescs; - std::vector in_tensorMemDescs; - std::vector out_tensorMemDescs; - cl_ml_tensor_mem_desc_set_qcom descriptorSet; - std::vector layer_names; - cl_ml_tensor_qcom unusedTensor = nullptr; - }; - - struct tensor_dims_t { - uint32_t n, c, h, w; - }; - - bool ExtensionStringPresent(void) { - cl_int result = 0; - size_t reqd_size = 0; - cl_device_id device_id = - workspace->GetCLDeviceID(workspace->GetThreadEntry()->device.device_id); - result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, 0, nullptr, &reqd_size); - ICHECK(reqd_size > 0u && result == CL_SUCCESS) << "clGetDeviceInfo:" << result; - - std::vector buf(reqd_size); - result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, reqd_size, buf.data(), nullptr); - ICHECK(result == CL_SUCCESS) << "clGetDeviceInfo:" << result; - - std::string extensions(buf.data()); - LOG(WARNING) << "OpenCL Extensions:" << extensions; - return (extensions.find("cl_qcom_ml_ops") != std::string::npos); - } - - cl_ml_tensor_qcom DeviceMakeCLMLTensor( - cl_context context, tensor_dims_t dims, - cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, - cl_channel_type dtype = CL_FLOAT) { - cl_ml_tensor_qcom tensor; - cl_int result = CL_OUT_OF_RESOURCES; - - cl_ml_tensor_desc_qcom desc = { - dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, { 0 }}; - result = - h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, &tensor); - ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; - (void)result; - return tensor; - } - - cl_int allocateTensorMemory(void* pClmlIntf, cl_context context, - std::shared_ptr pTensorMemDesc) { - uint32_t size = 0; - cl_int result = CL_OUT_OF_HOST_MEMORY; - cl_mem buffer = nullptr; - - result = h_ClmlIntf->clGetMLTensorMemorySizeQCOM(workspace->contexts[platform_id], - pTensorMemDesc->tensor, &size); - ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result; - - buffer = - clCreateBuffer(workspace->contexts[platform_id], CL_MEM_READ_WRITE, size, nullptr, &result); - ICHECK(result == CL_SUCCESS) << "clCreateBuffer:" << result; - - pTensorMemDesc->memory = buffer; - - return result; - } - - tensor_dims_t get_tensor_dims(const JSONGraphNode& node) { - std::vector shape = node.GetOpShape()[0]; - tensor_dims_t dims; - dims.n = shape[0]; - dims.c = shape[1]; - dims.h = shape[2]; - dims.w = shape[3]; - return dims; - } - - cl_channel_type MakeCLDataType(const DLDataType& data_type) { - if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 32) { - return CL_FLOAT; - } else if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16) { - return CL_HALF_FLOAT; - } else { - LOG(FATAL) << "Datatype " << data_type << " unsupported by CLML runtime"; - } - } - - cl_arithmetic_mode_qcom MakeCLArithMode(const cl_channel_type& data_type, - const cl_channel_type& acc_type = CL_FLOAT) { - if (data_type == CL_FLOAT && acc_type == CL_FLOAT) { - return CL_ARITHMETIC_MODE_FP32_QCOM; - } else if (data_type == CL_HALF_FLOAT && acc_type == CL_FLOAT) { - return CL_ARITHMETIC_MODE_FP16_ACC32_QCOM; - } else if (data_type == CL_HALF_FLOAT && acc_type == CL_HALF_FLOAT) { - return CL_ARITHMETIC_MODE_FP16_QCOM; - } else { - LOG(FATAL) << "Datatype " << data_type << " unsupported by CLML runtime"; - } - } + if (cws->is_recordable_queue) { + for (size_t i = 0; i < this->layer_.function.size(); ++i) { + result = + CLML_INTF->clEnqueueMLOpQCOM(this->layer_.recordable_queue, this->layer_.function[i], + this->layer_.descriptorSet, 0, nullptr, nullptr); + ICHECK(result == CL_SUCCESS) << "clEnqueueMLOpQCOM - Recordable Queue:" << result; + } - std::shared_ptr MakeCLMLTensor( - const JSONGraphNode& tensor_rep, void* data, std::vector c_shape, - cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_uint dtype = CL_FLOAT) { - std::vector shape = tensor_rep.GetOpShape()[0]; - std::vector clml_shape(shape.begin(), shape.end()); - if (c_shape.size() > 0) { - clml_shape = c_shape; + result = clEndRecordingQCOM(this->layer_.recording); + ICHECK(result == CL_SUCCESS) << "clEndRecordingQCOM:" << result; } - // Make sure the tensors with dimensions less than 4 are padded with 1. - clml_shape.push_back(1); - clml_shape.push_back(1); - clml_shape.push_back(1); - - tensor_dims_t dims; - dims.n = clml_shape[0]; - dims.c = clml_shape[1]; - dims.h = clml_shape[2]; - dims.w = clml_shape[3]; - DLDataType tvm_dtype = tensor_rep.GetOpDataType()[0]; - cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); - - auto tensor_dsc = std::make_shared(); - tensor_dsc->tensor = - DeviceMakeCLMLTensor(workspace->contexts[platform_id], dims, layout, cl_dtype); - return tensor_dsc; } - /*! - * \brief Create an CLML tensor given the JSON representation. If scale - * and offset are given, then create a quantized CLML tensor. - * - * \param tensor The tensor to represent. - * \return CLML Tensor. - */ - - std::shared_ptr MakeCLMLTensorFromJSONEntry( - const JSONGraphNodeEntry& tensor, std::vector shape = {}, - cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_uint dtype = CL_FLOAT) { - JSONGraphNode node = nodes_[tensor.id_]; - if (this->layer_.storage_map.find(tensor.id_) == this->layer_.storage_map.end()) { - void* node_data = nullptr; - if (node.GetOpType() == "const") { - node_data = data_entry_[EntryID(tensor)]->data; - } - auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, dtype, node_data, shape); - this->layer_.storage_map.insert({tensor.id_, std::make_pair(clml_tensor, node)}); - return clml_tensor; - } else { - return this->layer_.storage_map[tensor.id_].first; - } - } - /*! - * \brief Create an CLML tensor given the JSON representation. If scale - * and offset are given, then create a quantized CLML tensor. - * - * \param node The tensor to represent. - * \param data (optional) Constant data of input node. - * \return CLML Tensor. - */ - std::shared_ptr MakeCLMLTensorFromJSONNode( - const JSONGraphNode& node, cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, - cl_uint dtype = CL_FLOAT, void* data = nullptr, std::vector shape = {}) { - return MakeCLMLTensor(node, data, shape, layout, dtype); - } /*! * \brief Create a 2D convolution layer. * @@ -807,8 +846,7 @@ class CLMLRuntime : public JSONRuntimeBase { } else { cl_ml_tensor_desc_qcom desc = {}; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, - &layer_.unusedTensor); + result = CLML_INTF->clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, &layer_.unusedTensor); ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; bias->tensor = layer_.unusedTensor; } @@ -827,22 +865,21 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_qcom op = nullptr; if (!has_bn) { if (!has_act) { - result = h_ClmlIntf->clCreateMLOpConvolutionForwardQCOM( - workspace->contexts[platform_id], nullptr, &conv_desc, input->tensor, weight->tensor, - bias->tensor, output->tensor, &op, nullptr); + result = CLML_INTF->clCreateMLOpConvolutionForwardQCOM( + CLML_CTX, nullptr, &conv_desc, input->tensor, weight->tensor, bias->tensor, + output->tensor, &op, nullptr); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } else { - result = h_ClmlIntf->clCreateMLOpFusedConvolutionActivationForwardQCOM( - workspace->contexts[platform_id], nullptr, &conv_desc, &act_desc, input->tensor, - weight->tensor, bias->tensor, nullptr, output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpFusedConvolutionActivationForwardQCOM( + CLML_CTX, nullptr, &conv_desc, &act_desc, input->tensor, weight->tensor, bias->tensor, + nullptr, output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } - layer_.func_ins.push_back(input); layer->function.push_back(op); } else { int bn_index = has_bias ? 3 : 2; int axis = std::stoi(node.GetAttr>("batchnorm")[0]); - auto bn_dims = get_tensor_dims(nodes_[inputs[bn_index].id_]); + auto bn_dims = GetTensorDims(nodes_[inputs[bn_index].id_]); std::vector bn_shape = {1, 1, 1, 1}; bn_shape[axis] = bn_dims.n; auto bn_mean = std::make_shared(); @@ -860,20 +897,19 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, cl_arithmetic_mode}; if (!has_act) { - result = h_ClmlIntf->clCreateMLOpFusedConvolutionBatchNormForwardQCOM( - workspace->contexts[platform_id], nullptr, &conv_desc, &bn_desc, input->tensor, - weight->tensor, bias->tensor, output->tensor, bn_mean->tensor, bn_var->tensor, - bn_scale->tensor, bn_bias->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpFusedConvolutionBatchNormForwardQCOM( + CLML_CTX, nullptr, &conv_desc, &bn_desc, input->tensor, weight->tensor, bias->tensor, + output->tensor, bn_mean->tensor, bn_var->tensor, bn_scale->tensor, bn_bias->tensor, &op, + layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } else { - result = h_ClmlIntf->clCreateMLOpFusedConvolutionBatchNormActivationForwardQCOM( - workspace->contexts[platform_id], nullptr, &conv_desc, &bn_desc, &act_desc, - input->tensor, weight->tensor, bias->tensor, output->tensor, nullptr, bn_mean->tensor, - bn_var->tensor, bn_scale->tensor, bn_bias->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpFusedConvolutionBatchNormActivationForwardQCOM( + CLML_CTX, nullptr, &conv_desc, &bn_desc, &act_desc, input->tensor, weight->tensor, + bias->tensor, output->tensor, nullptr, bn_mean->tensor, bn_var->tensor, + bn_scale->tensor, bn_bias->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result; } - layer_.func_ins.push_back(input); layer->function.push_back(op); } return output; @@ -902,16 +938,14 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = {}; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, - &layer_.unusedTensor); + result = CLML_INTF->clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, &layer_.unusedTensor); ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << ":" << result; - result = h_ClmlIntf->clCreateMLOpActivationForwardQCOM( - workspace->contexts[platform_id], nullptr, &act_desc, input->tensor, layer_.unusedTensor, - output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpActivationForwardQCOM(CLML_CTX, nullptr, &act_desc, + input->tensor, layer_.unusedTensor, + output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Activation Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -940,7 +974,7 @@ class CLMLRuntime : public JSONRuntimeBase { opProperties.push_back(*reinterpret_cast(&epsilon)); opProperties.push_back(CL_ML_OP_PROPERTY_LIST_END_QCOM); - auto bn_dims = get_tensor_dims(nodes_[node.GetInputs()[1].id_]); + auto bn_dims = GetTensorDims(nodes_[node.GetInputs()[1].id_]); std::vector bn_shape = {1, 1, 1, 1}; bn_shape[axis] = bn_dims.n; auto bn_mean = std::make_shared(); @@ -960,14 +994,12 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpBatchNormForwardQCOM( - workspace->contexts[platform_id], opProperties.data(), &bn_desc, input->tensor, - bn_mean->tensor, bn_var->tensor, bn_scale->tensor, bn_bias->tensor, output->tensor, &op, - tuning_cache); + result = CLML_INTF->clCreateMLOpBatchNormForwardQCOM( + CLML_CTX, opProperties.data(), &bn_desc, input->tensor, bn_mean->tensor, bn_var->tensor, + bn_scale->tensor, bn_bias->tensor, output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Batchnorm Error:" << result; layer->function.push_back(op); - layer_.func_ins.push_back(input); return output; } @@ -1012,16 +1044,14 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = {}; cl_ml_tensor_qcom unusedTensor = nullptr; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, - &unusedTensor); + result = CLML_INTF->clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, &unusedTensor); ICHECK(unusedTensor && result == CL_SUCCESS) << ":" << result; - result = h_ClmlIntf->clCreateMLOpPoolingForwardQCOM(workspace->contexts[platform_id], nullptr, - &pool_desc, input->tensor, unusedTensor, - output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpPoolingForwardQCOM(CLML_CTX, nullptr, &pool_desc, input->tensor, + unusedTensor, output->tensor, &op, + layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1044,7 +1074,7 @@ class CLMLRuntime : public JSONRuntimeBase { auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - auto in_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]); + auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]); cl_ml_op_pooling_desc_qcom pool_desc = { node.GetOpName() == "nn.global_max_pool2d" ? CL_POOLING_MODE_MAX_QCOM : CL_POOLING_MODE_AVERAGE_EXCLUDE_PADDING_QCOM, @@ -1059,16 +1089,14 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_tensor_desc_qcom desc = {}; desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, - &layer_.unusedTensor); + result = CLML_INTF->clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, &layer_.unusedTensor); ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << ":" << result; - result = h_ClmlIntf->clCreateMLOpPoolingForwardQCOM( - workspace->contexts[platform_id], nullptr, &pool_desc, input->tensor, layer_.unusedTensor, - output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpPoolingForwardQCOM(CLML_CTX, nullptr, &pool_desc, input->tensor, + layer_.unusedTensor, output->tensor, &op, + layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1088,19 +1116,17 @@ class CLMLRuntime : public JSONRuntimeBase { cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - auto out_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]); + auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype, nullptr, {out_dims.n, out_dims.c, 1, 1}); cl_ml_op_softmax_desc_qcom softmax_desc = {CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM, CL_SOFTMAX_MODE_INSTANCE_QCOM, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpSoftmaxQCOM(workspace->contexts[platform_id], nullptr, - &softmax_desc, input->tensor, output->tensor, &op, - tuning_cache); + result = CLML_INTF->clCreateMLOpSoftmaxQCOM(CLML_CTX, nullptr, &softmax_desc, input->tensor, + output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "SoftMax Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1142,11 +1168,10 @@ class CLMLRuntime : public JSONRuntimeBase { {clml_padding[0], clml_padding[1], clml_padding[2], clml_padding[3], 0, 0, 0, 0}, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpPadQCOM(workspace->contexts[platform_id], nullptr, &pad_desc, - input->tensor, output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpPadQCOM(CLML_CTX, nullptr, &pad_desc, input->tensor, + output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Pad Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1167,11 +1192,10 @@ class CLMLRuntime : public JSONRuntimeBase { cl_dtype); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->contexts[platform_id], nullptr, - input->tensor, output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpReshapeQCOM(CLML_CTX, nullptr, input->tensor, output->tensor, + &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1192,11 +1216,10 @@ class CLMLRuntime : public JSONRuntimeBase { cl_dtype); auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->contexts[platform_id], nullptr, - input->tensor, output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpReshapeQCOM(CLML_CTX, nullptr, input->tensor, output->tensor, + &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1227,9 +1250,8 @@ class CLMLRuntime : public JSONRuntimeBase { } cl_ml_op_concat_desc_qcom concatDesc = {axis, (cl_uint)inputSize, cl_arithmetic_mode}; - result = - h_ClmlIntf->clCreateMLOpConcatQCOM(workspace->contexts[platform_id], nullptr, &concatDesc, - concatInputs, output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpConcatQCOM(CLML_CTX, nullptr, &concatDesc, concatInputs, + output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Concat Error:" << result; layer->function.push_back(op); @@ -1252,47 +1274,85 @@ class CLMLRuntime : public JSONRuntimeBase { DLDataType tvm_dtype = node.GetOpDataType()[0]; cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); - auto inp_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]); - auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {1, inp_dims.c, 1, 1}, - CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); - auto wt_dims = get_tensor_dims(nodes_[node.GetInputs()[1].id_]); - bool has_bias = node.GetInputs().size() == 3 ? true : false; - auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {wt_dims.n, wt_dims.c, 1, 1}, - CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype); + auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]); + auto input = + MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype); + auto wt_dims = GetTensorDims(nodes_[node.GetInputs()[1].id_]); + auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {1, 1, wt_dims.n, wt_dims.c}, + CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype); + auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype); + cl_gemm_transform_qcom b_transform = CL_GEMM_TRANSFORM_NONE_QCOM; + if (in_dims.c == wt_dims.c) { + b_transform = CL_GEMM_TRANSFORM_TRANSPOSE_QCOM; + } + cl_ml_op_gemm_desc_qcom gemmDesc = {in_dims.n, // m + wt_dims.n, // n + wt_dims.c, // k + CL_GEMM_TRANSFORM_NONE_QCOM, // A transform + b_transform, // B transform + {{1.0}, CL_FLOAT}, // alpha + {{0.0}, CL_FLOAT}, // beta + cl_arithmetic_mode}; + + result = CLML_INTF->clCreateMLOpGemmQCOM(CLML_CTX, 0, &gemmDesc, input->tensor, weight->tensor, + output->tensor, &op, layer_.tuning_cache); + ICHECK(op && result == CL_SUCCESS) << "Dense Error:" << result; - auto bias = std::make_shared(); - if (has_bias) { - auto bias_dims = get_tensor_dims(nodes_[node.GetInputs()[2].id_]); - bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, - cl_dtype); - } else { - cl_ml_tensor_desc_qcom desc = {}; - desc.num_dimensions = CL_TENSOR_UNUSED_QCOM; - result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->contexts[platform_id], nullptr, &desc, - &layer_.unusedTensor); - ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; - bias->tensor = layer_.unusedTensor; + layer->function.push_back(op); + return output; + } + + /*! + * \brief Create a batch_matmul layer. + * + * + * \param layer The CLML layer to build. Containing inputs, outputs and the CLML function. + * \param node The JSON representation of the operator. + */ + std::shared_ptr CreateBatchMatmulLayer(CachedLayer* layer, + const JSONGraphNode& node, + int nid) { + cl_int result = 0; + cl_ml_op_qcom op = nullptr; + DLDataType tvm_dtype = node.GetOpDataType()[0]; + cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype); + cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype); + auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]); + auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {in_dims.c, in_dims.h}, + CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype); + auto wt_dims = GetTensorDims(nodes_[node.GetInputs()[1].id_]); + auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {1, 1, wt_dims.c, wt_dims.h}, + CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype); + + std::vector out_shape = node.GetOpShape()[0]; + std::vector clml_out_shape; + clml_out_shape.push_back(out_shape[1]); + clml_out_shape.push_back(out_shape[2]); + clml_out_shape.push_back(1); + clml_out_shape.push_back(1); + auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype, nullptr, + clml_out_shape); + layer->out_shapes.insert({nid, clml_out_shape}); + + cl_bool b_transpose = std::stoi(node.GetAttr>("transpose_b")[0]); + cl_gemm_transform_qcom b_transform = CL_GEMM_TRANSFORM_NONE_QCOM; + if (b_transpose) { + b_transform = CL_GEMM_TRANSFORM_TRANSPOSE_QCOM; } - // Output - auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype, nullptr, - {1, wt_dims.n, 1, 1}); - cl_ml_op_convolution_desc_qcom conv_desc = {CL_CONVOLUTION_MODE_CONVOLUTION_QCOM, - 1, - 4, - {0, 0}, - {0, 0}, - {1, 1}, - {1, 1}, - 0, - cl_arithmetic_mode}; - - result = h_ClmlIntf->clCreateMLOpConvolutionForwardQCOM( - workspace->contexts[platform_id], nullptr, &conv_desc, input->tensor, weight->tensor, - bias->tensor, output->tensor, &op, nullptr); - ICHECK(op && result == CL_SUCCESS) << "Fully Connected Error:" << result; + cl_ml_op_gemm_desc_qcom gemmDesc = {in_dims.c, // m + wt_dims.c, // n + wt_dims.h, // k + CL_GEMM_TRANSFORM_NONE_QCOM, // A transform + b_transform, // B transform + {{1.0}, CL_FLOAT}, // alpha + {{0.0}, CL_FLOAT}, // beta + cl_arithmetic_mode}; + + result = CLML_INTF->clCreateMLOpGemmQCOM(CLML_CTX, 0, &gemmDesc, input->tensor, weight->tensor, + output->tensor, &op, layer_.tuning_cache); + ICHECK(op && result == CL_SUCCESS) << "BatchMatmul Error:" << result; layer->function.push_back(op); - layer_.func_ins.push_back(input); return output; } @@ -1318,11 +1378,10 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_clip_desc_qcom clip_desc = { CL_CLIP_BY_VALUE_QCOM, {{a_max}, CL_FLOAT}, {{a_min}, CL_FLOAT}, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpClipQCOM(workspace->contexts[platform_id], nullptr, &clip_desc, - input->tensor, output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpClipQCOM(CLML_CTX, nullptr, &clip_desc, input->tensor, + output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Clip Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1360,13 +1419,11 @@ class CLMLRuntime : public JSONRuntimeBase { cl_ml_op_binary_desc_qcom add_desc = { binary_op, {{1.0}, CL_FLOAT}, {{1.0}, CL_FLOAT}, {{0.0}, CL_FLOAT}, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpBinaryQCOM(workspace->contexts[platform_id], nullptr, - &add_desc, input_a->tensor, input_b->tensor, - output->tensor, &op, tuning_cache); + result = CLML_INTF->clCreateMLOpBinaryQCOM(CLML_CTX, nullptr, &add_desc, input_a->tensor, + input_b->tensor, output->tensor, &op, + layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << op_name << " Node Error:" << result; - layer_.func_ins.push_back(input_a); - layer_.func_ins.push_back(input_b); layer->function.push_back(op); return output; } @@ -1390,12 +1447,10 @@ class CLMLRuntime : public JSONRuntimeBase { cl_uint block_size = std::stoi(node.GetAttr>("block_size")[0]); cl_ml_op_depthtospace_desc_qcom dtos_desc = {block_size, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpDepthToSpaceQCOM(workspace->contexts[platform_id], nullptr, - &dtos_desc, input->tensor, output->tensor, - &op, tuning_cache); + result = CLML_INTF->clCreateMLOpDepthToSpaceQCOM(CLML_CTX, nullptr, &dtos_desc, input->tensor, + output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "DepthToSpace Layer Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1419,12 +1474,10 @@ class CLMLRuntime : public JSONRuntimeBase { cl_bool align_corners = std::stoi(node.GetAttr>("align_corners")[0]); cl_ml_op_resize_bilinear_desc_qcom resize_desc = {align_corners, false, cl_arithmetic_mode}; - result = h_ClmlIntf->clCreateMLOpResizeBilinearQCOM(workspace->contexts[platform_id], nullptr, - &resize_desc, input->tensor, output->tensor, - &op, tuning_cache); + result = CLML_INTF->clCreateMLOpResizeBilinearQCOM( + CLML_CTX, nullptr, &resize_desc, input->tensor, output->tensor, &op, layer_.tuning_cache); ICHECK(op && result == CL_SUCCESS) << "Resize Layer Error:" << result; - layer_.func_ins.push_back(input); layer->function.push_back(op); return output; } @@ -1434,16 +1487,12 @@ class CLMLRuntime : public JSONRuntimeBase { * \note Currently only supports a single layer. */ + // This layer instance CachedLayer layer_; - // CLML Context - GET_ML_API_INTERFACE* h_ClmlIntf = nullptr; - cl::OpenCLWorkspace* workspace = nullptr; - cl::OpenCLThreadEntry* tentry = nullptr; - cl_device_id device_id; - cl_platform_id platform_id; - cl_ml_tuningcache_qcom tuning_cache = nullptr; - bool is_tuning_run; - char* tuning_file; + + // CLML Workspace + CLMLWorkspace* cws; + #else void Run() override { LOG(FATAL) << "Cannot call run on CLML module without runtime enabled. " diff --git a/src/runtime/contrib/clml/clml_runtime.h b/src/runtime/contrib/clml/clml_runtime.h new file mode 100644 index 0000000000000..2a6ce02626d44 --- /dev/null +++ b/src/runtime/contrib/clml/clml_runtime.h @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/clml/clml_runtime.h + * \brief CLML header + */ +#ifndef TVM_RUNTIME_CONTRIB_CLML_CLML_RUNTIME_H_ +#define TVM_RUNTIME_CONTRIB_CLML_CLML_RUNTIME_H_ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../file_utils.h" +#include "../../opencl/opencl_common.h" +#include "../../thread_storage_scope.h" +#include "../json/json_node.h" +#include "../json/json_runtime.h" + +#ifdef TVM_GRAPH_EXECUTOR_CLML +#include + +#define CAT_I(a, b) a##b +#define CAT(a, b) CAT_I(a, b) +#define GET_ML_INTERFACE CAT(CAT(clGetMLInterfaceV, CL_QCOM_ML_OPS_H_MAJOR_VERSION), QCOM) +#define GET_ML_API_INTERFACE CAT(CAT(CLMLInterfaceV, CL_QCOM_ML_OPS_H_MAJOR_VERSION), QCOM) + +/*! \brief Magic number for CLML Tuning cache entry */ +static const uint64_t kTVMCLMLTuningCacheMagic = 0x434C4D4C54554E45; + +#define DEBUG_MEMORY_ALLOC false +#define DEBUG_STATS false +#define LOG_MEM LOG_IF(WARNING, DEBUG_MEMORY_ALLOC) +#define LOG_STATS LOG_IF(WARNING, DEBUG_STATS) + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +class CLMLThreadEntry; + +/*! + * \brief CLML workspace. + */ +class CLMLWorkspace { + public: + /* Constructor */ + CLMLWorkspace(); + /*! + * \brief Get the thread local ThreadEntry + */ + virtual CLMLThreadEntry* GetThreadEntry(); + + /* CLML Context */ + GET_ML_API_INTERFACE* h_ClmlIntf = nullptr; + cl::OpenCLWorkspace* workspace = nullptr; + cl::OpenCLThreadEntry* tentry = nullptr; + cl_device_id device_id; + cl_platform_id platform_id; + + /* Tuning Support */ + bool is_tuning_run; + char* tuning_file; + + /* Recordable Queues */ + bool is_recordable_queue = false; + + /* On chip memory support */ + bool is_on_chip_memory = false; + + /* On chip memory size */ + size_t onchip_mem_size = 0; + + /* get the global workspace */ + static CLMLWorkspace* Global(); + + bool ExtensionStringPresent(std::string extn); + + /* DDR memory management */ + std::map> ddr_global_pool; // buf, size and ref count +}; + +/*! \brief Thread local workspace */ +class CLMLThreadEntry { + public: + /* get the global workspace */ + static CLMLThreadEntry* ThreadLocal(); +}; + +/*! + * \brief CLML objects we cache in order to avoid needing to construct + * a new layer each time. + */ +struct CachedLayer { + /* List of all created CLML operation handles in graph */ + std::vector function; + /* The input tensor map */ + std::map> inputs; + /* A place holder Tensor representing TVM NDArray as CLML Tensor */ + std::map> in_placeholder; + /* The Output tensor map */ + std::vector> outputs; + /* A place holder Tensor representing TVM NDArray as CLML Tensor */ + std::vector> out_placeholder; + /* Tensor shape exception list while returning from CLML Subgraph */ + std::map> out_shapes; + /* Map of all tensors which need backing memory allocation */ + std::map, JSONGraphNode>> + storage_map; + /* Tensor memory descriptors list to set after backing memory allocation */ + std::vector tensorMemDescs; + cl_ml_tensor_mem_desc_set_qcom descriptorSet; + /* List of layer names in subgraph */ + std::vector layer_names; + /* A dummy CLML tensor used across various ops */ + cl_ml_tensor_qcom unusedTensor = nullptr; + + /* Graph level tuning cache */ + cl_ml_tuningcache_qcom tuning_cache = nullptr; + + /* Memory management */ + std::map storage_ref_map; // NodeId & ref. count + /* Activation node id & life span (the layer after which we can free) */ + std::map life_span; + std::map on_chip_pool_size; // Mem start & size + std::map on_chip_pool_alloc_info; // Mem start & node_id + std::map> on_chip_alloc_plan; // Final Alloc Plan + std::map on_chip_reject; // On-Chip reject info + bool alloc_ping_pong; // Allocation stratagy + int in_chip_total_free; // Total available + int in_chip_total_alloc; // Free memory + int on_chip_alert_fail; // Faliure due to fragmentation + + /* DDR memory planner */ + std::map> ddr_storage_ref_map; // local pool reference count + std::map ddr_alloc_plan; // allocation map + + cl_command_queue recordable_queue = nullptr; + cl_recording_qcom recording = nullptr; +}; + +struct tensor_dims_t { + uint32_t n, c, h, w; +}; + +#define CLML_INTF CLMLWorkspace::Global()->h_ClmlIntf +#define CLML_QUEUE \ + CLMLWorkspace::Global()->workspace->GetQueue(CLMLWorkspace::Global()->tentry->device) +#define CLML_CTX CLMLWorkspace::Global()->workspace->contexts[CLMLWorkspace::Global()->platform_id] + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_GRAPH_EXECUTOR_CLML +#endif // TVM_RUNTIME_CONTRIB_CLML_CLML_RUNTIME_H_ diff --git a/src/runtime/contrib/clml/clml_utils.cc b/src/runtime/contrib/clml/clml_utils.cc new file mode 100644 index 0000000000000..e1e6fc754231b --- /dev/null +++ b/src/runtime/contrib/clml/clml_utils.cc @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/clml/clml_utils.cc + * \brief Utilities. + */ +#ifdef TVM_GRAPH_EXECUTOR_CLML +#include "clml_utils.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +/*! + * \brief Copy utility to CLML Tensor. + * + * \param tensor CLML tensor descriptor + * \param data pointer to host data + * \param layout host data layout + */ +void CopyDataToCLMLTensor(std::shared_ptr tensor, void* data, + cl_ml_tensor_layout_qcom layout) { + cl_int result = 0; + cl_event evt = nullptr; + result = CLML_INTF->clEnqueueWriteMLTensorDataQCOM(CLML_QUEUE, data, layout, tensor->tensor, + tensor->memory, + 0, // n waitlist + nullptr, // waitlist + &evt); // event + ICHECK((evt != nullptr) && result == CL_SUCCESS) << "clEnqueueWriteMLTensorDataQCOM:" << result; +} + +/*! + * \brief Copy utility from CLML tensor. + * + * \param tensor CLML tensor descriptor + * \param data pointer to host data + * \param layout expectred host data layout + */ +void CopyDataFromCLMLTensor(std::shared_ptr tensor, void* data, + cl_ml_tensor_layout_qcom layout) { + cl_int result = 0; + cl_event readEvent = nullptr; + // Read the output tensor + result = CLML_INTF->clEnqueueReadMLTensorDataQCOM(CLML_QUEUE, tensor->tensor, tensor->memory, + data, layout, + 0, // n waitlist + nullptr, // waitlist + &readEvent); // event + ICHECK(result == CL_SUCCESS) << "clEnqueueReadMLTensorDataQCOM:" << result; + + result = clWaitForEvents(1, &readEvent); + ICHECK(result == CL_SUCCESS) << "clWaitForEvents:" << result; +} + +/*! + * \brief Make a CLML tensor given it's attributes + * + * \param context OpenCL context + * \param dims Tensor dimensions + * \param layout CLML tensor layout of tensor + * \param dtype Tensor data type + * \return CLML tensor + */ +cl_ml_tensor_qcom DeviceMakeCLMLTensor(cl_context context, tensor_dims_t dims, + cl_ml_tensor_layout_qcom layout, cl_channel_type dtype) { + cl_ml_tensor_qcom tensor; + cl_int result = CL_OUT_OF_RESOURCES; + + cl_ml_tensor_desc_qcom desc = { + dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, {0}}; + result = CLML_INTF->clCreateMLTensorQCOM(CLML_CTX, nullptr, &desc, &tensor); + ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result; + return tensor; +} + +/*! + * \brief utility that allocates DDR backed memory for the tensor. + * + * \param context OpenCL context + * \param buffer size + * \return allocated cl_mem object + */ +cl_mem AllocateDDRTensorMemory(size_t size) { + cl_int result = CL_OUT_OF_HOST_MEMORY; + cl_mem buffer = nullptr; + + buffer = clCreateBuffer(CLML_CTX, CL_MEM_READ_WRITE, size, nullptr, &result); + ICHECK(result == CL_SUCCESS) << "clCreateBuffer:" << result; + + return buffer; +} + +/*! + * \brief utility that allocates on chip backed memory for the tensor. + * + * \param context OpenCL context + * \param tensor_desc tensor descriptor + * \param on_chip_mem_offset on chip memory offset to be used for allocation + * \return result API status + */ +cl_mem AllocateOnChipTensorMemory(size_t size, cl_uint on_chip_mem_offset) { + cl_int result = CL_OUT_OF_HOST_MEMORY; + cl_mem buffer = nullptr; + + cl_mem_properties on_chip_buff_prop[] = {CL_MEM_ONCHIP_GLOBAL_QCOM, 1, + CL_MEM_ONCHIP_GLOBAL_OFFSET_QCOM, on_chip_mem_offset, 0}; + LOG_MEM << "On-Chip Alloc:" << size << " Offset:" << on_chip_mem_offset; + buffer = clCreateBufferWithProperties(CLML_CTX, on_chip_buff_prop, CL_MEM_READ_WRITE, size, + nullptr, &result); + ICHECK(result == CL_SUCCESS) << "clCreateBufferWithProperties:" << result; + + return buffer; +} + +/*! + * \brief Utility to extract tensor dimensions from JSON node. + * + * \param node JSON graph node + * \return The CLML tensor dimension + */ +tensor_dims_t GetTensorDims(const JSONGraphNode& node) { + std::vector shape = node.GetOpShape()[0]; + tensor_dims_t dims; + dims.n = shape[0]; + dims.c = shape[1]; + dims.h = shape[2]; + dims.w = shape[3]; + return dims; +} + +/*! + * \brief Utility to map TVM data type to OpenCL channel type. + * + * \param data_type TVM DType + * \return OpenCL channel type. + */ +cl_channel_type MakeCLDataType(const DLDataType& data_type) { + if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 32) { + return CL_FLOAT; + } else if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 16) { + return CL_HALF_FLOAT; + } else { + LOG(FATAL) << "Datatype " << data_type << " unsupported by CLML runtime"; + } +} + +/*! + * \brief Utility to map OpenCL types to CLML operator arthematic mode. + * + * \param data_type cl data type + * \param acc_type accumulation type to be used + * \return the operator arthematic mode + */ +cl_arithmetic_mode_qcom MakeCLArithMode(const cl_channel_type& data_type, + const cl_channel_type& acc_type) { + if (data_type == CL_FLOAT && acc_type == CL_FLOAT) { + return CL_ARITHMETIC_MODE_FP32_QCOM; + } else if (data_type == CL_HALF_FLOAT && acc_type == CL_FLOAT) { + return CL_ARITHMETIC_MODE_FP16_ACC32_QCOM; + } else if (data_type == CL_HALF_FLOAT && acc_type == CL_HALF_FLOAT) { + return CL_ARITHMETIC_MODE_FP16_QCOM; + } else { + LOG(FATAL) << "Datatype " << data_type << " unsupported by CLML runtime"; + } +} + +/*! + * \brief Helper to sanity check before tensor creation. + * + * \param node The tensor to represent. + * \param data data pointer to prefill the tensor + * \param shape shape information of tensor + * \param layout the tensor layout to be used + * \param dtype tensor data type + * \return CLML Tensor descriptor. + */ +std::shared_ptr MakeCLMLTensor(const JSONGraphNode& tensor_rep, + void* data, + std::vector c_shape, + cl_ml_tensor_layout_qcom layout, + cl_uint dtype) { + std::vector shape = tensor_rep.GetOpShape()[0]; + std::vector clml_shape(shape.begin(), shape.end()); + if (c_shape.size() > 0) { + clml_shape = c_shape; + } + // Make sure the tensors with dimensions less than 4 are padded with 1. + clml_shape.push_back(1); + clml_shape.push_back(1); + clml_shape.push_back(1); + + tensor_dims_t dims; + dims.n = clml_shape[0]; + dims.c = clml_shape[1]; + dims.h = clml_shape[2]; + dims.w = clml_shape[3]; + + auto tensor_dsc = std::make_shared(); + tensor_dsc->tensor = DeviceMakeCLMLTensor(CLML_CTX, dims, layout, dtype); + return tensor_dsc; +} + +/*! + * \brief Create an CLML tensor given the JSON Node representation. + * + * \param node The tensor to represent. + * \param layout the tensor layout to be used + * \param dtype tensor data type + * \param data data pointer to prefill the tensor + * \param shape shape information of tensor + * \return CLML Tensor descriptor. + */ +std::shared_ptr MakeCLMLTensorFromJSONNode( + const JSONGraphNode& node, cl_ml_tensor_layout_qcom layout, cl_uint dtype, void* data, + std::vector shape) { + return MakeCLMLTensor(node, data, shape, layout, dtype); +} + +/*! + * \brief Utility function to extract vector values from string. + * + * \param val vector of strings + * \return vector of cl_uints. + */ +std::vector GetVectorValues(const std::vector& val) { + std::vector array; + for (auto i : val) { + array.push_back((cl_uint)stoi(i)); + } + return array; +} + +} // namespace contrib +} // namespace runtime +} // namespace tvm +#endif diff --git a/src/runtime/contrib/clml/clml_utils.h b/src/runtime/contrib/clml/clml_utils.h new file mode 100644 index 0000000000000..79a8312aeb5e9 --- /dev/null +++ b/src/runtime/contrib/clml/clml_utils.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/contrib/clml/clml_utils.h + * \brief CLML utilities header + */ +#ifndef TVM_RUNTIME_CONTRIB_CLML_CLML_UTILS_H_ +#define TVM_RUNTIME_CONTRIB_CLML_CLML_UTILS_H_ +#include +#include +#include + +#include "clml_runtime.h" + +namespace tvm { +namespace runtime { +namespace contrib { + +using namespace tvm::runtime::json; +using JSONGraphNode = tvm::runtime::json::JSONGraphNode; + +void CopyDataToCLMLTensor(std::shared_ptr tensor, void* data, + cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM); + +void CopyDataFromCLMLTensor(std::shared_ptr tensor, void* data, + cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM); + +cl_ml_tensor_qcom DeviceMakeCLMLTensor( + cl_context context, tensor_dims_t dims, + cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, + cl_channel_type dtype = CL_FLOAT); + +cl_mem AllocateOnChipTensorMemory(size_t size, cl_uint on_chip_mem_offset); + +cl_mem AllocateDDRTensorMemory(size_t size); + +tensor_dims_t GetTensorDims(const JSONGraphNode& node); + +cl_channel_type MakeCLDataType(const DLDataType& data_type); + +cl_arithmetic_mode_qcom MakeCLArithMode(const cl_channel_type& data_type, + const cl_channel_type& acc_type = CL_FLOAT); + +std::shared_ptr MakeCLMLTensor(const JSONGraphNode& tensor_rep, + void* data, + std::vector c_shape, + cl_ml_tensor_layout_qcom layout, + cl_uint dtype); + +std::shared_ptr MakeCLMLTensorFromJSONNode( + const JSONGraphNode& node, cl_ml_tensor_layout_qcom layout, cl_uint dtype, void* data = nullptr, + std::vector shape = {}); + +std::vector GetVectorValues(const std::vector& val); + +} // namespace contrib +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_CLML_CLML_UTILS_H_ diff --git a/tests/cpp-runtime/opencl/clml_memory_planner.cc b/tests/cpp-runtime/opencl/clml_memory_planner.cc new file mode 100644 index 0000000000000..3d4d9c41f40e3 --- /dev/null +++ b/tests/cpp-runtime/opencl/clml_memory_planner.cc @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include + +#if defined(TVM_GRAPH_EXECUTOR_CLML) +#include "../src/runtime/contrib/clml/clml_memory_planner.h" +#include "../src/runtime/contrib/clml/clml_runtime.h" +#include "../src/runtime/opencl/opencl_common.h" + +using namespace tvm::runtime; +using namespace tvm::runtime::cl; + +class CLMLMemoryPlannerBin : public ::testing::Test { + protected: + virtual void SetUp() override { + layer.on_chip_pool_size.clear(); + layer.on_chip_pool_size.insert({0, cws->onchip_mem_size}); + layer.on_chip_pool_alloc_info.clear(); + layer.alloc_ping_pong = true; + layer.in_chip_total_free = cws->onchip_mem_size; + layer.in_chip_total_alloc = 0; + layer.on_chip_alert_fail = 0; + + /* clear global pool before each test */ + for (auto it = cws->ddr_global_pool.begin(); it != cws->ddr_global_pool.end(); it++) { + clReleaseMemObject(it->first); + } + cws->ddr_global_pool.clear(); + } + + void PlanMemory(int total_nodes, const std::map& tensor_sizes, + const std::map>& input_tensors) { + for (int nid = 0; nid < total_nodes; ++nid) { + uint32_t size = tensor_sizes.at(nid); + size_t offset = -1; + if (cws->is_on_chip_memory) { + os << "Requesting On-chip:" << nid << std::endl; + offset = RequestOnChipMemory(&layer, size); + } + if (-1 != offset) { + os << "On Chip not found:" << nid << std::endl; + layer.on_chip_pool_alloc_info.insert({offset, nid}); + layer.on_chip_alloc_plan.insert({nid, std::make_pair(size, offset)}); + } else { + os << "Requesting DDR memory:" << nid << std::endl; + layer.on_chip_reject.insert({nid, size}); + // DDR Allocation + auto ddr_mem = RequestDDRMemory(&layer, size); + layer.ddr_alloc_plan.insert({nid, ddr_mem}); + } + + // Now free up the input tensors on-chip memory for reuse. + for (auto& input_node : input_tensors.at(nid)) { + FreeMemory(&layer, input_node); + } + } + + // Stats dump + size_t in_chip_total_alloc = 0; + size_t total_reject = 0; + for (auto it = layer.on_chip_alloc_plan.begin(); it != layer.on_chip_alloc_plan.end(); it++) { + os << " On-chip Alloc:" << it->first << " Size:" << it->second.first + << " Offset:" << it->second.second << std::endl; + in_chip_total_alloc += it->second.first; + } + + for (auto it = layer.on_chip_reject.begin(); it != layer.on_chip_reject.end(); it++) { + os << "Reject:" << it->first << " Size:" << it->second << std::endl; + total_reject += it->second; + } + os << "Total On-chip Alloc:" << in_chip_total_alloc + total_reject + << " On-Chip:" << in_chip_total_alloc << " Reject:" << total_reject << std::endl; + + for (auto it = cws->ddr_global_pool.begin(); it != cws->ddr_global_pool.end(); it++) { + os << "DDR Global pool - size:" << it->second.first << " Ref:" << it->second.second + << std::endl; + } + for (auto it = layer.ddr_storage_ref_map.begin(); it != layer.ddr_storage_ref_map.end(); it++) { + os << "DDR Local pool - size:" << it->second.first << " Ref cnt:" << it->second.second + << std::endl; + } + } + + void CompareOnChipPlan(const std::vector& on_chip_plan) { + for (auto& nid : on_chip_plan) { + EXPECT_EQ(layer.on_chip_alloc_plan.find(nid) == layer.on_chip_alloc_plan.end(), false) + << os.str(); + } + } + + void CompareDDRPlan(const std::vector& ddr_plan) { + for (auto& nid : ddr_plan) { + EXPECT_EQ(layer.ddr_alloc_plan.find(nid) == layer.ddr_alloc_plan.end(), false) << os.str(); + } + } + + void RunTest(const std::map>& input_tensors, + const std::map& tensor_sizes, + const std::vector& on_chip_expected, const std::vector& ddr_expected, + const int ddr_global_pool_size) { + PlanMemory(input_tensors.size(), tensor_sizes, input_tensors); + CompareOnChipPlan(on_chip_expected); + CompareDDRPlan(ddr_expected); + EXPECT_EQ(cws->ddr_global_pool.size(), ddr_global_pool_size) << os.str(); + } + + protected: + tvm::runtime::contrib::CLMLWorkspace* cws = tvm::runtime::contrib::CLMLWorkspace::Global(); + std::stringstream os; + + public: + tvm::runtime::contrib::CachedLayer layer; +}; + +TEST_F(CLMLMemoryPlannerBin, sequential_all_on_chip) { + layer.storage_ref_map.insert({0, 1}); + layer.storage_ref_map.insert({1, 1}); + layer.storage_ref_map.insert({2, 1}); + layer.storage_ref_map.insert({3, 1}); + layer.storage_ref_map.insert({4, 1}); + layer.storage_ref_map.insert({5, 1}); + layer.storage_ref_map.insert({6, 1}); + layer.storage_ref_map.insert({7, 1}); + layer.storage_ref_map.insert({8, 1}); + layer.storage_ref_map.insert({9, 1}); + + layer.life_span.insert({0, 1}); + layer.life_span.insert({1, 2}); + layer.life_span.insert({2, 3}); + layer.life_span.insert({3, 4}); + layer.life_span.insert({4, 5}); + layer.life_span.insert({5, 6}); + layer.life_span.insert({6, 7}); + layer.life_span.insert({7, 8}); + layer.life_span.insert({8, 9}); + layer.life_span.insert({9, 10}); + + std::map tensor_sizes; + tensor_sizes.insert({0, 1024000}); + tensor_sizes.insert({1, 1024000}); + tensor_sizes.insert({2, 1024000}); + tensor_sizes.insert({3, 1024000}); + tensor_sizes.insert({4, 1024000}); + tensor_sizes.insert({5, 1024000}); + tensor_sizes.insert({6, 1024000}); + tensor_sizes.insert({7, 1024000}); + tensor_sizes.insert({8, 1024000}); + tensor_sizes.insert({9, 1024000}); + + std::map> input_tensors{ + {0, {}}, {1, {0}}, {2, {1}}, {3, {2}}, {4, {3}}, + {5, {4}}, {6, {5}}, {7, {6}}, {8, {7}}, {9, {8}}, + }; + + RunTest(input_tensors, tensor_sizes, std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + std::vector({}), 0); +} + +TEST_F(CLMLMemoryPlannerBin, sequential_mixed) { + layer.storage_ref_map.insert({0, 1}); + layer.storage_ref_map.insert({1, 1}); + layer.storage_ref_map.insert({2, 1}); + layer.storage_ref_map.insert({3, 1}); + layer.storage_ref_map.insert({4, 1}); + layer.storage_ref_map.insert({5, 1}); + + layer.life_span.insert({0, 1}); + layer.life_span.insert({1, 2}); + layer.life_span.insert({2, 3}); + layer.life_span.insert({3, 4}); + layer.life_span.insert({4, 5}); + layer.life_span.insert({5, 6}); + + std::map tensor_sizes; + tensor_sizes.insert({0, 1024000}); + tensor_sizes.insert({1, 1024000}); + tensor_sizes.insert({2, cws->onchip_mem_size + 1}); + tensor_sizes.insert({3, 1024000}); + tensor_sizes.insert({4, cws->onchip_mem_size + 1}); + tensor_sizes.insert({5, 1024000}); + + std::map> input_tensors{ + {0, {}}, {1, {0}}, {2, {1}}, {3, {2}}, {4, {3}}, {5, {4}}, + }; + + RunTest(input_tensors, tensor_sizes, std::vector({0, 1, 3, 5}), std::vector({2, 4}), 1); +} + +TEST_F(CLMLMemoryPlannerBin, sequential_all_ddr) { + layer.storage_ref_map.insert({0, 1}); + layer.storage_ref_map.insert({1, 1}); + layer.storage_ref_map.insert({2, 1}); + layer.storage_ref_map.insert({3, 1}); + layer.storage_ref_map.insert({4, 1}); + layer.storage_ref_map.insert({5, 1}); + + layer.life_span.insert({0, 1}); + layer.life_span.insert({1, 2}); + layer.life_span.insert({2, 3}); + layer.life_span.insert({3, 4}); + layer.life_span.insert({4, 5}); + layer.life_span.insert({5, 6}); + + std::map tensor_sizes; + tensor_sizes.insert({0, cws->onchip_mem_size + 1}); + tensor_sizes.insert({1, cws->onchip_mem_size + 1}); + tensor_sizes.insert({2, cws->onchip_mem_size + 1}); + tensor_sizes.insert({3, cws->onchip_mem_size + 1}); + tensor_sizes.insert({4, cws->onchip_mem_size + 1}); + tensor_sizes.insert({5, cws->onchip_mem_size + 1}); + + std::map> input_tensors{ + {0, {}}, {1, {0}}, {2, {1}}, {3, {2}}, {4, {3}}, {5, {4}}, + }; + + RunTest(input_tensors, tensor_sizes, std::vector({}), std::vector({0, 1, 2, 3, 4, 5}), + 2); +} + +TEST_F(CLMLMemoryPlannerBin, branched_all_on_alive_on_chip) { + layer.storage_ref_map.insert({0, 9}); + layer.storage_ref_map.insert({1, 8}); + layer.storage_ref_map.insert({2, 7}); + layer.storage_ref_map.insert({3, 6}); + layer.storage_ref_map.insert({4, 5}); + layer.storage_ref_map.insert({5, 4}); + layer.storage_ref_map.insert({6, 3}); + layer.storage_ref_map.insert({7, 2}); + layer.storage_ref_map.insert({8, 1}); + layer.storage_ref_map.insert({9, 1}); + + layer.life_span.insert({0, 9}); + layer.life_span.insert({1, 9}); + layer.life_span.insert({2, 9}); + layer.life_span.insert({3, 9}); + layer.life_span.insert({4, 9}); + layer.life_span.insert({5, 9}); + layer.life_span.insert({6, 9}); + layer.life_span.insert({7, 9}); + layer.life_span.insert({8, 9}); + layer.life_span.insert({9, 10}); + + std::map tensor_sizes; + tensor_sizes.insert({0, 102400}); + tensor_sizes.insert({1, 102400}); + tensor_sizes.insert({2, 102400}); + tensor_sizes.insert({3, 102400}); + tensor_sizes.insert({4, 102400}); + tensor_sizes.insert({5, 102400}); + tensor_sizes.insert({6, 102400}); + tensor_sizes.insert({7, 102400}); + tensor_sizes.insert({8, 102400}); + tensor_sizes.insert({9, 102400}); + + std::map> input_tensors{ + {0, {}}, + {1, {0}}, + {2, {0, 1}}, + {3, {0, 1, 2}}, + {4, {0, 1, 2, 3}}, + {5, {0, 1, 2, 3, 4}}, + {6, {0, 1, 2, 3, 4, 5}}, + {7, {0, 1, 2, 3, 4, 5, 6}}, + {8, {0, 1, 2, 3, 4, 5, 6, 7}}, + {9, {0, 1, 2, 3, 4, 5, 6, 7, 8}}, + }; + + RunTest(input_tensors, tensor_sizes, std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), + std::vector({}), 0); +} + +TEST_F(CLMLMemoryPlannerBin, branched_all_on_alive_mixed) { + layer.storage_ref_map.insert({0, 9}); + layer.storage_ref_map.insert({1, 8}); + layer.storage_ref_map.insert({2, 7}); + layer.storage_ref_map.insert({3, 6}); + layer.storage_ref_map.insert({4, 5}); + layer.storage_ref_map.insert({5, 4}); + layer.storage_ref_map.insert({6, 3}); + layer.storage_ref_map.insert({7, 2}); + layer.storage_ref_map.insert({8, 1}); + layer.storage_ref_map.insert({9, 1}); + + layer.life_span.insert({0, 9}); + layer.life_span.insert({1, 9}); + layer.life_span.insert({2, 9}); + layer.life_span.insert({3, 9}); + layer.life_span.insert({4, 9}); + layer.life_span.insert({5, 9}); + layer.life_span.insert({6, 9}); + layer.life_span.insert({7, 9}); + layer.life_span.insert({8, 9}); + layer.life_span.insert({9, 10}); + + std::map tensor_sizes; + tensor_sizes.insert({0, 102400}); + tensor_sizes.insert({1, 102400}); + tensor_sizes.insert({2, cws->onchip_mem_size + 1}); + tensor_sizes.insert({3, 102400}); + tensor_sizes.insert({4, cws->onchip_mem_size + 1}); + tensor_sizes.insert({5, 102400}); + tensor_sizes.insert({6, cws->onchip_mem_size + 1}); + tensor_sizes.insert({7, 102400}); + tensor_sizes.insert({8, 102400}); + tensor_sizes.insert({9, 102400}); + + std::map> input_tensors{ + {0, {}}, + {1, {0}}, + {2, {0, 1}}, + {3, {0, 1, 2}}, + {4, {0, 1, 2, 3}}, + {5, {0, 1, 2, 3, 4}}, + {6, {0, 1, 2, 3, 4, 5}}, + {7, {0, 1, 2, 3, 4, 5, 6}}, + {8, {0, 1, 2, 3, 4, 5, 6, 7}}, + {9, {0, 1, 2, 3, 4, 5, 6, 7, 8}}, + }; + + RunTest(input_tensors, tensor_sizes, std::vector({0, 1, 3, 5, 7, 8, 9}), + std::vector({2, 4, 6}), 3); +} + +TEST_F(CLMLMemoryPlannerBin, branched_all_on_alive_all_ddr) { + layer.storage_ref_map.insert({0, 9}); + layer.storage_ref_map.insert({1, 8}); + layer.storage_ref_map.insert({2, 7}); + layer.storage_ref_map.insert({3, 6}); + layer.storage_ref_map.insert({4, 5}); + layer.storage_ref_map.insert({5, 4}); + layer.storage_ref_map.insert({6, 3}); + layer.storage_ref_map.insert({7, 2}); + layer.storage_ref_map.insert({8, 1}); + layer.storage_ref_map.insert({9, 1}); + + layer.life_span.insert({0, 9}); + layer.life_span.insert({1, 9}); + layer.life_span.insert({2, 9}); + layer.life_span.insert({3, 9}); + layer.life_span.insert({4, 9}); + layer.life_span.insert({5, 9}); + layer.life_span.insert({6, 9}); + layer.life_span.insert({7, 9}); + layer.life_span.insert({8, 9}); + layer.life_span.insert({9, 10}); + + std::map tensor_sizes; + tensor_sizes.insert({0, cws->onchip_mem_size + 1}); + tensor_sizes.insert({1, cws->onchip_mem_size + 1}); + tensor_sizes.insert({2, cws->onchip_mem_size + 1}); + tensor_sizes.insert({3, cws->onchip_mem_size + 1}); + tensor_sizes.insert({4, cws->onchip_mem_size + 1}); + tensor_sizes.insert({5, cws->onchip_mem_size + 1}); + tensor_sizes.insert({6, cws->onchip_mem_size + 1}); + tensor_sizes.insert({7, cws->onchip_mem_size + 1}); + tensor_sizes.insert({8, cws->onchip_mem_size + 1}); + tensor_sizes.insert({9, cws->onchip_mem_size + 1}); + + std::map> input_tensors{ + {0, {}}, + {1, {0}}, + {2, {0, 1}}, + {3, {0, 1, 2}}, + {4, {0, 1, 2, 3}}, + {5, {0, 1, 2, 3, 4}}, + {6, {0, 1, 2, 3, 4, 5}}, + {7, {0, 1, 2, 3, 4, 5, 6}}, + {8, {0, 1, 2, 3, 4, 5, 6, 7}}, + {9, {0, 1, 2, 3, 4, 5, 6, 7, 8}}, + }; + RunTest(input_tensors, tensor_sizes, std::vector({}), + std::vector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}), 10); +} + +TEST_F(CLMLMemoryPlannerBin, skip_connections_mixed) { + layer.storage_ref_map.insert({0, 2}); + layer.storage_ref_map.insert({1, 1}); + layer.storage_ref_map.insert({2, 2}); + layer.storage_ref_map.insert({3, 1}); + layer.storage_ref_map.insert({4, 2}); + layer.storage_ref_map.insert({5, 1}); + layer.storage_ref_map.insert({6, 2}); + layer.storage_ref_map.insert({7, 1}); + layer.storage_ref_map.insert({8, 1}); + layer.storage_ref_map.insert({9, 1}); + + layer.life_span.insert({0, 2}); + layer.life_span.insert({1, 2}); + layer.life_span.insert({2, 4}); + layer.life_span.insert({3, 4}); + layer.life_span.insert({4, 6}); + layer.life_span.insert({5, 6}); + layer.life_span.insert({6, 8}); + layer.life_span.insert({7, 8}); + layer.life_span.insert({8, 9}); + layer.life_span.insert({9, 10}); + + std::map tensor_sizes; + tensor_sizes.insert({0, 1024000}); + tensor_sizes.insert({1, 1024000}); + tensor_sizes.insert({2, cws->onchip_mem_size + 1}); + tensor_sizes.insert({3, cws->onchip_mem_size + 1}); + tensor_sizes.insert({4, 1024000}); + tensor_sizes.insert({5, 1024000}); + tensor_sizes.insert({6, cws->onchip_mem_size + 1}); + tensor_sizes.insert({7, cws->onchip_mem_size + 1}); + tensor_sizes.insert({8, 1024000}); + tensor_sizes.insert({9, cws->onchip_mem_size + 1}); + + std::map> input_tensors{ + {0, {}}, {1, {0}}, {2, {0, 1}}, {3, {2}}, {4, {2, 3}}, + {5, {4}}, {6, {4, 5}}, {7, {6}}, {8, {6, 7}}, {9, {8}}, + }; + + RunTest(input_tensors, tensor_sizes, std::vector({0, 1, 4, 5, 8}), + std::vector({2, 3, 6, 7, 9}), 2); +} + +#endif // TVM_GRAPH_EXECUTOR_CLML diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py index 1b9cbdac63b55..42dcf083d02da 100644 --- a/tests/python/contrib/test_clml/infrastructure.py +++ b/tests/python/contrib/test_clml/infrastructure.py @@ -120,6 +120,25 @@ def visit_call(self, call): return c.count +def get_non_cpu_op_count(mod): + """Traverse graph counting ops not offloaded to TVM.""" + + class Counter(tvm.relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if not isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + c = Counter() + c.visit(mod["main"]) + return c.count + + def skip_codegen_test(): """Skip test if it requires the CLML codegen and it's not present.""" if not tvm.get_global_func("relay.ext.clml", True): diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py index 6cb90e7af00f5..e59a73a485ab4 100644 --- a/tests/python/contrib/test_clml/test_ops.py +++ b/tests/python/contrib/test_clml/test_ops.py @@ -562,34 +562,49 @@ def _get_model(x_shape, k_shape, has_bias=False): "op": "const", }, ] - if has_bias: - bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) - out = relay.nn.bias_add(out, bias) - bias_node = { - "attrs": { - "dtype": [[dtype]], - "shape": [[list((1, k_shape[0]))]], - }, - "name": "", - "op": "const", - } - exp_codegen.append(bias_node) - params["bias"] = tvm.nd.array(np.random.uniform(-1, 1, (k_shape[0],)).astype(dtype)) dense_node = { "attrs": { - "num_inputs": "3" if has_bias else "2", + "num_inputs": "2", "num_outputs": "1", "dtype": [[dtype]], "out_dtype": [[""]], "shape": [[[x_shape[0], k_shape[0]]]], "units": [[str(k_shape[0])]], }, - "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] if has_bias else [[0, 0, 0], [1, 0, 0]], + "inputs": [[0, 0, 0], [1, 0, 0]], "name": "nn.dense", "op": "kernel", } exp_codegen.append(dense_node) + + if has_bias: + bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) + out = relay.nn.bias_add(out, bias) + bias_data_node = { + "attrs": { + "dtype": [[dtype]], + "shape": [[list((1, k_shape[0]))]], + }, + "name": "", + "op": "const", + } + exp_codegen.append(bias_data_node) + bias_node = { + "attrs": { + "num_inputs": "2", + "num_outputs": "1", + "dtype": [[dtype]], + "shape": [[[x_shape[0], k_shape[0]]]], + }, + "inputs": [[2, 0, 0], [3, 0, 0]], + "name": "add", + "op": "kernel", + } + exp_codegen.append(bias_node) + + params["bias"] = tvm.nd.array(np.random.uniform(-1, 1, (k_shape[0],)).astype(dtype)) + return out, params, inputs, exp_codegen def _verify(out, params, inputs, exp_codegen): @@ -597,11 +612,11 @@ def _verify(out, params, inputs, exp_codegen): opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] tvm.testing.assert_allclose( - clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-2, atol=1e-2 ) verify_codegen(out, exp_codegen, device, params) - _verify(*(_get_model((1, 16), (32, 16)))) + _verify(*(_get_model((5, 16), (32, 16), False))) _verify(*(_get_model((1, 16), (32, 16), True))) @@ -775,5 +790,66 @@ def _verify(out, params, inputs): _verify(*(_get_model((1, 16, 7, 7), (2, 2), True))) +@pytest.mark.parametrize("dtype", ["float32"]) +@tvm.testing.requires_openclml +def test_batch_matmul(device, dtype): + def _get_model(a_shape, b_shape, a_transpose, b_transpose): + a = relay.var("a", shape=(a_shape), dtype=dtype) + b = relay.var("b", shape=(b_shape), dtype=dtype) + out = relay.nn.batch_matmul(a, b, transpose_a=a_transpose, transpose_b=b_transpose) + inputs = { + "a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype)), + "b": tvm.nd.array(np.random.uniform(-1, 1, b_shape).astype(dtype)), + } + params = {} + return out, params, inputs + + def _verify(out, params, inputs): + mod = IRModule.from_expr(out) + opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0] + clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0] + tvm.testing.assert_allclose( + clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3 + ) + + # Check to make sure these ops are offloaded to CLML instead of TVM. + exp_codegen = [ + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(inputs["a"].shape)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "dtype": [[dtype]], + "shape": [[list(inputs["b"].shape)]], + }, + "name": "", + "op": "input", + }, + { + "attrs": { + "transpose_a": [[str(int(out.attrs.transpose_a))]], + "transpose_b": [[str(int(out.attrs.transpose_b))]], + "out_dtype": [[""]], + "dtype": [[dtype]], + "num_inputs": "2", + "num_outputs": "1", + "shape": [[list(clml_out[0].shape)]], + }, + "inputs": [[0, 0, 0], [1, 0, 0]], + "name": "nn.batch_matmul", + "op": "kernel", + }, + ] + verify_codegen(out, exp_codegen, device, params) + + _verify(*(_get_model((1, 128, 32), (1, 128, 32), False, True))) + _verify(*(_get_model((1, 128, 128), (1, 32, 128), False, True))) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh index 87f50367440cd..80ac461c4e1b5 100755 --- a/tests/scripts/task_build_adreno_bins.sh +++ b/tests/scripts/task_build_adreno_bins.sh @@ -45,6 +45,8 @@ echo set\(ANDROID_ABI arm64-v8a\) >> config.cmake echo set\(ANDROID_PLATFORM android-28\) >> config.cmake echo set\(MACHINE_NAME aarch64-linux-gnu\) >> config.cmake +echo set\(USE_OPENCL_GTEST ON\) >> config.cmake + cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake" \ -DANDROID_ABI=arm64-v8a \ -DANDROID_PLATFORM=android-28 \ @@ -56,4 +58,4 @@ cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK_HOME}/build/cmake/android.toolchain. -DCMAKE_C_COMPILER="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang" \ -DMACHINE_NAME="aarch64-linux-gnu" .. -make -j$(nproc) tvm_rpc rtvm +make -j$(nproc) tvm_rpc rtvm opencl-cpptest