Skip to content

Commit

Permalink
[OpenCL] Refactor cl_program generation (apache#7834)
Browse files Browse the repository at this point in the history
* Refactor OpenCL runtime module to build separate cl_programs
for each kernel. This can avoid pathological bugs in the
vendor specific OpenCL compiler that may be triggered
with large programs.

* clang-format

* Remove check on program size when deconstructing.

* Refactor into SplitKernels method.

* Limit number of loops for kernel parsing

* Add return doc for SplitKernels per CR.
  • Loading branch information
csullivan authored and Trevor Morris committed May 6, 2021
1 parent 2e703c9 commit fffd391
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 29 deletions.
16 changes: 12 additions & 4 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,14 @@ class OpenCLModuleNode : public ModuleNode {
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e);

/*
* \brief Splits the provided serialized source file into separate
* source for each kernel primitive.
* \param source The serialized program source file (fmt: cl)
* \return Mapping from primitive name to kernel source
*/
std::unordered_map<std::string, std::string> SplitKernels(std::string source) const;

private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
Expand All @@ -340,14 +348,14 @@ class OpenCLModuleNode : public ModuleNode {
std::mutex build_lock_;
// The OpenCL source.
std::string source_;
// the binary data
cl_program program_{nullptr};
// build info
std::vector<bool> device_built_flag_;
// Mapping from primitive name to cl program for each device.
std::unordered_map<std::string, std::vector<cl_program>> programs_;
// kernel id cache
std::unordered_map<std::string, KTRefEntry> kid_map_;
// kernels build so far.
std::vector<cl_kernel> kernels_;
// parsed kernel data
std::unordered_map<std::string, std::string> parsed_kernels_;
};

} // namespace runtime
Expand Down
80 changes: 62 additions & 18 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,13 @@ OpenCLModuleNode::~OpenCLModuleNode() {
for (cl_kernel k : kernels_) {
OPENCL_CALL(clReleaseKernel(k));
}
if (program_) {
OPENCL_CALL(clReleaseProgram(program_));
// free the programs
for (auto& kv : programs_) {
for (auto& program : kv.second) {
if (program) {
OPENCL_CALL(clReleaseProgram(program));
}
}
}
}

Expand Down Expand Up @@ -166,7 +171,6 @@ std::string OpenCLModuleNode::GetSource(const std::string& format) {
void OpenCLModuleNode::Init() {
workspace_ = GetGlobalWorkspace();
workspace_->Init();
device_built_flag_.resize(workspace_->devices.size(), false);
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
Expand All @@ -181,56 +185,96 @@ void OpenCLModuleNode::Init() {
e.version = workspace_->timestamp++;
kid_map_[key] = e;
}

// split into source artifacts for each kernel
parsed_kernels_ = SplitKernels(GetSource("cl"));
// zero initialize cl_program pointers for each device kernel
for (auto& kv : parsed_kernels_) {
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
}
}

cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->device.device_id;
if (!device_built_flag_[device_id]) {
if (programs_[func_name][device_id] == nullptr) {
// create program
if (fmt_ == "cl") {
if (program_ == nullptr) {
const char* s = data_.c_str();
size_t len = data_.length();
cl_int err;
program_ = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
}
const char* s = parsed_kernels_[func_name].c_str();
size_t len = parsed_kernels_[func_name].length();
cl_int err;
programs_[func_name][device_id] = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
} else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") {
const unsigned char* s = (const unsigned char*)data_.c_str();
size_t len = data_.length();
cl_int err;
cl_device_id dev = w->devices[device_id];
program_ = clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
programs_[func_name][device_id] =
clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
OPENCL_CHECK_ERROR(err);
} else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_;
}
// build program
cl_int err;
cl_device_id dev = w->devices[device_id];
err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr);
if (err != CL_SUCCESS) {
size_t len;
std::string log;
clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr,
&len);
log.resize(len);
clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << log;
clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len,
&log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log;
}
device_built_flag_[device_id] = true;
}
// build kernel
cl_int err;
cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err);
OPENCL_CHECK_ERROR(err);
t->kernel_table[e.kernel_id].kernel = kernel;
t->kernel_table[e.kernel_id].version = e.version;
kernels_.push_back(kernel);
return kernel;
}

std::unordered_map<std::string, std::string> OpenCLModuleNode::SplitKernels(
std::string source) const {
std::unordered_map<std::string, std::string> split_kernels;
if (source.size()) {
std::string del{"// Function: "};
size_t end;
size_t begin = source.find(del);
ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited "
<< "source from code generation, but no kernel "
<< "delimiter was found.";
for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) {
begin += del.size();
end = source.find('\n', begin);
std::string func_name = source.substr(begin, end - begin);
begin = ++end;
// std::string::substr returns either start of next kernel
// or std::string::npos, in the latter case substr returns
// all characters until the end of the source string.
end = source.find(del, begin);
std::string func_source =
source.substr(begin, (end == std::string::npos) ? end : end - begin);
split_kernels.insert({func_name, func_source});
begin = end;
if (end == std::string::npos) {
break;
}
}
}
ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size())
<< "The number of registered kernels does not match number of parsed kernel sources";
return split_kernels;
}

Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
Expand Down
18 changes: 11 additions & 7 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,23 +283,27 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N
runtime::Module BuildOpenCL(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);

std::stringstream code;
const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
code << "// Function: " << kv.first->name_hint << std::endl;
CodeGenOpenCL cg;
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
std::string fsource = cg.Finish();
if (fpostproc) {
fsource = (*fpostproc)(fsource).operator std::string();
}
code << fsource;
}

std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string();
}
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code);
return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str());
}

TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL);
Expand Down

0 comments on commit fffd391

Please sign in to comment.