Skip to content

Commit

Permalink
[Compilation] Optimization for compilation process
Browse files Browse the repository at this point in the history
  • Loading branch information
maxyanghu committed Feb 15, 2024
1 parent 5f76caf commit c5fc4db
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 31 deletions.
15 changes: 10 additions & 5 deletions python/hidet/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List, Tuple, Dict, Union
from typing import Sequence, Optional, List, Tuple, Dict, Union
import os
import numpy as np
from hidet.ir.dialects.pattern import PlaceholderExpr
Expand Down Expand Up @@ -263,7 +263,6 @@ def visit_IRModule(self, module: IRModule) -> Doc:
doc += '} // namespace ' + module.namespace + NewLine()

doc = self.require_headers() + doc

return doc

def visit_Function(self, func: Function) -> Doc:
Expand Down Expand Up @@ -831,7 +830,7 @@ def visit_Function(self, func: Function) -> Doc:
return doc


def codegen(ir_module: IRModule, src_out_path: str, target: Union[str, Target]) -> str:
def codegen(ir_module: Union[IRModule, Sequence[IRModule]], src_out_path: str, target: Union[str, Target]) -> str:
if isinstance(target, str):
target = Target.from_string(target)

Expand All @@ -842,8 +841,14 @@ def codegen(ir_module: IRModule, src_out_path: str, target: Union[str, Target])
else:
raise ValueError(f'Unknown target: {target}')

doc = gen(ir_module)
code = str(doc)
code = ''
if isinstance(ir_module, Sequence):
for m in ir_module:
doc = gen(m)
code += str(doc) + '\n'
else:
doc = gen(ir_module)
code = str(doc)
if src_out_path is not None:
dir_path = os.path.dirname(src_out_path)
if not os.path.exists(dir_path):
Expand Down
68 changes: 50 additions & 18 deletions python/hidet/drivers/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Dict
from typing import Sequence, Dict, Union
import logging
import os
import pickle
Expand Down Expand Up @@ -37,7 +37,9 @@ def can_remote_build(ir_module: IRModule) -> bool:
return not (len(ir_module.object_files) > 0 or len(ir_module.linking_dirs) > 0 or len(ir_module.include_dirs) > 0)


