Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor building system and adding compiled products #261

Merged
merged 2 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/clear-hidet-cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import conf

conf.hidet.utils.hidet_clear_op_cache()
conf.hidet.utils.clear_op_cache()
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import hidet
hidet.option.cache_dir(os.path.join(hidet.option.get_cache_dir(), 'docs-cache'))
hidet.utils.hidet_clear_op_cache()
hidet.utils.clear_op_cache()
print('Build docs with under cache: {}'.format(hidet.option.get_cache_dir()))

# -- Project information -----------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions docs/source/python_api/driver.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
hidet.driver
------------
hidet.drivers
-------------

.. automodule:: hidet.driver
.. automodule:: hidet.drivers
:members:
:autosummary:
2 changes: 1 addition & 1 deletion examples/gpt-2/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def decode(self, tokens):

def get_encoder(model_name="124M"):
import hidet
models_dir = hidet.utils.hidet_cache_dir("./examples/gpt-2")
models_dir = hidet.utils.cache_dir("./examples/gpt-2")
with open(os.path.join(models_dir, model_name, "encoder.json"), "r") as f:
encoder = json.load(f)
with open(os.path.join(models_dir, model_name, "vocab.bpe"), "r", encoding="utf-8") as f:
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt-2/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def gpt2_forward(ids, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_v


def gpt2(model_size: str = "124M", seq_length: Optional[int] = 1000, use_fp16=False) -> FlowGraph:
cache_dir = hidet.utils.hidet_cache_dir('./examples/gpt-2/')
cache_dir = hidet.utils.cache_dir('./examples/gpt-2/')
model_name = 'model_{}_seq{}_{}.hf'.format(model_size, seq_length, 'fp16' if use_fp16 else 'fp32')
hf_path = os.path.join(cache_dir, model_name)
if os.path.exists(hf_path):
Expand Down
6 changes: 1 addition & 5 deletions gallery/developer-guides/hidet-script-dynamic-kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,7 @@ def matmul_kernel(

assert isinstance(matmul_kernel, hidet.ir.Function) # matmul is a hidet.ir.Function

ir_module = script_module.ir_module()
compiled_function: hidet.runtime.CompiledFunction = hidet.driver.build_ir_module(
ir_module
)
return compiled_function
return script_module.build()


def main():
Expand Down
19 changes: 9 additions & 10 deletions gallery/how-to-guides/add-new-operator-compute-definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,16 +270,15 @@ def demo_task():
from hidet.ir.task import Task


def run_task(task: Task, inputs: List[hidet.Tensor], outputs: List[hidet.Tensor]):
def run_task(task: Task, inputs: List[hidet.Tensor]):
"""Run given task and print inputs and outputs"""
from hidet.runtime import CompiledFunction
from hidet.runtime import CompiledTask

# build the task
func: CompiledFunction = hidet.driver.build_task(task, target_device='cpu')
params = inputs + outputs
func: CompiledTask = hidet.drivers.build_task(task, target='cpu')

# run the compiled task
func(*params)
outputs = func.run_async(inputs)

print('Task:', task.name)
print('Inputs:')
Expand Down Expand Up @@ -317,7 +316,7 @@ def add_example():
b: TensorNode = tensor_input(name='b', dtype='float32', shape=[5])
c: TensorNode = compute(name='c', shape=[5], fcompute=lambda i: a[i] + b[i])
task = Task(name='add', inputs=[a, b], outputs=[c])
run_task(task, [hidet.randn([5]), hidet.randn([5])], [hidet.empty([5])])
run_task(task, [hidet.randn([5]), hidet.randn([5])])


add_example()
Expand Down Expand Up @@ -350,7 +349,7 @@ def reduce_sum_example():
),
)
task = Task('reduce_sum', inputs=[a], outputs=[b])
run_task(task, [hidet.randn([4, 3])], [hidet.empty([4])])
run_task(task, [hidet.randn([4, 3])])


reduce_sum_example()
Expand All @@ -371,7 +370,7 @@ def arg_max_example():
),
)
task = Task('arg_max', inputs=[a], outputs=[b])
run_task(task, [hidet.randn([4, 3])], [hidet.empty([4], dtype='int64')])
run_task(task, [hidet.randn([4, 3])])


