Skip to content

Commit

Permalink
[Runtime] Extend Graph Runtime To Support Cuda Graph Launch (#7616)
Browse files Browse the repository at this point in the history
* add graph runtime cuGraph poc

* lint format

* add unittest

* fix review comments

* Update CMakeLists.txt

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

* build cuda graph runtime in gpu test

* Revert "build cuda graph runtime in gpu test"

This reverts commit f286711.

* rename cuGraph to CUDA Graph

* rename cuda_graph

* rename cuda_graph

* lint format

* Update src/runtime/graph/graph_runtime_factory.cc

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

* Update python/tvm/testing.py

Co-authored-by: Cody Yu <comaniac0422@gmail.com>

* fix lint error

* remove unnecessary warn

* add test, fix lint

* fix lint W0223

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
  • Loading branch information
zhuochenKIDD and comaniac authored Mar 17, 2021
1 parent c55608f commit 60ff0c7
Show file tree
Hide file tree
Showing 12 changed files with 502 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ tvm_option(USE_THREADS "Build with thread support" ON)
tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" OFF)
tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON)
tvm_option(USE_GRAPH_RUNTIME_CUDA_GRAPH "Build with tiny graph runtime with CUDA Graph for GPUs" OFF)
tvm_option(USE_PROFILER "Build profiler for the VM and graph runtime" ON)
tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF)
tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ set(USE_STACKVM_RUNTIME OFF)
# Whether enable tiny embedded graph runtime.
set(USE_GRAPH_RUNTIME ON)

# Whether enable tiny graph runtime with CUDA Graph
set(USE_GRAPH_RUNTIME_CUDA_GRAPH OFF)

# Whether to enable the profiler for the graph runtime and vm
set(USE_PROFILER ON)

Expand Down
11 changes: 11 additions & 0 deletions cmake/modules/CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ if(USE_CUDA)
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
endif(USE_THRUST)

if(USE_GRAPH_RUNTIME_CUDA_GRAPH)
if(NOT USE_GRAPH_RUNTIME)
message(FATAL_ERROR "CUDA Graph is only supported by graph runtime, please set USE_GRAPH_RUNTIME=ON")
endif()
if(CUDAToolkit_VERSION_MAJOR LESS "10")
message(FATAL_ERROR "CUDA Graph requires CUDA 10 or above, got=" ${CUDAToolkit_VERSION})
endif()
message(STATUS "Build with Graph runtime with CUDA Graph support...")
file(GLOB RUNTIME_CUDA_GRAPH_SRCS src/runtime/graph/cuda_graph/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_GRAPH_SRCS})
endif()
else(USE_CUDA)
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc)
endif(USE_CUDA)
16 changes: 16 additions & 0 deletions python/tvm/contrib/cuda_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
134 changes: 134 additions & 0 deletions python/tvm/contrib/cuda_graph/cuda_graph_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Graph runtime with CUDA Graph"""
import tvm._ffi

from tvm._ffi.base import string_types
from tvm.contrib import graph_runtime


def create(graph_json_str, libmod, ctx):
"""Create a runtime executor module given a graph and module.
Parameters
----------
graph_json_str : str
The graph to be deployed in json format output by json graph.
The graph can contain operator(tvm_op) that points to the name
of PackedFunc in the libmod.
libmod : tvm.runtime.Module
The module of the corresponding function
ctx : TVMContext
The context to deploy the module, only supports CUDA GPU
Returns
-------
graph_module : GraphModuleCudaGraph
CUDA graph runtime module that can be used to execute the graph.
Note
----
See also :py:class:`tvm.contrib.cuda_graph.cuda_graph_runtime.GraphModuleCudaGraph`
for examples to directly construct a GraphModuleCudaGraph from an exported
relay compiled library.
"""
assert isinstance(graph_json_str, string_types)
try:
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_cuda_graph.create")
else:
fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cuda_graph.create")
except ValueError:
raise ValueError(
"To enable CUDA graph support (experimental), please set "
"'(USE_GRAPH_RUNTIME_CUGRAPH ON)' in config.cmake and rebuild TVM"
)

return GraphModuleCudaGraph(fcreate(graph_json_str, libmod, *device_type_id))


class GraphModuleCudaGraph(graph_runtime.GraphModule):
"""CUDA graph runtime module.
This is a CUDA graph runtime wrapper over the TVM runtime.
Runtime interfaces are wrapped with CUDA graph functionalities.
Parameters
----------
module : Module
The internal tvm module that holds the actual graph functions.
"""

def __init__(self, module):
self._start_capture = module["start_capture"]
self._end_capture = module["end_capture"]
self._run_cuda_graph = module["run_cuda_graph"]
self._cuda_graph_captured = False
graph_runtime.GraphModule.__init__(self, module)

def capture_cuda_graph(self):
"""Capture a CUDA graph for tvm_op graph
This should be called before run_cuda_graph() to capture and
instantiate a CUDA graph instance.
"""
self._run() # call cuModuleLoadData before cudaStream API
self._start_capture()
self._run()
self._end_capture()
self._cuda_graph_captured = True

def run_cuda_graph(self):
"""Run the CUDA graph for tvm_op graph
Run the captured CUDA graph instance instead of the
for-loop kernel launch of default graph runtime
"""
self._run_cuda_graph()

def run(self, **input_dict):
"""A run wrapper for graph capture / launch, user can just
change default graph runtime to cuda graph runtime, and
the first call will capture a cuda graph for future launch
Parameters
----------
input_dict: dict of str to NDArray
List of input values to be feed to
"""
if input_dict:
self.set_input(**input_dict)
if not self._cuda_graph_captured:
self.capture_cuda_graph()
else:
self._run_cuda_graph()

def debug_get_output(self, node, out):
"""Run graph up to node and get the output to out
Parameters
----------
node : int / str
The node index or name
out : NDArray
The output array container
"""
raise NotImplementedError("Please use debugger.debug_runtime as graph_runtime instead.")
12 changes: 12 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,18 @@ def have_tensorcore(compute_version=None, target=None):
return False


def have_cudagraph():
"""Either CUDA Graph support is provided"""
try:
cuda_path = find_cuda_path()
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver < 10.0:
return False
return True
except RuntimeError:
return False


def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,25 @@ def requires_cuda(*args):
return _compose(args, _requires_cuda)


def requires_cudagraph(*args):
"""Mark a test as requiring the CUDA Graph Feature
This also marks the test as requiring cuda
Parameters
----------
f : function
Function to mark
"""
_requires_cudagraph = [
pytest.mark.skipif(
not nvcc.have_cudagraph(), reason="CUDA Graph is not supported in this environment"
),
*requires_cuda(),
]
return _compose(args, _requires_cudagraph)


def requires_opencl(*args):
"""Mark a test as requiring the OpenCL runtime.
Expand Down
135 changes: 135 additions & 0 deletions src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file graph_runtime_cuda_graph.cc
*/

