Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fixbug] Fix a bug in compile server #306

Merged
merged 10 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 74 additions & 54 deletions apps/compile_server/resources/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import traceback
import threading
import requests
import subprocess
import zipfile
import logging
Expand All @@ -19,6 +20,7 @@
lock = threading.Lock()
logger = logging.Logger(__name__)

pid = os.getpid()
jobs_dir = os.path.join(os.getcwd(), 'jobs')
repos_dir = os.path.join(os.getcwd(), 'repos')
commits_dir = os.path.join(os.getcwd(), 'commits')
Expand All @@ -33,8 +35,6 @@ def should_update(repo_timestamp) -> bool:
timestamp = f.read()
return time.time() - float(timestamp) > 3 * 60 # 3 minutes
else:
with open(repo_timestamp, 'w') as f:
f.write(str(time.time()))
return True


Expand All @@ -52,10 +52,15 @@ def clone_github_repo(owner: str, repo: str, version: str) -> str:
repo = git.Repo(repo_dir)

if should_update(repo_timestamp):
repo.remotes.origin.fetch()
repo.git.fetch('--all')
repo.git.fetch('--tags')
repo.git.checkout(version)
# repo.remotes.origin.fetch()
# repo.git.fetch('--all')
# repo.git.fetch('--tags')
repo.git.checkout(version)
repo.remotes.origin.pull(version)
with open(repo_timestamp, 'w') as f:
f.write(str(time.time()))
else:
repo.git.checkout(version)
commit_id = repo.head.commit.hexsha

commit_dir = os.path.join(commits_dir, commit_id)
Expand All @@ -73,7 +78,14 @@ def clone_github_repo(owner: str, repo: str, version: str) -> str:
("make -j1", "./build")
]
for command, cwd in commands:
subprocess.run(command.split(), cwd=os.path.join(commit_dir, cwd), check=True)
subprocess.run(
command.split(),
cwd=os.path.join(commit_dir, cwd),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True
)

return commit_id

Expand All @@ -94,50 +106,58 @@ def parse_repo_url(url: str) -> Tuple[str, str]:
class CompilationResource(Resource):
@jwt_required() # Requires JWT authentication for this endpoint
def post(self):
# Retrieve the ir modules
job_data: bytes = request.data

raw_job: Dict[str, Any] = pickle.loads(job_data)

# download the repository if needed
hidet_repo_url = raw_job['hidet_repo_url']
hidet_repo_version = raw_job['hidet_repo_version']
owner, repo = parse_repo_url(hidet_repo_url)
commit_id: str = clone_github_repo(owner, repo, hidet_repo_version)

workload: bytes = raw_job['workload']

job = {
'commit_id': commit_id,
'workload': workload
}

job_id: str = sha256(commit_id.encode() + workload).hexdigest()
job_path = os.path.join(jobs_dir, job_id + '.pickle')
job_response_path = os.path.join(jobs_dir, job_id + '.response')

# check if the job is already done
if os.path.exists(job_response_path):
with open(job_response_path, 'rb') as f:
return pickle.load(f)

# write the job to the disk
job_lock = os.path.join(jobs_dir, job_id + '.lock')
with FileLock(job_lock):
if not os.path.exists(job_path):
with open(job_path, 'wb') as f:
pickle.dump(job, f)

with lock: # Only one thread can access the following code at the same time
print('job_id:', job_id)
ret = subprocess.run([sys.executable, compile_script, '--job_id', job_id])

# respond to the client
response_path = os.path.join(jobs_dir, job_id + '.response')
if not os.path.exists(response_path):
msg = '{}\n{}'.format(ret.stderr, ret.stdout)
return {'message': 'Can not find a response from the worker due to\n{}'.format(msg)}, 500
else:
with open(response_path, 'rb') as f:
response: Tuple[Dict, int] = pickle.load(f)
return response
try:
# Retrieve the ir modules
job_data: bytes = request.data

raw_job: Dict[str, Any] = pickle.loads(job_data)

# download the repository if needed
hidet_repo_url = raw_job['hidet_repo_url']
hidet_repo_version = raw_job['hidet_repo_version']
owner, repo = parse_repo_url(hidet_repo_url)
commit_id: str = clone_github_repo(owner, repo, hidet_repo_version)

workload: bytes = raw_job['workload']

job = {
'commit_id': commit_id,
'workload': workload
}

job_id: str = sha256(commit_id.encode() + workload).hexdigest()
job_path = os.path.join(jobs_dir, job_id + '.pickle')
job_response_path = os.path.join(jobs_dir, job_id + '.response')

print('[{}] Received a job: {}'.format(pid, job_id[:16]))

# check if the job is already done
if os.path.exists(job_response_path):
print('[{}] Job {} has already done before, respond directly'.format(pid, job_id[:16]))
with open(job_response_path, 'rb') as f:
return pickle.load(f)

# write the job to the disk
job_lock = os.path.join(jobs_dir, job_id + '.lock')
with FileLock(job_lock):
if not os.path.exists(job_path):
with open(job_path, 'wb') as f:
pickle.dump(job, f)

with lock: # Only one thread can access the following code at the same time
print('[{}] Start compiling: {}'.format(pid, job_id[:16]))
ret = subprocess.run([sys.executable, compile_script, '--job_id', job_id])