arg_max_example()
Expand All @@ -391,7 +390,7 @@ def matmul_example():
),
)
task = Task('matmul', inputs=[a, b], outputs=[c])
run_task(task, [hidet.randn([3, 3]), hidet.randn([3, 3])], [hidet.empty([3, 3])])
run_task(task, [hidet.randn([3, 3]), hidet.randn([3, 3])])


matmul_example()
Expand All @@ -411,7 +410,7 @@ def softmax_example():
softmax = compute('softmax', shape=[3], fcompute=lambda i: exp_a[i] / exp_sum)

task = Task('softmax', inputs=[a], outputs=[softmax])
run_task(task, [hidet.randn([3])], [hidet.empty([3])])
run_task(task, [hidet.randn([3])])


softmax_example()
Expand Down
12 changes: 1 addition & 11 deletions gallery/how-to-guides/add-new-operator-template-based.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import hidet
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.module import IRModule


class BatchMatmulFp16Task(Task):
Expand Down Expand Up @@ -240,16 +240,6 @@ def demo_usage():

demo_usage()

# %%
# Generated Source Code
# ---------------------
# If you are interested in the generated source code, here it is:

a = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
b = hidet.randn([1, 2, 2], dtype='float16', device='cuda')
op = BatchMatmulFp16Op(a, b)
print(op.task_func.source(color=True))

# %%
# Summary
# -------
Expand Down
20 changes: 10 additions & 10 deletions include/hidet/runtime/cpu/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,42 @@
typedef std::complex<float> complex64_t;
typedef std::complex<double> complex128_t;

