Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYNC] Sync with CentML/hidet -> hidet-org/hidet #486

Closed
wants to merge 9 commits into from
2 changes: 1 addition & 1 deletion .github/workflows/publish-centml-pypi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ jobs:
matrix: ${{ fromJSON(needs.list-test-dirs.outputs.matrix) }}
runs-on: arc-runner-set
container:
image: nvidia/cuda:12.4.0-devel-ubuntu20.04
image: nvidia/cuda:12.6.2-devel-ubuntu22.04
steps:
- name: Install dependencies via apt
run: |
Expand Down
30 changes: 29 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.19)

project(hidet C CXX)

set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# config hidet
Expand Down Expand Up @@ -34,6 +34,7 @@ add_library(hidet_runtime SHARED
src/hidet/runtime/llm/tokenizer/pretokenizers.cpp
src/hidet/runtime/llm/tokenizer/tokenizer.cpp
src/hidet/runtime/llm/tokenizer/utf8.cpp
src/hidet/runtime/torch/stream.cpp
)
target_include_directories(hidet_runtime PRIVATE ${CMAKE_SOURCE_DIR}/include /usr/include)
set_target_properties(hidet_runtime PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
Expand All @@ -44,4 +45,31 @@ add_library(hidet SHARED
)
target_include_directories(hidet PRIVATE ${CMAKE_SOURCE_DIR}/include)
set_target_properties(hidet PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)

execute_process(
COMMAND python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))"
OUTPUT_VARIABLE TORCH_LIBRARY
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(TORCH_INCLUDE_DIR "${TORCH_LIBRARY}/include")
execute_process(
COMMAND python3 -c "import nvidia; import os; print(os.path.dirname(nvidia.__file__))"
OUTPUT_VARIABLE NVIDIA_LIBRARY
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(CUDARUNTIME_INCLUDE_DIR "${NVIDIA_LIBRARY}/cuda_runtime/include")
execute_process(
COMMAND python3 -c "import triton; import os; print(os.path.dirname(triton.__file__))"
OUTPUT_VARIABLE TRITON_LIBRARY
OUTPUT_STRIP_TRAILING_WHITESPACE
)
set(TRITON_INCLUDE_DIR "${TRITON_LIBRARY}/backends/nvidia/include/")

add_library(hidet_torch_wrapper SHARED
src/hidet/runtime/hidet_torch/cuda_stream.cpp
)
target_include_directories(hidet_torch_wrapper PRIVATE ${CMAKE_SOURCE_DIR}/include ${TORCH_INCLUDE_DIR} ${CUDARUNTIME_INCLUDE_DIR} ${TRITON_INCLUDE_DIR})
target_link_libraries(hidet_torch_wrapper ${TORCH_LIBRARY}/lib/libc10_cuda.so)
set_target_properties(hidet_torch_wrapper PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)

target_link_libraries(hidet "-Wl,--no-as-needed" hidet_runtime)
2 changes: 1 addition & 1 deletion apps/compile_server/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.2.0-devel-ubuntu22.04
FROM nvidia/cuda:12.6.2-devel-ubuntu22.04

COPY ./run.py /app/run.py
WORKDIR /app
Expand Down
2 changes: 1 addition & 1 deletion apps/compile_server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ lark
scipy

# for torch runtime api dependency
torch>=2.3.0
torch>=2.3.0
4 changes: 4 additions & 0 deletions docs/source/getting-started/build-from-source.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ First clone the repository to local:

Build shared libraries
~~~~~~~~~~~~~~~~~~~~~~
Before building the runtime library, make sure you have ``torch`` installed in your python environment:
.. code-block:: console

$ pip install torch

The runtime library is written in C++ and compiled into a shared library. To build the shared library, you need to have
a C++ compiler installed (as well as build tools like ``cmake``, and ``make``). The following command will build the
Expand Down
13 changes: 13 additions & 0 deletions include/hidet/runtime/cuda/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ struct CudaContext: BaseContext {
/* The cuda stream the kernels will be launched on. */
void *stream = nullptr;

/* whether to use torch stream */
bool use_torch_stream = true;

/* NCCL Comunicators*/
void **nccl_comms = nullptr;

Expand All @@ -35,6 +38,16 @@ struct CudaContext: BaseContext {
*/
DLL void set_cuda_stream(void *stream);

/**
* Get the use torch stream flag
*/
DLL bool get_use_torch_cuda_stream();

/**
* set the flag of whether to use torch stream
*/
DLL void use_torch_cuda_stream(bool use);

/**
* Get the cuda stream of cuda context.
*/
Expand Down
15 changes: 15 additions & 0 deletions include/hidet/runtime/torch/stream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <hidet/runtime/common.h>

DLL void *hidet_get_current_torch_stream();
11 changes: 9 additions & 2 deletions python/hidet/cuda/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import List, Sequence, Optional, Any, Callable
from cuda import cudart
from cuda.cudart import cudaGraphExec_t
from hidet.option import use_torch_stream, is_use_torch_stream
from hidet.graph.tensor import Tensor
from hidet.runtime.storage import MemoryPool, CudaMemoryAPI, memory_pool
from hidet.runtime.device import Device
Expand Down Expand Up @@ -116,7 +117,6 @@ def __init__(
self._inputs: List[Tensor] = []
self._outputs: List[Tensor] = []
self._ref_objs: List[Any] = ref_objs

with memory_pool(self._memory_pool):
# create the input tensors
self._inputs = f_create_inputs()
Expand All @@ -126,11 +126,18 @@ def __init__(
for _ in range(num_warmup):
f_run(self._inputs)

# There are two scenarios:
# 1. if torch or hidet is using default stream we use hidet created new stream
# 2. If torch is using its own new stream we use hidet created new stream to avoid
# interfere with torch stream
# Both cases we switch back to use hidet stream
prev_flag = is_use_torch_stream()
use_torch_stream(False)
# capture the cuda graph
self._memory_api.freeze()
with self._graph_capture:
self._outputs = f_run(self._inputs)

use_torch_stream(prev_flag)
# instantiate the cuda graph
self._graph_exec: cudaGraphExec_t = self._graph_capture.instantiate()

Expand Down
9 changes: 8 additions & 1 deletion python/hidet/cuda/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,15 @@ def current_stream(device=None) -> Stream:
stream: Stream
The current stream.
"""
from hidet.ffi import runtime_api

device_id = _get_device_id(device)
if device_id not in _current_streams:
c_stream = runtime_api.get_current_stream()
if c_stream is not None:
# we return the current stream no matter if it's hidet/torch stream
_current_streams[device_id] = ExternalStream(handle=c_stream, device_id=device_id)
else:
# if no current stream is set, we use the default stream
_current_streams[device_id] = ExternalStream(handle=0, device_id=device_id)
return _current_streams[_get_device_id(device)]

Expand Down
21 changes: 18 additions & 3 deletions python/hidet/ffi/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Dict, Optional
import os
import os.path
import ctypes
import torch
from hidet.libinfo import get_library_search_dirs

_LIB: Optional[ctypes.CDLL] = None
_LIB_RUNTIME: Optional[ctypes.CDLL] = None
_LIB_HIDET_TORCH_WRAPPER: Optional[ctypes.CDLL] = None


library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None}
library_paths: Dict[str, Optional[str]] = {'hidet': None, 'hidet_runtime': None, 'hidet_torch': None}


def load_library():
global _LIB, _LIB_RUNTIME
global _LIB, _LIB_RUNTIME, _LIB_HIDET_TORCH_WRAPPER
if _LIB:
return
libc10_path = os.path.join(os.path.dirname(torch.__file__), 'lib/libc10_cuda.so')
ctypes.cdll.LoadLibrary(libc10_path)

library_dirs = get_library_search_dirs()
for library_dir in library_dirs:
libhidet_path = os.path.join(library_dir, 'libhidet.so')
libhidet_runtime_path = os.path.join(library_dir, 'libhidet_runtime.so')
if not os.path.exists(libhidet_path) or not os.path.exists(libhidet_runtime_path):
libhidet_torch_wrapper_path = os.path.join(library_dir, 'libhidet_torch_wrapper.so')
if (
not os.path.exists(libhidet_path)
or not os.path.exists(libhidet_runtime_path)
or not os.path.exists(libhidet_torch_wrapper_path)
):
continue
_LIB_RUNTIME = ctypes.cdll.LoadLibrary(libhidet_runtime_path)
_LIB_HIDET_TORCH_WRAPPER = ctypes.cdll.LoadLibrary(libhidet_torch_wrapper_path)
_LIB = ctypes.cdll.LoadLibrary(libhidet_path)
library_paths['hidet_runtime'] = libhidet_runtime_path
library_paths['hidet'] = libhidet_path
library_paths['hidet_torch_wrapper'] = libhidet_torch_wrapper_path
break
if _LIB is None:
raise OSError('Can not find library in the following directory: \n' + '\n'.join(library_dirs))
Expand Down Expand Up @@ -71,6 +84,8 @@ def get_func(func_name, arg_types: List, restype, lib=None):
func = getattr(_LIB, func_name)
elif func_exists(func_name, _LIB_RUNTIME):
func = getattr(_LIB_RUNTIME, func_name)
elif func_exists(func_name, _LIB_HIDET_TORCH_WRAPPER):
func = getattr(_LIB_HIDET_TORCH_WRAPPER, func_name)
elif func_exists(func_name, lib):
func = getattr(lib, func_name)
else:
Expand Down
16 changes: 13 additions & 3 deletions python/hidet/ffi/runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from ctypes import c_void_p, c_char_p, c_uint64, c_int32
from ctypes import c_void_p, c_char_p, c_uint64, c_int32, c_bool
from hidet.cuda import Stream
from .ffi import get_func
from .array import Array
Expand All @@ -26,15 +26,17 @@ class RuntimeAPI:
_get_symbol_value = get_func('get_symbol_value', [c_char_p], c_int32)
_set_symbol_value = get_func('set_symbol_value', [c_char_p, c_int32], None)
_set_nccl_comms = get_func('set_nccl_comms', [c_int32, c_void_p], None)
_get_use_torch_stream = get_func('get_use_torch_cuda_stream', [], c_bool)
_use_torch_cuda_stream = get_func('use_torch_cuda_stream', [c_bool], None)

@staticmethod
def set_current_stream(stream: Union[Stream, int]) -> None:
RuntimeAPI._set_current_stream(c_void_p(int(stream)))

@staticmethod
def get_current_stream() -> int:
def get_current_stream() -> Union[int, None]:
p = RuntimeAPI._get_current_stream()
return p.value
return p if p else None

@staticmethod
def register_callback(name: str, cfunc):
Expand Down Expand Up @@ -68,5 +70,13 @@ def set_nccl_comms(comms: Array) -> None:
comms_array_t = c_void_p * comms.length
RuntimeAPI._set_nccl_comms(comms.length, comms_array_t.from_buffer(comms.buffer))

@staticmethod
def get_use_torch_cuda_stream() -> bool:
return RuntimeAPI._get_use_torch_stream()

@staticmethod
def use_torch_cuda_stream(use: bool) -> None:
RuntimeAPI._use_torch_cuda_stream(use)


runtime_api = RuntimeAPI()
2 changes: 0 additions & 2 deletions python/hidet/graph/frontend/torch/dynamo_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ def __call__(self, *args):
else:
# ignore constant
pass
# Inherited cuda stream from torch
runtime_api.set_current_stream(torch.cuda.current_stream().cuda_stream)
# Prepare inputs
tensor_args = preprocess_inputs(tensor_args)
# Run graph/model
Expand Down
10 changes: 9 additions & 1 deletion python/hidet/graph/ops/matmul/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from .matmul import MatmulOp
from .batch_matmul import batch_matmul
from .matmul_f16_cute import matmul_f16_cute
from .matmul_f16_cute_experimental import matmul_f16_cute as matmul_f16_cute_experimental
from .matmul_f16_cute import matmul_f16_cute as matmul_f16_cute_stable
from ..transform import broadcast, flatten
from ..utils import broadcast_shapes

Expand Down Expand Up @@ -215,6 +216,13 @@ def resolve_f16(self, op: Operator) -> Optional[List[Tensor]]:
return None

parallel_k = hidet.option.get_parallel_k()
hexcute_matmul = hidet.option.get_hexcute_matmul()
if hexcute_matmul == 'enable':
matmul_f16_cute = matmul_f16_cute_experimental
elif hexcute_matmul == 'disable':
matmul_f16_cute = matmul_f16_cute_stable
else:
raise NotImplementedError('The heuristic for hexcute_matmul is not implemented.')

if op.task.has_symbolic_shape():
k_parts = 1
Expand Down
Loading
Loading