def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output_kind: str = '.so', force=False):
def build_ir_module(
ir_module: Union[IRModule, Sequence[IRModule]], output_dir: str, target: str, output_kind: str = '.so', force=False
):
"""
Build an IR module to a shared library or object file.
Expand All @@ -50,8 +52,8 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output
Parameters
----------
ir_module: IRModule
The IR module to be built.
ir_module: Union[IRModule, Sequence[IRModule]]
The IR module to be built. This can be a single IRModule or a sequence of IRModules.
output_dir: str
The directory to save the generated source code and the compiled library.
Expand Down Expand Up @@ -85,7 +87,7 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output
# the library has been built
return

if hidet.option.compile_server.enabled() and can_remote_build(ir_module):
if isinstance(ir_module, IRModule) and hidet.option.compile_server.enabled() and can_remote_build(ir_module):
from hidet.apps.compile_server import remote_build

remote_build(ir_module, output_dir, target=target, output_kind=output_kind)
Expand Down Expand Up @@ -118,20 +120,39 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output
if target.name == 'cpu' and 'arch' in target.attrs:
hidet.option.cpu.arch(target.attrs['arch'])
with PassContext(instruments=instruments):
ir_module = lower(ir_module)
if isinstance(ir_module, Sequence):
for i in range(len(ir_module)):
ir_module[i] = lower(ir_module[i])
else:
ir_module = lower(ir_module)

# code generation
codegen(ir_module, src_out_path=src_path, target=target)

include_dir = []
linking_dir = []
linking_lib = []
object_file = []
if isinstance(ir_module, Sequence):
for im in ir_module:
include_dir.extend(im.include_dirs)
linking_dir.extend(im.linking_dirs)
linking_lib.extend(im.linking_libs)
object_file.extend(im.object_files)
else:
include_dir.extend(ir_module.include_dirs)
linking_dir.extend(ir_module.linking_dirs)
linking_lib.extend(ir_module.linking_libs)
object_file.extend(ir_module.object_files)
# compile source code
compile_source(
src_path,
output_library_file=lib_path,
target=target,
include_dirs=ir_module.include_dirs,
linking_dirs=ir_module.linking_dirs,
linking_libraries=ir_module.linking_libs,
object_files=ir_module.object_files,
include_dirs=include_dir,
linking_dirs=linking_dir,
linking_libraries=linking_lib,
object_files=object_file,
)

# write the function types
Expand All @@ -154,8 +175,8 @@ def build_ir_module_batch(
ir_modules: Sequence[IRModule]
A sequence of ir modules to build.
output_dirs: Sequence[str]
The output directory to save the compiled library and source code (lib.so and source.cu).
output_dirs: Squence[str]
Directories for compilation artifacts
output_kind: str
The output kind of the compiled library. Can be '.so' or '.o'.
Expand All @@ -172,19 +193,23 @@ def build_job(args):
ir_module, output_dir = args
build_ir_module(ir_module, output_dir, output_kind=output_kind, target=target, force=force)

jobs = [(ir_module, output_dir) for ir_module, output_dir in zip(ir_modules, output_dirs)]
def regroup_modules(modules, size):
if size > 1:
return [modules[i : i + size] for i in range(0, len(modules), size)]
else:
return modules

# calculate the number of workers
cpu_count = os.cpu_count()
if hidet.option.compile_server.enabled():
num_workers = min(len(jobs), 128)
num_workers = min(len(ir_modules), 128)
else:
max_jobs, mem_for_worker = option.get_parallel_tune()
max_jobs = cpu_count if max_jobs == -1 else min(max_jobs, cpu_count)
mem_for_worker *= 1024**3
num_workers = min(max(int(psutil.virtual_memory().available // mem_for_worker), 1), max_jobs)

if num_workers > 1 and len(jobs) > 1:
if num_workers > 1 and len(ir_modules) > 1:
# Set the affinity of current process. Some package such as numpy will change affinity of current process,
# which might limit the parallelism of compilation.
from contextlib import suppress
Expand All @@ -194,9 +219,16 @@ def build_job(args):

lazy_initialize_cuda()

per_worker_jobs = 1 if len(ir_modules) < num_workers else len(ir_modules) // num_workers
ir_modules_list = regroup_modules(ir_modules, per_worker_jobs)
jobs = [
(ir_modules, output_dir)
for ir_modules, output_dir in zip(ir_modules_list, output_dirs[: len(ir_modules_list)])
]

for _ in tqdm(parallel_imap(build_job, jobs, num_workers), desc='Compiling', total=len(jobs), ncols=80):
pass
return output_dirs[: len(ir_modules_list)]
else:
# sequential build
for job in tqdm(jobs, desc='Compiling', ncols=80, disable=len(jobs) == 1):
build_job(job)
build_ir_module(ir_modules, output_dir=output_dirs[0], output_kind=output_kind, target=target, force=force)
return [output_dirs[0]]
10 changes: 2 additions & 8 deletions python/hidet/drivers/build_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import shutil
from hashlib import sha256
from typing import List, Optional, Tuple

import hidet.cuda
from hidet import option
from hidet.ir.stmt import AssertStmt
Expand Down Expand Up @@ -111,9 +110,8 @@ def get_output_shape(idx: int32, dims: ~int32):

# generate the candidate summary
_generate_candidate_summary(candidates, task_dir)

# build each candidate to an object file (.o)
build_ir_module_batch(
objects_path_list = build_ir_module_batch(
ir_modules=candidates,
output_dirs=[os.path.join(task_dir, 'candidates', str(i)) for i in range(len(candidates))],
output_kind='.o',
Expand Down Expand Up @@ -143,9 +141,7 @@ def launch(arg: meta.types(param_types)):
ir_module = script_module.ir_module()
ir_module.add_function(get_input_shape.name, get_input_shape)
ir_module.add_function(get_output_shape.name, get_output_shape)
ir_module.object_files.extend(
[os.path.join(task_dir, 'candidates', str(i), 'lib.o') for i in range(len(candidates))]
)
ir_module.object_files.extend([os.path.join(object_path, 'lib.o') for object_path in objects_path_list])
task_ir_module = ir_module

# add assertions to the launch function
Expand All @@ -162,7 +158,6 @@ def launch(arg: meta.types(param_types)):

# build task ir module
build_ir_module(ir_module=task_ir_module, output_dir=task_dir, output_kind='.so', target=target)

# clear the candidate object files that are no longer needed
if not hidet.option.get_option('debug_cache_tuning'):
shutil.rmtree(os.path.join(task_dir, 'candidates'), ignore_errors=True)
Expand Down Expand Up @@ -277,7 +272,6 @@ def build_task(task: Task, target='cuda', load=True) -> Optional[CompiledTask]:
# write version
with open(version_path, 'w') as f:
f.write(hidet.__version__)

# implement task to IRModule, each task may produce multiple IRModules (candidates)
# they have the same functionality but different performance
candidates = task.implement(target=target, working_dir=task_dir)
Expand Down

0 comments on commit c5fc4db

Please sign in to comment.