float abs(const complex64_t &a) {
static float abs(const complex64_t &a) {
return std::abs(a);
}

double abs(const complex128_t &a) {
static double abs(const complex128_t &a) {
return std::abs(a);
}

float real(const complex64_t &a) {
static float real(const complex64_t &a) {
return std::real(a);
}

double real(const complex128_t &a) {
static double real(const complex128_t &a) {
return std::real(a);
}

float imag(const complex64_t &a) {
static float imag(const complex64_t &a) {
return std::imag(a);
}

double imag(const complex128_t &a) {
static double imag(const complex128_t &a) {
return std::imag(a);
}

complex64_t conj(const complex64_t &a) {
static complex64_t conj(const complex64_t &a) {
return std::conj(a);
}

complex128_t conj(const complex128_t &a) {
static complex128_t conj(const complex128_t &a) {
return std::conj(a);
}

complex64_t make_complex(float x, float y) {
static complex64_t make_complex(float x, float y) {
return complex64_t(x, y);
}

complex128_t make_complex(double x, double y) {
static complex128_t make_complex(double x, double y) {
return complex128_t(x, y);
}
7 changes: 7 additions & 0 deletions include/hidet/runtime/cpu/float32.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <math.h>

static inline float rsqrtf(float x)
{
return 1.0f / sqrtf(x);
}

12 changes: 0 additions & 12 deletions include/hidet/runtime/memory_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ struct MemoryPlanner {
};

static MemoryPlanner memory_planner;
//
//int max_segments = 0;

static void memory_planner_init() {
memory_planner.size_map.clear();
Expand All @@ -31,13 +29,6 @@ static void memory_planner_init() {
}

static int64_t memory_planner_allocate(int64_t size) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();

// auto ret = memory_planner.regions.begin()->start;
// memory_planner.regions.begin()->start += size;
// return ret;
if(size == 0) {
return -1;
}
Expand Down Expand Up @@ -65,9 +56,6 @@ static int64_t memory_planner_allocate(int64_t size) {
}

static void memory_planner_free(int64_t ptr) {
// max_segments = std::max(max_segments, (int)memory_planner.regions.size());
// printf("%d (%d)\n", (int)memory_planner.regions.size(), max_segments);
// memory_planner.print();
if(ptr == -1) {
return;
}
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from . import utils
from . import graph
from . import runtime
from . import driver
from . import drivers
from . import logging
from . import cuda

Expand All @@ -31,7 +31,7 @@
from .ir.expr import symbol_var

from .runtime.device import Device, device
from .runtime.model import save_model, load_model
from .runtime.compiled_graph import save_compiled_graph, load_compiled_graph

from .graph import Tensor, Operator, Module, FlowGraph
from .graph import nn
Expand Down
60 changes: 41 additions & 19 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List, Dict
from typing import Optional, List, Sequence
import functools
import warnings
import os
Expand Down Expand Up @@ -39,11 +39,10 @@ class SourceCompiler:
The base class of source compiler.
"""

def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str, str]] = None) -> None:
def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
raise NotImplementedError()

@staticmethod
def run_compile_command(command: str, src_path, out_lib_path: str):
def run_compile_command(self, command: str, src_path, out_lib_path: str):
try:
# the directory to store the library "lib.so"
out_lib_dir = os.path.dirname(out_lib_path)
Expand All @@ -68,7 +67,8 @@ def run_compile_command(command: str, src_path, out_lib_path: str):
raise CompilationFailed(src_path, message)

# write the compilation log
with open(os.path.join(out_lib_dir, 'compiler.log'), 'w') as f:
log_name = self.__class__.__name__.lower() + '_output.txt'
with open(os.path.join(out_lib_dir, log_name), 'w', encoding='utf-8') as f:
output = '\n'.join([result.stdout.decode('utf-8').strip(), result.stderr.decode('utf-8').strip()])
f.write(output.strip())

Expand Down Expand Up @@ -104,7 +104,10 @@ def _resolve_nvcc_path():
return path
raise FileNotFoundError('Can not find nvcc compiler.')

def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str, str]] = None) -> None:
def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')

cc = hidet.cuda.compute_capability()
cc_code = '{}{}'.format(cc[0], cc[1])

Expand Down Expand Up @@ -148,7 +151,9 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
# supress warning no 39 like: "warning #39-D: division by zero"
'--diag-suppress 39',
# generate shared library (lib.so).
'--shared',
'--shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
# the source path.
src_path,
# the output library path.
Expand All @@ -174,7 +179,9 @@ def _resolve_gcc_path():
return path
raise FileNotFoundError('Can not find g++ compiler.')

def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str, str]] = None) -> None:
def compile(self, src_path: str, out_lib_path: str, linking_objects: Sequence[str]) -> None:
if len(linking_objects) > 0 and out_lib_path.endswith('.o'):
raise ValueError('Can not compile multiple objects into a single object file.')
command = [
# the path to nvcc compiler
self.gcc_path,
Expand All @@ -193,9 +200,11 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
# enable OpenMP.
'-fopenmp',
# link the hidet runtime, all APIs for communication between kernels and host system are in hidet runtime.
'-lhidet_runtime',
'-Wl,--no-as-needed -lhidet_runtime',
# generate shared library (lib.so).
'-shared',
'-shared' if out_lib_path.endswith('.so') else '--compile',
# the linking objects.
' '.join(linking_objects),
# the source path.
src_path,
# the output library path.
Expand All @@ -206,23 +215,36 @@ def compile(self, src_path: str, out_lib_path: str, options: Optional[Dict[str,
self.run_compile_command(" ".join(command), src_path, out_lib_path)


def compile_source(src_path: str, out_lib_path: str) -> None:
def compile_source(
source_file: str, output_library_file: str, target: str, object_files: Optional[Sequence[str]]
) -> None:
"""
Compile the source code in 'src_path' file and output the library to 'out_lib_path'.

Parameters
----------
src_path: str
source_file: str
The path to source code.
out_lib_path: str
output_library_file: str
The path to output library.
target: str
The target platform. Currently only support 'cpu' and 'gpu'.
object_files: Optional[Sequence[str]]
The path to object files. If not None, the object files will be linked to the output library.
"""
src_path = os.path.abspath(src_path)
out_lib_path = os.path.abspath(out_lib_path)

if hidet.cuda.available():
source_file = os.path.abspath(source_file)
output_library_file = os.path.abspath(output_library_file)
if object_files is not None:
object_files = [os.path.abspath(object_file) for object_file in object_files]

if target == 'cuda':
if not hidet.cuda.available():
raise RuntimeError('CUDA is not available.')
compiler = NVCC()
else:
elif target == 'cpu':
compiler = GCC()
else:
raise ValueError('Unknown target platform: {}'.format(target))

compiler.compile(src_path, out_lib_path)
object_files = object_files or []
compiler.compile(source_file, output_library_file, object_files)
Loading