# respond to the client
response_path = os.path.join(jobs_dir, job_id + '.response')
if not os.path.exists(response_path):
raise RuntimeError('Can not find the response file:\n{}{}'.format(ret.stderr, ret.stdout))
else:
print('[{}] Finish compiling: {}'.format(pid, job_id[:16]))
with open(response_path, 'rb') as f:
response: Tuple[Dict, int] = pickle.load(f)
return response
except Exception as e:
msg = traceback.format_exc()
print('[{}] Failed to compile:\n{}'.format(pid, msg))
return {'message': '[Remote] {}'.format(msg)}, 500
67 changes: 56 additions & 11 deletions python/hidet/drivers/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,54 @@ def can_remote_build(ir_module: IRModule) -> bool:
return not (len(ir_module.object_files) > 0 or len(ir_module.linking_dirs) > 0 or len(ir_module.include_dirs) > 0)


def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output_kind: str = '.so'): # '.so', '.o'
def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output_kind: str = '.so', force=False):
"""
Build an IR module to a shared library or object file.

This driver function performs the following steps to build an IR module:

1. Lower and optimize the IR module with a sequence of pre-defined passes.
2. Generate source code from the lowered IR module.
3. Call the underlying compiler (e.g., gcc or nvcc) to compile the generated source code to a shared library (when
`output_kind == '.so'`) or an object file (when `output_kind == '.o'`).

Parameters
----------
ir_module: IRModule
The IR module to be built.

output_dir: str
The directory to save the generated source code and the compiled library.

target: str
The target to build the IR module. Currently, we support two targets: `cpu` and `cuda`. The target can also
specify attributes (e.g., 'cuda --arch=sm_70').

output_kind: str
The output kind. Currently, we support two kinds: `'.so'` and `'.o'`. The former means that the IR module will
be compiled to a shared library, while the latter means that the IR module will be compiled to an object file.

force: bool
Whether to force re-build the IR module. By default, we will not re-build the IR module if the library has been
built at the specified output directory.
"""
if output_kind == '.so':
lib_name = 'lib.so'
elif output_kind == '.o':
lib_name = 'lib.o'
else:
raise ValueError(f'Invalid output kind: {output_kind}')
lib_path = os.path.join(output_dir, lib_name)

if (
os.path.exists(lib_path)
and os.path.getsize(lib_path) > 0
and (output_kind != '.so' or os.path.exists(os.path.join(output_dir, 'func_types.pickle')))
and not force
):
# the library has been built
return

if hidet.option.compile_server.enabled() and can_remote_build(ir_module):
from hidet.apps.compile_server import remote_build

Expand All @@ -53,14 +100,6 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output
else:
raise ValueError(f'Invalid target: {target}')

if output_kind == '.so':
lib_name = 'lib.so'
elif output_kind == '.o':
lib_name = 'lib.o'
else:
raise ValueError(f'Invalid output kind: {output_kind}')
lib_path = os.path.join(output_dir, lib_name)

# lower ir module
instruments = []
if hidet.option.get_save_lower_ir():
Expand Down Expand Up @@ -92,7 +131,9 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output
pickle.dump(func_types, f)


def build_ir_module_batch(ir_modules: Sequence[IRModule], output_dirs: Sequence[str], output_kind: str, target: str):
def build_ir_module_batch(
ir_modules: Sequence[IRModule], output_dirs: Sequence[str], output_kind: str, target: str, force: bool = False
):
"""
Build a batch of ir modules.

Expand All @@ -109,11 +150,15 @@ def build_ir_module_batch(ir_modules: Sequence[IRModule], output_dirs: Sequence[

target: str
The target of the compilation. Can be 'cuda' or 'cpu'.

force: bool
Whether to force re-build the IR module. By default, we will not re-build the IR module if the library has been
built at the specified output directory.
"""

def build_job(args):
ir_module, output_dir = args
build_ir_module(ir_module, output_dir, output_kind=output_kind, target=target)
build_ir_module(ir_module, output_dir, output_kind=output_kind, target=target, force=force)

jobs = [(ir_module, output_dir) for ir_module, output_dir in zip(ir_modules, output_dirs)]

Expand Down
6 changes: 3 additions & 3 deletions python/hidet/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def build(self):
ret: hidet.runtime.CompiledModule
The compiled module.
"""
import os
from hashlib import sha256
from hidet.drivers import build_ir_module
from hidet.runtime import load_compiled_module
from hashlib import sha256
import hidet.utils

hash_dir = sha256(str(self).encode()).hexdigest()[:16]
output_dir = os.path.join('./outs/ir_modules', hash_dir)
output_dir = hidet.utils.cache_dir('ir_modules', hash_dir)

if any(func.kind in ['cuda_kernel', 'cuda_internal'] for func in self.functions.values()):
target = 'cuda'
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/runtime/compiled_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def save(self, path: str):
def save_compiled_graph(model: CompiledGraph, path: str):
from hidet.utils.dataclass import asdict

dirname = os.path.dirname(path)
os.makedirs(dirname, exist_ok=True)

with zipfile.ZipFile(path, 'w') as zf:

def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = None):
Expand Down