-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[COMPSERVER] Compilation server speed up #487
Conversation
aeb9d8c
to
50ac5c2
Compare
commit_id: str = job['commit_id'] | ||
commit_dir = os.path.join(commits_dir, commit_id) | ||
sys.path.insert(0, os.path.join(commit_dir, 'python')) | ||
import hidet # import the hidet from the commit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, when we have an module previously imported, the import xxx
will just been ignored. If that's true, this import hidet
will be ignored for the second and later on jobs. Thus we might compile the given job with a wrong hidet version. Could you double check the behaviour of import
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I think you are right.
Fixing it with importlib.reload
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a confirmation
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ tree
.
├── a
│ └── mymodule.py
├── b
│ └── mymodule.py
└── import_test.py
2 directories, 3 files
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ cat a/mymodule.py
def f():
print("Module A")
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ cat b/mymodule.py
def f():
print("Module b")
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ cat import_test.py
import os
import sys
sys.path.insert(0, os.path.abspath("a"))
import mymodule
mymodule.f()
sys.path.pop(0)
sys.path.insert(0, os.path.abspath("b"))
import mymodule
mymodule.f()
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ python import_test.py
Module A
Module A
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we reload hidet, it might have similar performance as launching a new process if 'importing' hidet takes most of the time. I am not sure whether there will be any side-effect when we reload a different version of hidet given a hidet has already been imported. For example, the hidet runtime library (libhidet_runtime.so
) might not be freed automatically and dlopen will use the previously loaded libhidet_runtime.so
instead of reload the one in the later imported hidet.
We could spend some efforts to reduce the import time of hidet as a standalone job. In my impression, most of the time is spent in adding primitive functions (we construct all primitive functions at import-time).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
runtime lib reload I will check as well
When we reload hidet, it might have similar performance as launching a new process if 'importing' hidet takes most of the time
The most overhead comes from loading hidet not process creation. But the main idea is to avoid process creation because it unconditionally cause import hidet
. In one process we can reload hidet only if it actually changed.
We could spend some efforts to reduce the import time of hidet as a standalone job. In my impression, most of the time is spent in adding primitive functions (we construct all primitive functions at import-time).
The reason of long import is definitely register primitive functions and situation stayed really bad after wgmma registration. Bolin is working on speed up it. But there 5K+ different wgmma to register.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In one process we can reload hidet only if it actually changed.
That's challenging. Different commits have different versions of hidet. How do you define "actually changed"?
The reason of long import is definitely register primitive functions and situation stayed really bad after wgmma registration. Bolin is working on speed up it. But there 5K+ different wgmma to register.
Yeah. Let's work towards this direction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's challenging. Different commits have different versions of hidet. How do you define "actually changed"?
we have a folder with hidet in hands. We can remember this folder. If on next "wish" to import hidet folder not changed we can skip import. Did I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's challenging. Different commits have different versions of hidet. How do you define "actually changed"?
we have a folder with hidet in hands. We can remember this folder. If on next "wish" to import hidet folder not changed we can skip import. Did I miss something?
I see your point. If that's what you want to support, we could
- maintain a pool of compilation workers, like 20 or similar number. each worker has already imported some version of hidet.
- once the compilation server recevied one job, and its hidet version matches one of the worker, we deliver this job to the matched worker.
- if no matched worker, then terminate the oldest one and lauch a new one with given version of hidet.
This makes the the worker and server isolated.
I dig deeper in the question of libs load/unload The following test import os
import sys
import subprocess
import importlib
def unimport_hidet(module_name):
for module_name in list(sys.modules.keys()):
if module_name.startswith("hidet"):
module = sys.modules.get(module_name)
if hasattr(module, "cleanup"):
module.cleanup()
del sys.modules[module_name]
importlib.invalidate_caches() # Ensure the import cache is cleared
def print_libs_info():
pid = os.getpid()
result = subprocess.run(f"ls -l /proc/{pid}/map_files | grep hidet",
shell=True, capture_output=True, text=True)
print("STDOUT:", result.stdout, flush=True)
print("_LIB", hidet.ffi.ffi._LIB, flush=True)
print("_LIB_RUNTIME", hidet.ffi.ffi._LIB_RUNTIME, flush=True)
print("_LIB_HIDET_TORCH_WRAPPER", hidet.ffi.ffi._LIB_HIDET_TORCH_WRAPPER, flush=True)
sys.path.insert(0, os.path.abspath("/home/ubuntu/hidet/python"))
import hidet
print_libs_info()
sys.path.pop(0)
sys.path.insert(0, os.path.abspath("/home/ubuntu/hidet3/python"))
unimport_hidet("hidet")
import hidet
print_libs_info() produce the following result ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ python import_test.py
STDOUT: lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5dea5000-7fdd5deb9000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5deb9000-7fdd5def0000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5def0000-7fdd5defd000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5defd000-7fdd5defe000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5defe000-7fdd5deff000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66044000-7fdd66045000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66045000-7fdd66046000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66046000-7fdd66047000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66047000-7fdd66048000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66048000-7fdd66049000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66049000-7fdd6604a000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604a000-7fdd6604b000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604b000-7fdd6604c000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604c000-7fdd6604d000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604d000-7fdd6604e000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
_LIB <CDLL '/home/ubuntu/hidet/build/lib/libhidet.so', handle 5930795f7410 at 0x7fdd5cba6f20>
_LIB_RUNTIME <CDLL '/home/ubuntu/hidet/build/lib/libhidet_runtime.so', handle 59307960acd0 at 0x7fdd5cba6da0>
_LIB_HIDET_TORCH_WRAPPER <CDLL '/home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so', handle 5930795f6b20 at 0x7fdd5cba6dd0>
STDOUT: lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdc8b5a5000-7fdc8b5b9000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdc8b5b9000-7fdc8b5f0000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdc8b5f0000-7fdc8b5fd000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdc8b5fd000-7fdc8b5fe000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdc8b5fe000-7fdc8b5ff000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5b662000-7fdd5b663000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5b663000-7fdd5b664000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5b664000-7fdd5b665000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5b665000-7fdd5b666000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5b666000-7fdd5b667000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5ce2c000-7fdd5ce2d000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5ce2d000-7fdd5ce2e000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5ce2e000-7fdd5ce2f000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5ce2f000-7fdd5ce30000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5ce30000-7fdd5ce31000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5dea5000-7fdd5deb9000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5deb9000-7fdd5def0000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5def0000-7fdd5defd000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5defd000-7fdd5defe000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd5defe000-7fdd5deff000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66044000-7fdd66045000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66045000-7fdd66046000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66046000-7fdd66047000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66047000-7fdd66048000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66048000-7fdd66049000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd66049000-7fdd6604a000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604a000-7fdd6604b000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604b000-7fdd6604c000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604c000-7fdd6604d000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan 7 06:39 7fdd6604d000-7fdd6604e000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
_LIB <CDLL '/home/ubuntu/hidet3/build/lib/libhidet.so', handle 59307f826cf0 at 0x7fdc8b66a4d0>
_LIB_RUNTIME <CDLL '/home/ubuntu/hidet3/build/lib/libhidet_runtime.so', handle 59307f825510 at 0x7fdc8b66a530>
_LIB_HIDET_TORCH_WRAPPER <CDLL '/home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so', handle 59307f826440 at 0x7fdc8b66a470> So, the new version of libs are loaded and our vars representing them reference to the new version. |
I spend some time but was not able to unload libraries. After deletion there are some refs to lib and it is not unloaded. From other point of view all libs consume around 320KB. Even if we have 100 different commits we waste only 32MB of memory. Regarding necessity. Frequently compilation of one candidate is not so long - several seconds. Even if we speed up |
Thanks @vadiklyutiy for the more detailed exploration. After some time of thinking, I still think "reloading" is not the right direction to go to speedup the compilation server. The isolation of compilation server and specific compilation worker make any unexpected behavior in the new commit of hidet isolated from the main compilation server. Thus, even the hidet make the interpreter crash, our main server is still safe. To perform the effective reloading, we also need to reconstrct all the primitives and it will take the same amount of time. This reconstruction is necessary because we might have changed the code of constructing primitives. |
The speed up is about when nothing changes. When fixed hidet hash is used. I believe in customer use it would be most common case. |
Agree with that. Some garbage commit may kill. omp server. |
@yaoyaoding We can make some option/mode for comp server when it support only one hidet commit hash and avoid multiply import for such case. |
I agree that it's common that only one version of hidet will be used in a short time window. I asked ChatGPT to write some sample code to implement the worker pool: import multiprocessing
import os
import time
import importlib
import sys
# Worker function to handle jobs for a specific version of `hidet`
def worker_process(version, job_queue, result_queue):
sys.path.insert(0, version) # Ensure the version path is first in sys.path
hidet = importlib.import_module("hidet") # Load the specific version
importlib.reload(hidet) # Reload to ensure correct version
print(f"Worker {multiprocessing.current_process().name} loaded hidet version from {version}")
while True:
job = job_queue.get()
if job == "STOP":
print(f"Shutting down worker for version: {version}")
break
# Simulate compilation job
job_id, task = job
print(f"Worker {multiprocessing.current_process().name} processing job {job_id} with hidet version {version}")
result = f"Result of {task} using hidet version {version}"
result_queue.put((job_id, result))
# Main server to manage job distribution and worker pool
class CompilationServer:
def __init__(self, max_workers=20):
self.max_workers = max_workers
self.workers = {} # {version_path: (worker_process, job_queue)}
self.result_queue = multiprocessing.Queue()
def get_or_create_worker(self, version_path):
# If a worker for the version exists, return it
if version_path in self.workers:
return self.workers[version_path]
# If the worker pool is full, remove the oldest worker
if len(self.workers) >= self.max_workers:
old_version, (worker, job_queue) = self.workers.popitem()
job_queue.put("STOP") # Send shutdown signal to the old worker
worker.join() # Wait for it to exit
# Create a new worker for the version
job_queue = multiprocessing.Queue()
worker = multiprocessing.Process(target=worker_process, args=(version_path, job_queue, self.result_queue))
worker.start()
self.workers[version_path] = (worker, job_queue)
return self.workers[version_path]
def submit_job(self, job_id, version_path, task):
_, job_queue = self.get_or_create_worker(version_path)
job_queue.put((job_id, task))
def get_result(self):
while not self.result_queue.empty():
job_id, result = self.result_queue.get()
print(f"Result for job {job_id}: {result}")
def shutdown(self):
for version_path, (worker, job_queue) in self.workers.items():
job_queue.put("STOP")
worker.join()
print("Server shut down.")
# Example usage
if __name__ == "__main__":
server = CompilationServer(max_workers=5)
# Simulate receiving compilation jobs for different versions
versions = ["hidet_v1", "hidet_v2", "hidet_v3"]
os.makedirs("hidet_v1", exist_ok=True)
os.makedirs("hidet_v2", exist_ok=True)
os.makedirs("hidet_v3", exist_ok=True)
try:
for i in range(10):
version = versions[i % len(versions)]
server.submit_job(job_id=i, version_path=version, task=f"compile_task_{i}")
time.sleep(2)
server.get_result()
finally:
server.shutdown() How do you think to adopt above POC to our server? |
I recently asked both gtp4 and somet to change dll load from ctype.cddl to cffi, both atemtps failed even after my manual fixes :-D To avoid misunderstanding. You prefer pool of workers, every worker hold a separate version of hidet(as you described above) instead of make option for compserver that support only one version of hidet. Right? |
Yeah, the complexity between only supporting one version of hidet vs. with a pool should be similar since both should be run in seperate processes. |
aa87120
to
50ac5c2
Compare
Yaoyao's proposal implemented in #489 |
Compilation server didn't work on full cylinders. One problem fixed with this PR.
Before for every compilation request we spawned separate process. Inside this process we make
import hidet
that takes 16 sec https://github.com/CentML/hidet/issues/669. Comparing to compilation that might take just 2 seconds it was a huge overhead. Here run the compilation in same proccess.+ add force
flush
forprint
in compilation server, otherwise they are not appear in console.