Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[OpenCLML] CLML Profiling fixes corresponding to OpenCL Timer recent … (
Browse files Browse the repository at this point in the history
apache#12711)

* [OpenCLML] CLML Profiling fixes corresponding to OpenCL Timer recent changes.

* [OpenCLML] Review comments.

* * review comment
  • Loading branch information
srkreddy1238 authored and xinetzone committed Nov 25, 2022
1 parent 3dae16d commit 96f56e2
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 93 deletions.
161 changes: 75 additions & 86 deletions src/runtime/contrib/clml/clml_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,37 +131,14 @@ class CLMLRuntime : public JSONRuntimeBase {
// Setup CLML Context
cl_int result = 0;

// Initialize Context and Command Queue
result = clGetPlatformIDs(1, &platform, NULL);
ICHECK(result == CL_SUCCESS) << "clGetPlatformIDs:" << result;
workspace = cl::OpenCLWorkspace::Global();
workspace->Init();
tentry = workspace->GetThreadEntry();

uint32_t num_devices = 0;
result = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 0, NULL, &num_devices);
ICHECK(result == CL_SUCCESS && num_devices == 1) << "clGetDeviceIDs:" << result;

result = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &device_id, NULL);
ICHECK(device_id && result == CL_SUCCESS) << "clGetDeviceIDs:" << result;

