Skip to content

Commit

Permalink
Reactor and introduce in chip memory and memory planner
Browse files Browse the repository at this point in the history
Introduced thread context with CLMLWorkspace.
Organized the code as runtime, utils and memory planners
Introcuded recording queue support and on chip memory support.
On chip memory allocation planner to acommodate multiple tensors at a time.
DDR memory planner introduced to reuse the underlaying memory across
multiple tensor descriptors.

Dense layer support refactored to use GEMM.
CLML binary operators doesn't support broadcasting. Hence introduced an explicite
broadcast op as a work around.

clml SDK codegen is enhanced accordingly.
  • Loading branch information
Siva Rama Krishna Reddy B authored and srkreddy1238 committed May 24, 2023
1 parent 4f99750 commit 3bf7e63
Show file tree
Hide file tree
Showing 12 changed files with 1,610 additions and 549 deletions.
36 changes: 20 additions & 16 deletions apps/cpp_clml/clml_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ CLMLRunner::CLMLRunner(std::string name, ToolArgs& args, cl_platform_id arg_plat
context(arg_context),
device_id(arg_device_id),
queue(arg_queue) {
LOG(INFO) << "CLMLRunner Constructor: Input:" << r_args.input << " Output:" << r_args.output
<< " Params:" << r_args.params;
LOG(INFO) << "CLMLRunner Constructor:" << name << " Input:" << r_args.input
<< " Output:" << r_args.output << " Params:" << r_args.params;
cl_int result;

// Query and Get CLML Interface
Expand Down Expand Up @@ -648,25 +648,29 @@ void CLMLRunner::MakeConcatenate(
void CLMLRunner::MakeDense(std::shared_ptr<cl_ml_tensor_memory_desc_qcom> input_desc,
std::shared_ptr<cl_ml_tensor_memory_desc_qcom> weight_desc,
std::shared_ptr<cl_ml_tensor_memory_desc_qcom> output_desc,
std::shared_ptr<cl_ml_tensor_memory_desc_qcom> bias_desc,
std::vector<cl_uint> in_shape, std::vector<cl_uint> wt_shape,
std::string dtype) {
cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(MakeCLDataType(dtype));
cl_ml_op_qcom op = nullptr;
cl_int result;
cl_gemm_transform_qcom b_transform = CL_GEMM_TRANSFORM_NONE_QCOM;

cl_ml_op_convolution_desc_qcom conv_desc = {CL_CONVOLUTION_MODE_CONVOLUTION_QCOM,
1,
4,
{0, 0},
{0, 0},
{1, 1},
{1, 1},
0,
cl_arithmetic_mode};

result = h_ClmlIntf->clCreateMLOpConvolutionForwardQCOM(
this->context, 0, &conv_desc, input_desc->tensor, weight_desc->tensor, bias_desc->tensor,
output_desc->tensor, &op, tuning_cache);
if (in_shape[1] == wt_shape[1]) {
b_transform = CL_GEMM_TRANSFORM_TRANSPOSE_QCOM;
}

cl_ml_op_gemm_desc_qcom gemmDesc = {in_shape[0], // m
wt_shape[0], // n
wt_shape[1], // k
CL_GEMM_TRANSFORM_NONE_QCOM, // A transform
b_transform, // B transform
{{1.0}, CL_FLOAT}, // alpha
{{0.0}, CL_FLOAT}, // beta
cl_arithmetic_mode};

result =
h_ClmlIntf->clCreateMLOpGemmQCOM(this->context, 0, &gemmDesc, input_desc->tensor,
weight_desc->tensor, output_desc->tensor, &op, tuning_cache);

CLML_SDK_TEST_AND_EXIT(op && result == CL_SUCCESS);
this->function.push_back(op);
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_clml/clml_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class CLMLRunner {
void MakeDense(std::shared_ptr<cl_ml_tensor_memory_desc_qcom> input_desc,
std::shared_ptr<cl_ml_tensor_memory_desc_qcom> weight_desc,
std::shared_ptr<cl_ml_tensor_memory_desc_qcom> output_desc,
std::shared_ptr<cl_ml_tensor_memory_desc_qcom> bias_desc, std::string dtype);
std::vector<cl_uint> in_shape, std::vector<cl_uint> wt_shape, std::string dtype);

/*! \brief SoftMax layer implementattion */
void MakeSoftMax(std::shared_ptr<cl_ml_tensor_memory_desc_qcom> input_desc,
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_clml/scripts/clml_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main():
clml_mod = clml.partition_for_clml(mod, params)
libm = relay.build(
clml_mod,
target="opencl -device=adreno",
target="opencl",
target_host="llvm -mtriple=aarch64-linux-gnu",
params=params,
)
Expand Down
161 changes: 106 additions & 55 deletions python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,36 @@ def transform_function(
return RemoveDropout().visit(func)


class BroadcastInputs(ExprMutator):
"""
Binary operators need broadcasting for CLML.
"""

def visit_call(self, call):
if call.op.name in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
new_fn = self.visit(call.op)
call_shape = call.checked_type.shape
lhs = call.args[0]
rhs = call.args[1]
lhs_shape = lhs.checked_type.shape
rhs_shape = rhs.checked_type.shape
if list(call_shape) != list(lhs_shape):
lhs = relay.broadcast_to(self.visit(lhs), call_shape)
if list(call_shape) != list(rhs_shape):
rhs = relay.broadcast_to(self.visit(rhs), call_shape)
args = [lhs, rhs]
return Call(new_fn, args, call.attrs)
return super().visit_call(call)


@transform.function_pass(opt_level=0)
class BinaryOpBroadcaster:
def transform_function(
self, func: relay.function.Function, mod: tvm.IRModule, _: tvm.transform.PassContext
) -> relay.function.Function:
return BroadcastInputs().visit(func)


def partition_for_clml(mod, params=None, **opts):
"""Partition the graph greedily offloading supported
operators to CLML Library.
Expand All @@ -104,6 +134,7 @@ def partition_for_clml(mod, params=None, **opts):
[
transform.InferType(),
RemoveDropoutPass(),
BinaryOpBroadcaster(),
transform.FoldConstant(),
transform.MergeComposite(clml_pattern_table()),
transform.AnnotateTarget("clml", False),
Expand Down Expand Up @@ -261,8 +292,6 @@ def concat_pattern():
def dense_pattern():
"""Create a dense pattern."""
pattern = is_op("nn.dense")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
return pattern

def pad_pattern():
Expand Down Expand Up @@ -344,9 +373,19 @@ def check_conv_transpose(extract):

def check_binary_op(extract):
call = extract
if len(call.args[1].checked_type.shape) > 0:
return True
return False
# Scalers are not supported
if len(call.args[1].checked_type.shape) == 0:
return False

for arg in call.args:
# Avoid any operators with dtype Int64
if arg.checked_type.dtype == "int64":
return False
# No support for batch> 1
if arg.checked_type.shape[0] > 1:
return False

return True

def check_pad_op(extract):
call = extract
Expand Down Expand Up @@ -377,6 +416,20 @@ def check_concat_op(extract):
return True

def check_default_op(extract):
call = extract
# Avoid any operators with dtype Int64
for arg in call.args:
if arg.checked_type.dtype == "int64":
return False
return True

def check_batch_matmul_op(extract):
call = extract
# Only support single Matmul
if call.args[0].checked_type.shape[0] > 1:
return False
if call.args[1].checked_type.shape[0] > 1:
return False
return True

return [
Expand All @@ -394,7 +447,7 @@ def check_default_op(extract):
("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op),
("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op),
("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op),
("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
# ("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op),
("clml.max_pool2d", is_op("nn.max_pool2d")(wildcard()), check_default_op),
("clml.global_avg_pool2d", is_op("nn.global_avg_pool2d")(wildcard()), check_default_op),
Expand All @@ -404,6 +457,11 @@ def check_default_op(extract):
("clml.batch_flatten", is_op("nn.batch_flatten")(wildcard()), check_default_op),
("clml.depth_to_space", is_op("nn.depth_to_space")(wildcard()), check_default_op),
("clml.upsampling", is_op("nn.upsampling")(wildcard()), check_upsampling_op),
(
"clml.batch_matmul",
is_op("nn.batch_matmul")(wildcard(), wildcard()),
check_batch_matmul_op,
),
]


Expand Down Expand Up @@ -570,7 +628,9 @@ def __init__(self, cmod):
runner.MakeDense($input_tensor,
$weight_tensor,
$output_tensor,
$bias_tensor, "$dtype");"""
std::vector<cl_uint> ({$in_shape}),
std::vector<cl_uint> ({$wt_shape}),
"$dtype");"""
)
self.MakeSoftMax = Template(
"""
Expand Down Expand Up @@ -641,13 +701,12 @@ def __init__(self, cmod):
" Output Count : $output_count\\n"
' Input MetaInfo\\n$input_meta\\n Output MetaInfo\\n$output_meta");'
)

self.MakeInputMetaInfo = Template(
" Input: $in_name\\n Dtype : $dtype\\n Shape : [$shape]"
" Input: $in_name\\n Dtype : $dtype\\n Shape : [$shape]\\n"
)

self.MakeOutputMetaInfo = Template(
" Output: $out_name\\n Dtype : $dtype\\n Shape : [$shape]"
" Output: $out_name\\n Dtype : $dtype\\n Shape : [$shape]\\n"
)

def get_src(self):
Expand All @@ -666,23 +725,40 @@ def get_tensor_from_map(
else:
node = self.nodes[node_seq]
dtype = str(node["attrs"]["dtype"][0][0])
if node["op"] == "input":
self.clml_code.append("// Input Node")
node_out_name = self.sub_module_name + "_" + "input_" + str(node_seq)
else:
node_out_name = node["name"]
if shape is None:
shape = str(tuple(node["attrs"]["shape"][0][0]))[1:-1]

self.clml_code.append(
self.MakeCLMLTensor.substitute(
name=node["name"], shape=shape, dtype=dtype, layout=layout
name=node_out_name, shape=shape, dtype=dtype, layout=layout
)
)
self.clml_code.append(
self.MapInsert.substitute(nid=node["name"], tensor_desc=node["name"])
self.MapInsert.substitute(nid=node_out_name, tensor_desc=node_out_name)
)
if node["op"] == "input":
self.clml_code.append(
Template("runner.inputs.push_back($clml_input);").substitute(
clml_input=node_out_name
)
)
self.input_meta.append(
self.MakeInputMetaInfo.substitute(
in_name=node_out_name, dtype=dtype, shape=shape
)
)

if self.nodes[node_seq]["op"] == "const":
self.clml_code.append(
Template('runner.consts.push_back("$nid");').substitute(nid=node["name"])
)
self.node_map[node_seq] = node["name"]
return node["name"]
self.node_map[node_seq] = node_out_name
return node_out_name

def make_output_tensor(
node, node_seq, shape=None, layout="CL_TENSOR_LAYOUT_OPTIMAL_QCOM", dtype="float32"
Expand All @@ -697,40 +773,13 @@ def make_output_tensor(
name=node_out_name,
shape=shape,
dtype=dtype,
layout="CL_TENSOR_LAYOUT_OPTIMAL_QCOM",
layout=layout,
)
)
return node_out_name

for node_seq, node in enumerate(self.nodes):
if node["op"] == "input":
self.clml_code.append("// Input Node")
dtype = str(node["attrs"]["dtype"][0][0])
shape = str(tuple(node["attrs"]["shape"][0][0]))[1:-1]
node_out_name = self.sub_module_name + "_" + "input_" + str(node_seq)
self.clml_code.append(
self.MakeCLMLTensor.substitute(
name=node_out_name,
shape=shape,
dtype=dtype,
layout="CL_TENSOR_LAYOUT_OPTIMAL_QCOM",
)
)
self.clml_code.append(
self.MapInsert.substitute(nid=node_out_name, tensor_desc=node_out_name)
)
self.clml_code.append(
Template("runner.inputs.push_back($clml_input);").substitute(
clml_input=node_out_name
)
)
self.node_map[node_seq] = node_out_name
self.input_meta.append(
self.MakeInputMetaInfo.substitute(
in_name=node_out_name, dtype=dtype, shape=shape
)
)
elif node["op"] == "kernel":
if node["op"] == "kernel":
self.clml_code.append("// Kernel Node : " + node["name"])
if node["name"] == "nn.conv2d" or node["name"] == "nn.depthwise_conv2d":
if "padding" in node["attrs"]:
Expand Down Expand Up @@ -791,6 +840,7 @@ def make_output_tensor(
bn_shape = [1, 1, 1, 1]
bn_node = self.nodes[node["inputs"][bn_index][0]]
bn_shape[axis] = bn_node["attrs"]["shape"][0][0]
dtype = bn_node["attrs"]["dtype"][0][0]

bn_scale_tensor = get_tensor_from_map(
node["inputs"][bn_index][0],
Expand Down Expand Up @@ -858,6 +908,7 @@ def make_output_tensor(
bn_shape = [1, 1, 1, 1]
bn_node = self.nodes[node["inputs"][0][0]]
bn_shape[axis] = bn_node["attrs"]["shape"][0][0]
dtype = bn_node["attrs"]["dtype"][0][0]
bn_scale_tensor = get_tensor_from_map(
node["inputs"][0][0], shape=str(tuple(bn_shape))[1:-1], dtype=dtype
)
Expand Down Expand Up @@ -947,26 +998,26 @@ def make_output_tensor(
in_shape = tuple(in_node["attrs"]["shape"][0][0])
wt_shape = tuple(in_node["attrs"]["shape"][0][0])
input_tensor = get_tensor_from_map(
node["inputs"][0][0], shape=str(tuple([1, in_shape[1], 1, 1]))[1:-1]
node["inputs"][0][0], layout="CL_TENSOR_LAYOUT_NCHW_QCOM"
)
weight_tensor = get_tensor_from_map(
node["inputs"][1][0],
shape=str(tuple([wt_shape[0], wt_shape[1], 1, 1]))[1:-1],
shape=str(tuple([1, 1, wt_shape[0], wt_shape[1]]))[1:-1],
layout="CL_TENSOR_LAYOUT_NCHW_QCOM",
)
if len(node["inputs"]) == 3:
bias_tensor = "runner.unusedTensor"
else:
bias_tensor = get_tensor_from_map(node["inputs"][2][0])

node_out_name = make_output_tensor(
node, node_seq, shape=str(tuple([1, wt_shape[0], 1, 1]))[1:-1]
node,
node_seq,
shape=str(tuple([in_shape[0], wt_shape[0], 1, 1]))[1:-1],
layout="CL_TENSOR_LAYOUT_NCHW_QCOM",
)
self.clml_code.append(
self.MakeDense.substitute(
input_tensor=input_tensor,
weight_tensor=weight_tensor,
output_tensor=node_out_name,
bias_tensor=bias_tensor,
in_shape=str(in_shape)[1:-1],
wt_shape=str(wt_shape)[1:-1],
dtype=node["attrs"]["dtype"][0][0],
)
)
Expand Down Expand Up @@ -1045,7 +1096,7 @@ def make_output_tensor(
)
self.node_map[node_seq] = node_out_name

elif node["op"] != "const":
elif node["op"] not in ["const", "input"]:
print("Unknown Node type:", node["op"])

# Populate outputs
Expand Down Expand Up @@ -1086,8 +1137,8 @@ def make_output_tensor(
name=self.sub_module_name,
input_count=len(self.input_meta),
output_count=len(self.output_meta),
input_meta="\n".join(self.input_meta),
output_meta="\n".join(self.output_meta),
input_meta="\\\n".join(self.input_meta),
output_meta="\\\n".join(self.output_meta),
)
)

Expand Down
Loading

0 comments on commit 3bf7e63

Please sign in to comment.