Skip to content

Commit

Permalink
[RUNTIME] Enable OpenCL
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 18, 2017
1 parent e9ff9a8 commit c29a292
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 18 deletions.
17 changes: 16 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ endif
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
export FRAMEWORKS=

ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS)
Expand All @@ -43,6 +44,20 @@ else
CFLAGS += -DTVM_CUDA_RUNTIME=0
endif


ifeq ($(USE_OPENCL), 1)
CFLAGS += -DTVM_OPENCL_RUNTIME=1
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
FRAMEWORKS += -framework OpenCL
else
LDFLAGS += -lOpenCL
endif
else
CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif


include tests/cpp/unittest.mk

test: $(TEST)
Expand All @@ -59,7 +74,7 @@ lib/libtvm.a: $(ALL_DEP)

lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

$(LIB_HALIDE_IR): LIBHALIDEIR

Expand Down
17 changes: 17 additions & 0 deletions include/tvm/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ typedef TVMArray* TVMArrayHandle;
*/
TVM_DLL const char *TVMGetLastError(void);

/*!
* \brief Initialize certain type of devices, this may
* not be necessary for all device types. But is needed for OpenCL.
*
* \param dev_mask The device mask of device type to be initialized
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return Whether the function is successful.
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int *out_code);

/*!
* \brief Whether the specified context is enabled.
*
Expand Down
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ ADD_CFLAGS =
# whether use CUDA during compile
USE_CUDA = 1

# whether use OpenCL during compile
USE_OPENCL = 0

# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable, leave it as NONE
# USE_CUDA_PATH = /usr/local/cuda
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import schedule

from . import ndarray as nd
from .ndarray import cpu, gpu, opencl
from .ndarray import cpu, gpu, opencl, init_opencl

from ._base import TVMError
from .function import *
26 changes: 25 additions & 1 deletion python/tvm/_ctypes/_runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from .._base import _LIB
from .._base import c_array
from .._base import c_array, c_str
from .._base import check_call


Expand Down Expand Up @@ -182,6 +182,30 @@ def sync(ctx):
check_call(_LIB.TVMSynchronize(ctx, None))


def init_opencl(**kwargs):
"""Initialize the opencl with the options.
Parameters
----------
kwargs : dict
The options
"""
keys = []
vals = []
for k, v in kwargs.items():
keys.append(c_str(k))
vals.append(c_str(v))
dev_mask = ctypes.c_int(4)
out_code = ctypes.c_int()
check_call(_LIB.TVMDeviceInit(
dev_mask,
c_array(ctypes.c_char_p, keys),
c_array(ctypes.c_char_p, vals),
ctypes.c_int(len(keys)),
ctypes.byref(out_code)))
return out_code.value != 0


class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle"]
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
from ._ctypes._runtime_api import _init_runtime_module
from ._ctypes._runtime_api import init_opencl


class NDArray(NDArrayBase):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __getitem__(self, k):
k = k.op
if not isinstance(k, _tensor.Operation):
raise ValueError("Expect schedule key to be Tensor or Operation")
if not k in self.stage_map:
if k not in self.stage_map:
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k]

Expand Down
17 changes: 17 additions & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ inline size_t GetDataAlignment(TVMArray* arr) {

using namespace tvm::runtime;

int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int* out_code) {
API_BEGIN();
*out_code = 1;
switch (dev_mask) {
case kOpenCL: {
*out_code = DeviceInit<kOpenCL>(option_keys, option_vals, num_options);
break;
}
default: break;
}
API_END();
}

int TVMContextEnabled(TVMContext ctx,
int* out_enabled) {
API_BEGIN();
Expand Down
20 changes: 18 additions & 2 deletions src/runtime/device_api.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2016 by Contributors
* \file device_api.hx
* \file device_api.h
* \brief Device specific API
*/
#ifndef TVM_RUNTIME_DEVICE_API_H_
Expand All @@ -11,6 +11,21 @@

namespace tvm {
namespace runtime {
/*!
* \brief Initialize the device.
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \return 0 if success, 1: if already initialized
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline bool DeviceInit(const char** option_keys,
const char** option_vals,
int num_options) {
return true;
}

/*!
* \brief Whether ctx is enabled.
* \param ctx The device context to perform operation.
Expand Down Expand Up @@ -93,7 +108,8 @@ inline void StreamSync(TVMContext ctx, TVMStreamHandle stream);
} // namespace runtime
} // namespace tvm

#include "./device_api_gpu.h"
#include "./device_api_cpu.h"
#include "./device_api_gpu.h"
#include "./device_api_opencl.h"

#endif // TVM_RUNTIME_DEVICE_API_H_
11 changes: 1 addition & 10 deletions src/runtime/device_api_gpu.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2016 by Contributors
* \file ctxice_api_gpu.h
* \file device_api_gpu.h
* \brief GPU specific API
*/
#ifndef TVM_RUNTIME_DEVICE_API_GPU_H_
Expand All @@ -14,15 +14,6 @@

namespace tvm {
namespace runtime {
/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}

/*!
* \brief Protected CUDA call.
Expand Down
Loading

0 comments on commit c29a292

Please sign in to comment.