-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Runtime] Extend Graph Runtime To Support Cuda Graph Launch (#7616)
* add graph runtime cuGraph poc * lint format * add unittest * fix review comments * Update CMakeLists.txt Co-authored-by: Cody Yu <comaniac0422@gmail.com> * build cuda graph runtime in gpu test * Revert "build cuda graph runtime in gpu test" This reverts commit f286711. * rename cuGraph to CUDA Graph * rename cuda_graph * rename cuda_graph * lint format * Update src/runtime/graph/graph_runtime_factory.cc Co-authored-by: Cody Yu <comaniac0422@gmail.com> * Update python/tvm/testing.py Co-authored-by: Cody Yu <comaniac0422@gmail.com> * fix lint error * remove unnecessary warn * add test, fix lint * fix lint W0223 Co-authored-by: Cody Yu <comaniac0422@gmail.com>
- Loading branch information
1 parent
c55608f
commit 60ff0c7
Showing
12 changed files
with
502 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Graph runtime with CUDA Graph""" | ||
import tvm._ffi | ||
|
||
from tvm._ffi.base import string_types | ||
from tvm.contrib import graph_runtime | ||
|
||
|
||
def create(graph_json_str, libmod, ctx): | ||
"""Create a runtime executor module given a graph and module. | ||
Parameters | ||
---------- | ||
graph_json_str : str | ||
The graph to be deployed in json format output by json graph. | ||
The graph can contain operator(tvm_op) that points to the name | ||
of PackedFunc in the libmod. | ||
libmod : tvm.runtime.Module | ||
The module of the corresponding function | ||
ctx : TVMContext | ||
The context to deploy the module, only supports CUDA GPU | ||
Returns | ||
------- | ||
graph_module : GraphModuleCudaGraph | ||
CUDA graph runtime module that can be used to execute the graph. | ||
Note | ||
---- | ||
See also :py:class:`tvm.contrib.cuda_graph.cuda_graph_runtime.GraphModuleCudaGraph` | ||
for examples to directly construct a GraphModuleCudaGraph from an exported | ||
relay compiled library. | ||
""" | ||
assert isinstance(graph_json_str, string_types) | ||
try: | ||
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx) | ||
if num_rpc_ctx == len(ctx): | ||
fcreate = ctx[0]._rpc_sess.get_function("tvm.graph_runtime_cuda_graph.create") | ||
else: | ||
fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cuda_graph.create") | ||
except ValueError: | ||
raise ValueError( | ||
"To enable CUDA graph support (experimental), please set " | ||
"'(USE_GRAPH_RUNTIME_CUGRAPH ON)' in config.cmake and rebuild TVM" | ||
) | ||
|
||
return GraphModuleCudaGraph(fcreate(graph_json_str, libmod, *device_type_id)) | ||
|
||
|
||
class GraphModuleCudaGraph(graph_runtime.GraphModule): | ||
"""CUDA graph runtime module. | ||
This is a CUDA graph runtime wrapper over the TVM runtime. | ||
Runtime interfaces are wrapped with CUDA graph functionalities. | ||
Parameters | ||
---------- | ||
module : Module | ||
The internal tvm module that holds the actual graph functions. | ||
""" | ||
|
||
def __init__(self, module): | ||
self._start_capture = module["start_capture"] | ||
self._end_capture = module["end_capture"] | ||
self._run_cuda_graph = module["run_cuda_graph"] | ||
self._cuda_graph_captured = False | ||
graph_runtime.GraphModule.__init__(self, module) | ||
|
||
def capture_cuda_graph(self): | ||
"""Capture a CUDA graph for tvm_op graph | ||
This should be called before run_cuda_graph() to capture and | ||
instantiate a CUDA graph instance. | ||
""" | ||
self._run() # call cuModuleLoadData before cudaStream API | ||
self._start_capture() | ||
self._run() | ||
self._end_capture() | ||
self._cuda_graph_captured = True | ||
|
||
def run_cuda_graph(self): | ||
"""Run the CUDA graph for tvm_op graph | ||
Run the captured CUDA graph instance instead of the | ||
for-loop kernel launch of default graph runtime | ||
""" | ||
self._run_cuda_graph() | ||
|
||
def run(self, **input_dict): | ||
"""A run wrapper for graph capture / launch, user can just | ||
change default graph runtime to cuda graph runtime, and | ||
the first call will capture a cuda graph for future launch | ||
Parameters | ||
---------- | ||
input_dict: dict of str to NDArray | ||
List of input values to be feed to | ||
""" | ||
if input_dict: | ||
self.set_input(**input_dict) | ||
if not self._cuda_graph_captured: | ||
self.capture_cuda_graph() | ||
else: | ||
self._run_cuda_graph() | ||
|
||
def debug_get_output(self, node, out): | ||
"""Run graph up to node and get the output to out | ||
Parameters | ||
---------- | ||
node : int / str | ||
The node index or name | ||
out : NDArray | ||
The output array container | ||
""" | ||
raise NotImplementedError("Please use debugger.debug_runtime as graph_runtime instead.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
src/runtime/graph/cuda_graph/graph_runtime_cuda_graph.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file graph_runtime_cuda_graph.cc | ||
*/ | ||
|
||
#include <tvm/runtime/registry.h> | ||
|
||
#include "../../cuda/cuda_common.h" | ||
#include "../graph_runtime.h" | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
|
||
/*! | ||
* \brief Graph runtime with CUDA Graph Support. | ||
* | ||
* This is the extension of GraphRuntime class used for CUDA graph launch | ||
* instead of CUDA kernel launch. CUDA graph launch requires CUDA 10.0 or | ||
* above, currently there are two ways of constructing CUDA graphs: | ||
* (1) Using CUDA stream capture API to capture a series of operations on | ||
* CUDA stream, and automatically generates a graph (2) Building a graph | ||
* using CUDA graph API manually. This implementation uses stream capture. | ||
*/ | ||
class GraphRuntimeCudaGraph : public GraphRuntime { | ||
public: | ||
/*! | ||
* \brief Begin CUDA graph capture on stream, the stream enters capture mode. | ||
*/ | ||
void StartCapture() { | ||
const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx; | ||
|
||
TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_); | ||
TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_); | ||
|
||
CUDA_CALL(cudaStreamBeginCapture(static_cast<cudaStream_t>(capture_stream_), | ||
cudaStreamCaptureModeGlobal)); | ||
} | ||
|
||
/*! | ||
* \brief Launch the instantiated graph on stream | ||
*/ | ||
void RunCudaGraph() { | ||
cudaStream_t cuStream = static_cast<cudaStream_t>(capture_stream_); | ||
CUDA_CALL(cudaGraphLaunch(cuda_graph_exec_, cuStream)); | ||
CUDA_CALL(cudaStreamSynchronize(cuStream)); | ||
} | ||
|
||
/*! | ||
* \brief End CUDA graph capture on stream, a graph will be created and | ||
* instantiated. | ||
*/ | ||
void EndCapture() { | ||
cudaGraph_t graph; | ||
CUDA_CALL(cudaStreamEndCapture(static_cast<cudaStream_t>(capture_stream_), &graph)); | ||
|
||
cudaGraphNode_t* nodes = NULL; | ||
size_t numNodes = 0; | ||
CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes)); | ||
LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes; | ||
|
||
CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec_, graph, NULL, NULL, 0)); | ||
} | ||
|
||
/*! | ||
* \brief GetFunction Get the function based on input. | ||
* \param name The function which needs to be invoked. | ||
* \param sptr_to_self Packed function pointer. | ||
*/ | ||
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self); | ||
|
||
private: | ||
/*! \brief The Cuda stream on which to capture a CUDA graph. */ | ||
TVMStreamHandle capture_stream_; | ||
/*! \brief The captured CUDA graph will be instantiated to this. */ | ||
cudaGraphExec_t cuda_graph_exec_; | ||
}; | ||
|
||
PackedFunc GraphRuntimeCudaGraph::GetFunction(const std::string& name, | ||
const ObjectPtr<Object>& sptr_to_self) { | ||
if (name == "run_cuda_graph") { | ||
return PackedFunc( | ||
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->RunCudaGraph(); }); | ||
} else if (name == "start_capture") { | ||
return PackedFunc( | ||
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->StartCapture(); }); | ||
} else if (name == "end_capture") { | ||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->EndCapture(); }); | ||
} else { | ||
return GraphRuntime::GetFunction(name, sptr_to_self); | ||
} | ||
} | ||
|
||
Module GraphRuntimeCudaGraphCreate(const std::string& sym_json, const tvm::runtime::Module& m, | ||
const std::vector<TVMContext>& ctxs, | ||
PackedFunc lookup_linked_param_func) { | ||
auto exec = make_object<GraphRuntimeCudaGraph>(); | ||
exec->Init(sym_json, m, ctxs, lookup_linked_param_func); | ||
return Module(exec); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tvm.graph_runtime_cuda_graph.create") | ||
.set_body([](TVMArgs args, TVMRetValue* rv) { | ||
ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " | ||
"at least 4, but it has " | ||
<< args.num_args; | ||
PackedFunc lookup_linked_param_func; | ||
int ctx_start_arg = 2; | ||
if (args[2].type_code() == kTVMPackedFuncHandle) { | ||
lookup_linked_param_func = args[2]; | ||
ctx_start_arg++; | ||
} | ||
|
||
*rv = GraphRuntimeCudaGraphCreate(args[0], args[1], GetAllContext(args, ctx_start_arg), | ||
lookup_linked_param_func); | ||
}); | ||
} // namespace runtime | ||
} // namespace tvm |
Oops, something went wrong.