if (!ExtensionStringPresent(device_id)) {
if (!ExtensionStringPresent()) {
LOG(WARNING) << "CLML Runtime Init: Qualcomm extn not present.\n";
return;
}

// Reuse the OpenCl work space from TVM Device API.
auto func = tvm::runtime::Registry::Get("device_api.opencl");
ICHECK(func != nullptr) << "Cannot find OpenCL device_api in registry";
auto device_api = static_cast<cl::OpenCLWorkspace*>(((*func)()).operator void*());
this->context = device_api->context;
bool queue_found = false;
for (size_t i = 0; i < device_api->devices.size(); ++i) {
if (device_api->devices[i] == device_id) {
this->queue = device_api->queues[i];
this->evts = &(device_api->events[i]);
queue_found = true;
}
}
ICHECK(queue_found != false) << "Device queue not found in OpenCL Workspace";

// Query and Get CLML Interface
static const cl_uint MAX_VERSIONS = 256;
cl_int majorVersions[MAX_VERSIONS];
Expand Down Expand Up @@ -220,8 +197,8 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM) {
cl_int result = 0;
cl_event evt = NULL;
result = h_ClmlIntf->clEnqueueWriteMLTensorDataQCOM(queue, data, layout, tensor->tensor,
tensor->memory,
result = h_ClmlIntf->clEnqueueWriteMLTensorDataQCOM(workspace->GetQueue(tentry->device), data,
layout, tensor->tensor, tensor->memory,
0, // n waitlist
NULL, // waitlist
&evt); // event
Expand All @@ -233,8 +210,8 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_int result = 0;
cl_event readEvent = NULL;
// Read the output tensor
result = h_ClmlIntf->clEnqueueReadMLTensorDataQCOM(queue, tensor->tensor, tensor->memory, data,
layout,
result = h_ClmlIntf->clEnqueueReadMLTensorDataQCOM(workspace->GetQueue(tentry->device),
tensor->tensor, tensor->memory, data, layout,
0, // n waitlist
NULL, // waitlist
&readEvent); // event
Expand All @@ -253,6 +230,8 @@ class CLMLRuntime : public JSONRuntimeBase {
*/
void Run() override {
cl_int result = 0;
cl_command_queue queue = workspace->GetQueue(tentry->device);
std::vector<cl_event>& evts = workspace->GetEventQueue(tentry->device);
for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto nid = input_nodes_[i];
uint32_t eid = EntryID(nid, 0);
Expand Down Expand Up @@ -286,22 +265,26 @@ class CLMLRuntime : public JSONRuntimeBase {
}

for (size_t i = 0; i < this->layer_.function.size(); ++i) {
this->evts->resize(this->evts->size() + 1);
cl_event* evt = &(this->evts->back());
result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i],
this->layer_.descriptorSet, 0, NULL, evt);
if (getenv("CLML_PROFILING")) {
evts.resize(evts.size() + 1);
cl_event* evt = &(evts.back());
result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i],
this->layer_.descriptorSet, 0, NULL, evt);
} else {
result = h_ClmlIntf->clEnqueueMLOpQCOM(queue, this->layer_.function[i],
this->layer_.descriptorSet, 0, NULL, NULL);
}
ICHECK(result == CL_SUCCESS) << "clEnqueueMLOpQCOM:" << result;
}

if (getenv("CLML_PROFILING")) {
cl_ulong start, end;
cl_ulong duration = 0;
clWaitForEvents(1, &(this->evts->back()));
clWaitForEvents(1, &(evts.back()));
for (size_t i = 0; i < this->layer_.layer_names.size(); ++i) {
clGetEventProfilingInfo((*this->evts)[i], CL_PROFILING_COMMAND_START, sizeof(cl_ulong),
&start, nullptr);
clGetEventProfilingInfo((*this->evts)[i], CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end,
clGetEventProfilingInfo(evts[i], CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start,
nullptr);
clGetEventProfilingInfo(evts[i], CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, nullptr);
duration += (end - start);
LOG(WARNING) << "Layer:" << this->layer_.layer_names[i] << " Duration:" << (end - start);
}
Expand Down Expand Up @@ -425,7 +408,7 @@ class CLMLRuntime : public JSONRuntimeBase {
JSONGraphNode node = it->second.second;
void* node_data = nullptr;

allocateTensorMemory(h_ClmlIntf, context, tensor_desc);
allocateTensorMemory(h_ClmlIntf, workspace->context, tensor_desc);

if (node.GetOpType() == "const") {
node_data = data_entry_[EntryID(it->first, 0)]->data;
Expand All @@ -449,8 +432,9 @@ class CLMLRuntime : public JSONRuntimeBase {
LOG(WARNING) << "CLML Tunning In Progress:";
for (size_t i = 0; i < this->layer_.function.size(); ++i) {
LOG(WARNING) << "CLML Tunning:" << i;
result = h_ClmlIntf->clTuneMLOpQCOM(queue, this->layer_.function[i],
this->layer_.descriptorSet, this->tuning_cache, NULL);
result = h_ClmlIntf->clTuneMLOpQCOM(workspace->GetQueue(tentry->device),
this->layer_.function[i], this->layer_.descriptorSet,
this->tuning_cache, NULL);
ICHECK(result == CL_SUCCESS) << "clTuneMLOpQCOM:" << result;
}

Expand Down Expand Up @@ -499,10 +483,13 @@ class CLMLRuntime : public JSONRuntimeBase {
uint32_t n, c, h, w;
};

bool ExtensionStringPresent(cl_device_id device_id) {
bool ExtensionStringPresent(void) {
cl_int result = 0;

if (workspace->platform_id == nullptr) {
return 0;
}
size_t reqd_size = 0;
cl_device_id device_id = workspace->devices[workspace->GetThreadEntry()->device.device_id];
result = clGetDeviceInfo(device_id, CL_DEVICE_EXTENSIONS, 0, NULL, &reqd_size);
ICHECK(reqd_size > 0u && result == CL_SUCCESS) << "clGetDeviceInfo:" << result;

Expand All @@ -525,7 +512,7 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_ml_tensor_desc_qcom desc = {
dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, { 0 }};
CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast<CLMLInterfaceV2QCOM*>(pClmlIntf);
result = clmlIntf->clCreateMLTensorQCOM(context, NULL, &desc, &tensor);
result = clmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor);
ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result;
(void)result;
return tensor;
Expand All @@ -538,10 +525,11 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_mem buffer = NULL;

CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast<CLMLInterfaceV2QCOM*>(pClmlIntf);
result = clmlIntf->clGetMLTensorMemorySizeQCOM(context, pTensorMemDesc->tensor, &size);
result =
clmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size);
ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result;

buffer = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &result);
buffer = clCreateBuffer(workspace->context, CL_MEM_READ_WRITE, size, NULL, &result);
ICHECK(result == CL_SUCCESS) << "clCreateBuffer:" << result;

pTensorMemDesc->memory = buffer;
Expand Down Expand Up @@ -592,7 +580,8 @@ class CLMLRuntime : public JSONRuntimeBase {
cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);

auto tensor_dsc = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
tensor_dsc->tensor = DeviceMakeCLMLTensor(h_ClmlIntf, context, dims, layout, cl_dtype);
tensor_dsc->tensor =
DeviceMakeCLMLTensor(h_ClmlIntf, workspace->context, dims, layout, cl_dtype);
return tensor_dsc;
}

Expand Down Expand Up @@ -703,7 +692,8 @@ class CLMLRuntime : public JSONRuntimeBase {
} else {
cl_ml_tensor_desc_qcom desc = {};
desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
result = h_ClmlIntf->clCreateMLTensorQCOM(context, NULL, &desc, &layer_.unusedTensor);
result =
h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &layer_.unusedTensor);
ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result;
bias->tensor = layer_.unusedTensor;
}
Expand All @@ -723,13 +713,13 @@ class CLMLRuntime : public JSONRuntimeBase {
if (!has_bn) {
if (!has_act) {
result = h_ClmlIntf->clCreateMLOpConvolutionForwardQCOM(
context, 0, &conv_desc, input->tensor, weight->tensor, bias->tensor, output->tensor,
&op, NULL);
workspace->context, 0, &conv_desc, input->tensor, weight->tensor, bias->tensor,
output->tensor, &op, NULL);
ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result;
} else {
result = h_ClmlIntf->clCreateMLOpFusedConvolutionActivationForwardQCOM(
context, 0, &conv_desc, &act_desc, input->tensor, weight->tensor, bias->tensor, NULL,
output->tensor, &op, tuning_cache);
workspace->context, 0, &conv_desc, &act_desc, input->tensor, weight->tensor,
bias->tensor, NULL, output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result;
}
layer_.func_ins.push_back(input);
Expand All @@ -753,13 +743,13 @@ class CLMLRuntime : public JSONRuntimeBase {
CL_ARITHMETIC_MODE_FP32_QCOM};
if (!has_act) {
result = h_ClmlIntf->clCreateMLOpFusedConvolutionBatchNormForwardQCOM(
context, 0, &conv_desc, &bn_desc, input->tensor, weight->tensor, bias->tensor,
output->tensor, bn_mean->tensor, bn_var->tensor, bn_scale->tensor, bn_bias->tensor, &op,
tuning_cache);
workspace->context, 0, &conv_desc, &bn_desc, input->tensor, weight->tensor,
bias->tensor, output->tensor, bn_mean->tensor, bn_var->tensor, bn_scale->tensor,
bn_bias->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Convolution Error:" << result;
} else {
result = h_ClmlIntf->clCreateMLOpFusedConvolutionBatchNormActivationForwardQCOM(
context, 0, &conv_desc, &bn_desc, &act_desc, input->tensor, weight->tensor,
workspace->context, 0, &conv_desc, &bn_desc, &act_desc, input->tensor, weight->tensor,
bias->tensor, output->tensor, NULL, bn_mean->tensor, bn_var->tensor, bn_scale->tensor,
bn_bias->tensor, &op, tuning_cache);

Expand Down Expand Up @@ -790,12 +780,13 @@ class CLMLRuntime : public JSONRuntimeBase {

cl_ml_tensor_desc_qcom desc = {};
desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
result = h_ClmlIntf->clCreateMLTensorQCOM(context, NULL, &desc, &layer_.unusedTensor);
result =
h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &layer_.unusedTensor);
ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << ":" << result;

result = h_ClmlIntf->clCreateMLOpActivationForwardQCOM(context, 0, &act_desc, input->tensor,
layer_.unusedTensor, output->tensor, &op,
tuning_cache);
result = h_ClmlIntf->clCreateMLOpActivationForwardQCOM(workspace->context, 0, &act_desc,
input->tensor, layer_.unusedTensor,
output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Activation Error:" << result;

layer_.func_ins.push_back(input);
Expand Down Expand Up @@ -834,8 +825,8 @@ class CLMLRuntime : public JSONRuntimeBase {
CL_ARITHMETIC_MODE_FP32_QCOM};

result = h_ClmlIntf->clCreateMLOpBatchNormForwardQCOM(
context, 0, &bn_desc, input->tensor, bn_mean->tensor, bn_var->tensor, bn_scale->tensor,
bn_bias->tensor, output->tensor, &op, tuning_cache);
workspace->context, 0, &bn_desc, input->tensor, bn_mean->tensor, bn_var->tensor,
bn_scale->tensor, bn_bias->tensor, output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Batchnorm Error:" << result;

layer->function.push_back(op);
Expand Down Expand Up @@ -872,12 +863,13 @@ class CLMLRuntime : public JSONRuntimeBase {

cl_ml_tensor_desc_qcom desc = {};
desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
result = h_ClmlIntf->clCreateMLTensorQCOM(context, NULL, &desc, &layer_.unusedTensor);
result =
h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &layer_.unusedTensor);
ICHECK(layer_.unusedTensor && result == CL_SUCCESS) << ":" << result;

result = h_ClmlIntf->clCreateMLOpPoolingForwardQCOM(context, 0, &pool_desc, input->tensor,
layer_.unusedTensor, output->tensor, &op,
tuning_cache);
result = h_ClmlIntf->clCreateMLOpPoolingForwardQCOM(workspace->context, 0, &pool_desc,
input->tensor, layer_.unusedTensor,
output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result;

layer_.func_ins.push_back(input);
Expand All @@ -904,8 +896,8 @@ class CLMLRuntime : public JSONRuntimeBase {
CL_SOFTMAX_MODE_INSTANCE_QCOM,
CL_ARITHMETIC_MODE_FP32_QCOM};

result = h_ClmlIntf->clCreateMLOpSoftmaxQCOM(context, 0, &softmax_desc, input->tensor,
output->tensor, &op, tuning_cache);
result = h_ClmlIntf->clCreateMLOpSoftmaxQCOM(workspace->context, 0, &softmax_desc,
input->tensor, output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "SoftMax Error:" << result;

layer_.func_ins.push_back(input);
Expand Down Expand Up @@ -946,8 +938,8 @@ class CLMLRuntime : public JSONRuntimeBase {
{clml_padding[0], clml_padding[1], clml_padding[2], clml_padding[3], 0, 0, 0, 0},
CL_ARITHMETIC_MODE_FP32_QCOM};

result = h_ClmlIntf->clCreateMLOpPadQCOM(context, 0, &pad_desc, input->tensor, output->tensor,
&op, tuning_cache);
result = h_ClmlIntf->clCreateMLOpPadQCOM(workspace->context, 0, &pad_desc, input->tensor,
output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Pad Error:" << result;

layer_.func_ins.push_back(input);
Expand All @@ -968,8 +960,8 @@ class CLMLRuntime : public JSONRuntimeBase {
auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
auto output = MakeCLMLTensorFromJSONNode(node);

result = h_ClmlIntf->clCreateMLOpReshapeQCOM(context, 0, input->tensor, output->tensor, &op,
tuning_cache);
result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->context, 0, input->tensor,
output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result;

layer_.func_ins.push_back(input);
Expand Down Expand Up @@ -1004,13 +996,13 @@ class CLMLRuntime : public JSONRuntimeBase {

auto output = MakeCLMLTensorFromJSONNode(node);
if (has_bias) {
result = h_ClmlIntf->clCreateMLOpFullyConnectedQCOM(context, 0, &fc_desc, input->tensor,
weight->tensor, bias->tensor,
output->tensor, &op, tuning_cache);
result = h_ClmlIntf->clCreateMLOpFullyConnectedQCOM(
workspace->context, 0, &fc_desc, input->tensor, weight->tensor, bias->tensor,
output->tensor, &op, tuning_cache);
} else {
result = h_ClmlIntf->clCreateMLOpFullyConnectedQCOM(context, 0, &fc_desc, input->tensor,
weight->tensor, NULL, output->tensor, &op,
tuning_cache);
result = h_ClmlIntf->clCreateMLOpFullyConnectedQCOM(workspace->context, 0, &fc_desc,
input->tensor, weight->tensor, NULL,
output->tensor, &op, tuning_cache);
}
ICHECK(op && result == CL_SUCCESS) << "Fully Connected Error:" << result;

Expand Down Expand Up @@ -1039,8 +1031,8 @@ class CLMLRuntime : public JSONRuntimeBase {
{{a_min}, CL_FLOAT},
CL_ARITHMETIC_MODE_FP32_QCOM};

result = h_ClmlIntf->clCreateMLOpClipQCOM(context, 0, &clip_desc, input->tensor, output->tensor,
&op, tuning_cache);
result = h_ClmlIntf->clCreateMLOpClipQCOM(workspace->context, 0, &clip_desc, input->tensor,
output->tensor, &op, tuning_cache);
ICHECK(op && result == CL_SUCCESS) << "Clip Error:" << result;

layer_.func_ins.push_back(input);
Expand All @@ -1056,11 +1048,8 @@ class CLMLRuntime : public JSONRuntimeBase {
CachedLayer layer_;
// CLML Context
CLMLInterfaceV2QCOM* h_ClmlIntf = NULL;
cl_platform_id platform = NULL;
cl_context context = NULL;
cl_device_id device_id = NULL;
cl_command_queue queue = NULL;
std::vector<cl_event>* evts;
cl::OpenCLWorkspace* workspace = NULL;
cl::OpenCLThreadEntry* tentry = NULL;
cl_ml_tuningcache_qcom tuning_cache = NULL;
bool is_tuning_run;
char* tuning_file;
Expand Down
6 changes: 3 additions & 3 deletions tests/python/contrib/test_clml/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ class Device:

connection_type = "tracker"
host = "localhost"
port = 9090
port = 9150
target = "opencl"
target_host = "llvm -mtriple=aarch64-linux-gnu"
device_key = ""
cross_compile = ""
device_key = "android"
cross_compile = "aarch64-linux-android-g++"

def __init__(self):
"""Keep remote device for lifetime of object."""
Expand Down
4 changes: 1 addition & 3 deletions tests/python/contrib/test_clml/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from tvm import relay

import tvm
from test_clml.infrastructure import skip_runtime_test, build_and_run
from test_clml.infrastructure import Device
from test_clml.infrastructure import skip_runtime_test, build_and_run, Device


def _build_and_run_network(mod, params, inputs, data, device, atol, rtol):
Expand Down Expand Up @@ -86,7 +85,6 @@ def get_model():
mobilenet = MobileNet(
include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
)
mobilenet.load_weights("mobilenet_1_0_224_tf.h5")
inputs = {mobilenet.input_names[0]: ((1, 3, 224, 224), "float32")}

data = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_clml/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,5 @@ def test_batchnorm():


if __name__ == "__main__":
# test_conv2d()
test_conv2d()
test_batchnorm()

0 comments on commit 96f56e2

Please sign in to comment.