#include <tvm/runtime/registry.h>

#include "../../cuda/cuda_common.h"
#include "../graph_runtime.h"

namespace tvm {
namespace runtime {

/*!
* \brief Graph runtime with CUDA Graph Support.
*
* This is the extension of GraphRuntime class used for CUDA graph launch
* instead of CUDA kernel launch. CUDA graph launch requires CUDA 10.0 or
* above, currently there are two ways of constructing CUDA graphs:
* (1) Using CUDA stream capture API to capture a series of operations on
* CUDA stream, and automatically generates a graph (2) Building a graph
* using CUDA graph API manually. This implementation uses stream capture.
*/
class GraphRuntimeCudaGraph : public GraphRuntime {
public:
/*!
* \brief Begin CUDA graph capture on stream, the stream enters capture mode.
*/
void StartCapture() {
const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx;

TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_);
TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_);

CUDA_CALL(cudaStreamBeginCapture(static_cast<cudaStream_t>(capture_stream_),
cudaStreamCaptureModeGlobal));
}

/*!
* \brief Launch the instantiated graph on stream
*/
void RunCudaGraph() {
cudaStream_t cuStream = static_cast<cudaStream_t>(capture_stream_);
CUDA_CALL(cudaGraphLaunch(cuda_graph_exec_, cuStream));
CUDA_CALL(cudaStreamSynchronize(cuStream));
}

/*!
* \brief End CUDA graph capture on stream, a graph will be created and
* instantiated.
*/
void EndCapture() {
cudaGraph_t graph;
CUDA_CALL(cudaStreamEndCapture(static_cast<cudaStream_t>(capture_stream_), &graph));

cudaGraphNode_t* nodes = NULL;
size_t numNodes = 0;
CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes));
LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes;

CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec_, graph, NULL, NULL, 0));
}

/*!
* \brief GetFunction Get the function based on input.
* \param name The function which needs to be invoked.
* \param sptr_to_self Packed function pointer.
*/
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);

private:
/*! \brief The Cuda stream on which to capture a CUDA graph. */
TVMStreamHandle capture_stream_;
/*! \brief The captured CUDA graph will be instantiated to this. */
cudaGraphExec_t cuda_graph_exec_;
};

PackedFunc GraphRuntimeCudaGraph::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
if (name == "run_cuda_graph") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->RunCudaGraph(); });
} else if (name == "start_capture") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->StartCapture(); });
} else if (name == "end_capture") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->EndCapture(); });
} else {
return GraphRuntime::GetFunction(name, sptr_to_self);
}
}

Module GraphRuntimeCudaGraphCreate(const std::string& sym_json, const tvm::runtime::Module& m,
const std::vector<TVMContext>& ctxs,
PackedFunc lookup_linked_param_func) {
auto exec = make_object<GraphRuntimeCudaGraph>();
exec->Init(sym_json, m, ctxs, lookup_linked_param_func);
return Module(exec);
}

TVM_REGISTER_GLOBAL("tvm.graph_runtime_cuda_graph.create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is "
"at least 4, but it has "
<< args.num_args;
PackedFunc lookup_linked_param_func;
int ctx_start_arg = 2;
if (args[2].type_code() == kTVMPackedFuncHandle) {
lookup_linked_param_func = args[2];
ctx_start_arg++;
}

*rv = GraphRuntimeCudaGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg),
lookup_linked_param_func);
});
} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 60ff0c7

Please sign in to comment.