Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[COMPSERVER] Compilation server speed up #487

Closed
wants to merge 3 commits into from

Conversation

vadiklyutiy
Copy link
Collaborator

Compilation server didn't work on full cylinders. One problem fixed with this PR.
Before for every compilation request we spawned separate process. Inside this process we make import hidet that takes 16 sec https://github.com/CentML/hidet/issues/669. Comparing to compilation that might take just 2 seconds it was a huge overhead. Here run the compilation in same proccess.

+ add force flush for print in compilation server, otherwise they are not appear in console.

@vadiklyutiy vadiklyutiy added the Task Or "Story" in JIRA's term. label Dec 24, 2024
@vadiklyutiy vadiklyutiy self-assigned this Dec 24, 2024
@vadiklyutiy vadiklyutiy force-pushed the vadim/fix-comp-server branch from aeb9d8c to 50ac5c2 Compare December 24, 2024 20:13
vadiklyutiy added a commit that referenced this pull request Dec 26, 2024
…487)

Fix for #486

Below rule in `rule_based_simplifier`
`((e1 // c1) // c2, e1 // (c1 * c2))`
apply for cases when `e1` is int Var, `c1` is int const, `c2` is fp const. But the rule is incorrect in this case. 

Apply the rule for `int` constants only.
commit_id: str = job['commit_id']
commit_dir = os.path.join(commits_dir, commit_id)
sys.path.insert(0, os.path.join(commit_dir, 'python'))
import hidet # import the hidet from the commit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, when we have an module previously imported, the import xxx will just been ignored. If that's true, this import hidet will be ignored for the second and later on jobs. Thus we might compile the given job with a wrong hidet version. Could you double check the behaviour of import?

Copy link
Collaborator Author

@vadiklyutiy vadiklyutiy Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think you are right.
Fixing it with importlib.reload

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a confirmation

ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ tree
.
├── a
│   └── mymodule.py
├── b
│   └── mymodule.py
└── import_test.py

2 directories, 3 files
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ cat a/mymodule.py
def f():
    print("Module A")
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ cat b/mymodule.py
def f():
    print("Module b")
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ cat import_test.py
import os
import sys

sys.path.insert(0, os.path.abspath("a"))
import mymodule
mymodule.f()

sys.path.pop(0)

sys.path.insert(0, os.path.abspath("b"))
import mymodule
mymodule.f()
ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ python import_test.py
Module A
Module A

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we reload hidet, it might have similar performance as launching a new process if 'importing' hidet takes most of the time. I am not sure whether there will be any side-effect when we reload a different version of hidet given a hidet has already been imported. For example, the hidet runtime library (libhidet_runtime.so) might not be freed automatically and dlopen will use the previously loaded libhidet_runtime.so instead of reload the one in the later imported hidet.

We could spend some efforts to reduce the import time of hidet as a standalone job. In my impression, most of the time is spent in adding primitive functions (we construct all primitive functions at import-time).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

runtime lib reload I will check as well

When we reload hidet, it might have similar performance as launching a new process if 'importing' hidet takes most of the time

The most overhead comes from loading hidet not process creation. But the main idea is to avoid process creation because it unconditionally cause import hidet. In one process we can reload hidet only if it actually changed.

We could spend some efforts to reduce the import time of hidet as a standalone job. In my impression, most of the time is spent in adding primitive functions (we construct all primitive functions at import-time).

The reason of long import is definitely register primitive functions and situation stayed really bad after wgmma registration. Bolin is working on speed up it. But there 5K+ different wgmma to register.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In one process we can reload hidet only if it actually changed.

That's challenging. Different commits have different versions of hidet. How do you define "actually changed"?

The reason of long import is definitely register primitive functions and situation stayed really bad after wgmma registration. Bolin is working on speed up it. But there 5K+ different wgmma to register.

Yeah. Let's work towards this direction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's challenging. Different commits have different versions of hidet. How do you define "actually changed"?

we have a folder with hidet in hands. We can remember this folder. If on next "wish" to import hidet folder not changed we can skip import. Did I miss something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's challenging. Different commits have different versions of hidet. How do you define "actually changed"?

we have a folder with hidet in hands. We can remember this folder. If on next "wish" to import hidet folder not changed we can skip import. Did I miss something?

I see your point. If that's what you want to support, we could

  1. maintain a pool of compilation workers, like 20 or similar number. each worker has already imported some version of hidet.
  2. once the compilation server recevied one job, and its hidet version matches one of the worker, we deliver this job to the matched worker.
  3. if no matched worker, then terminate the oldest one and lauch a new one with given version of hidet.

This makes the the worker and server isolated.

@vadiklyutiy
Copy link
Collaborator Author

I dig deeper in the question of libs load/unload

The following test

import os
import sys
import subprocess
import importlib

def unimport_hidet(module_name):
    for module_name in list(sys.modules.keys()):
        if module_name.startswith("hidet"):
            module = sys.modules.get(module_name)
            if hasattr(module, "cleanup"):
                module.cleanup()
            del sys.modules[module_name]
    importlib.invalidate_caches()  # Ensure the import cache is cleared

def print_libs_info():
    pid = os.getpid()
    result = subprocess.run(f"ls -l /proc/{pid}/map_files | grep hidet", 
                            shell=True, capture_output=True, text=True)
    print("STDOUT:", result.stdout, flush=True)
    print("_LIB", hidet.ffi.ffi._LIB, flush=True)
    print("_LIB_RUNTIME", hidet.ffi.ffi._LIB_RUNTIME, flush=True)   
    print("_LIB_HIDET_TORCH_WRAPPER", hidet.ffi.ffi._LIB_HIDET_TORCH_WRAPPER, flush=True)      

sys.path.insert(0, os.path.abspath("/home/ubuntu/hidet/python"))
import hidet
print_libs_info()  

sys.path.pop(0) 

sys.path.insert(0, os.path.abspath("/home/ubuntu/hidet3/python"))
unimport_hidet("hidet")
import hidet
print_libs_info() 

produce the following result

ubuntu@ip-172-31-31-27:~/small_scripts/import_test$ python import_test.py
STDOUT: lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5dea5000-7fdd5deb9000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5deb9000-7fdd5def0000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5def0000-7fdd5defd000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5defd000-7fdd5defe000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5defe000-7fdd5deff000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66044000-7fdd66045000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66045000-7fdd66046000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66046000-7fdd66047000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66047000-7fdd66048000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66048000-7fdd66049000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66049000-7fdd6604a000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604a000-7fdd6604b000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604b000-7fdd6604c000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604c000-7fdd6604d000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604d000-7fdd6604e000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so

_LIB <CDLL '/home/ubuntu/hidet/build/lib/libhidet.so', handle 5930795f7410 at 0x7fdd5cba6f20>
_LIB_RUNTIME <CDLL '/home/ubuntu/hidet/build/lib/libhidet_runtime.so', handle 59307960acd0 at 0x7fdd5cba6da0>
_LIB_HIDET_TORCH_WRAPPER <CDLL '/home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so', handle 5930795f6b20 at 0x7fdd5cba6dd0>
STDOUT: lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdc8b5a5000-7fdc8b5b9000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdc8b5b9000-7fdc8b5f0000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdc8b5f0000-7fdc8b5fd000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdc8b5fd000-7fdc8b5fe000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdc8b5fe000-7fdc8b5ff000 -> /home/ubuntu/hidet3/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5b662000-7fdd5b663000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5b663000-7fdd5b664000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5b664000-7fdd5b665000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5b665000-7fdd5b666000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5b666000-7fdd5b667000 -> /home/ubuntu/hidet3/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5ce2c000-7fdd5ce2d000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5ce2d000-7fdd5ce2e000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5ce2e000-7fdd5ce2f000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5ce2f000-7fdd5ce30000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5ce30000-7fdd5ce31000 -> /home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5dea5000-7fdd5deb9000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5deb9000-7fdd5def0000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5def0000-7fdd5defd000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5defd000-7fdd5defe000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd5defe000-7fdd5deff000 -> /home/ubuntu/hidet/build/lib/libhidet_runtime.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66044000-7fdd66045000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66045000-7fdd66046000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66046000-7fdd66047000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66047000-7fdd66048000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66048000-7fdd66049000 -> /home/ubuntu/hidet/build/lib/libhidet.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd66049000-7fdd6604a000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604a000-7fdd6604b000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604b000-7fdd6604c000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604c000-7fdd6604d000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so
lr-------- 1 ubuntu ubuntu 64 Jan  7 06:39 7fdd6604d000-7fdd6604e000 -> /home/ubuntu/hidet/build/lib/libhidet_torch_wrapper.so

_LIB <CDLL '/home/ubuntu/hidet3/build/lib/libhidet.so', handle 59307f826cf0 at 0x7fdc8b66a4d0>
_LIB_RUNTIME <CDLL '/home/ubuntu/hidet3/build/lib/libhidet_runtime.so', handle 59307f825510 at 0x7fdc8b66a530>
_LIB_HIDET_TORCH_WRAPPER <CDLL '/home/ubuntu/hidet3/build/lib/libhidet_torch_wrapper.so', handle 59307f826440 at 0x7fdc8b66a470>

So, the new version of libs are loaded and our vars representing them reference to the new version.
Only one issue is that previous version is not unload right now. But it might be simply fixed with register cleanup function.

@vadiklyutiy
Copy link
Collaborator Author

I spend some time but was not able to unload libraries. After deletion there are some refs to lib and it is not unloaded.

From other point of view all libs consume around 320KB. Even if we have 100 different commits we waste only 32MB of memory.
@yaoyaoding May we waive from unloading? IMO not so good, but not a critical issue.

Regarding necessity. Frequently compilation of one candidate is not so long - several seconds. Even if we speed up import and it will take 1 sec, it is still does make sense to keep this speed up.

@yaoyaoding
Copy link
Member

yaoyaoding commented Jan 8, 2025

Thanks @vadiklyutiy for the more detailed exploration. After some time of thinking, I still think "reloading" is not the right direction to go to speedup the compilation server. The isolation of compilation server and specific compilation worker make any unexpected behavior in the new commit of hidet isolated from the main compilation server. Thus, even the hidet make the interpreter crash, our main server is still safe.

To perform the effective reloading, we also need to reconstrct all the primitives and it will take the same amount of time. This reconstruction is necessary because we might have changed the code of constructing primitives.

@vadiklyutiy
Copy link
Collaborator Author

To perform the effective reloading, we also need to reconstrct all the primitives and it will take the same amount of time. This reconstruction is necessary because we might have changed the code of constructing primitives.

The speed up is about when nothing changes. When fixed hidet hash is used. I believe in customer use it would be most common case.

@vadiklyutiy
Copy link
Collaborator Author

The isolation of compilation server and specific compilation worker make any unexpected behavior in the new commit of hidet isolated from the main compilation server. Thus, even the hidet make the interpreter crash, our main server is still safe.

Agree with that. Some garbage commit may kill. omp server.
As a alternative every comp server worker can run standalone process for every hidet hash seen and keep it alive. But this seems as sufficient complication and not one day changes.

@vadiklyutiy
Copy link
Collaborator Author

@yaoyaoding
what do you think about the following option.

We can make some option/mode for comp server when it support only one hidet commit hash and avoid multiply import for such case.
I believe the case when we have only one hash is really very frequently.

@yaoyaoding
Copy link
Member

@yaoyaoding what do you think about the following option.

We can make some option/mode for comp server when it support only one hidet commit hash and avoid multiply import for such case. I believe the case when we have only one hash is really very frequently.

I agree that it's common that only one version of hidet will be used in a short time window. I asked ChatGPT to write some sample code to implement the worker pool:

import multiprocessing
import os
import time
import importlib
import sys

# Worker function to handle jobs for a specific version of `hidet`
def worker_process(version, job_queue, result_queue):
    sys.path.insert(0, version)  # Ensure the version path is first in sys.path
    hidet = importlib.import_module("hidet")  # Load the specific version
    importlib.reload(hidet)  # Reload to ensure correct version

    print(f"Worker {multiprocessing.current_process().name} loaded hidet version from {version}")

    while True:
        job = job_queue.get()
        if job == "STOP":
            print(f"Shutting down worker for version: {version}")
            break

        # Simulate compilation job
        job_id, task = job
        print(f"Worker {multiprocessing.current_process().name} processing job {job_id} with hidet version {version}")
        result = f"Result of {task} using hidet version {version}"
        result_queue.put((job_id, result))

# Main server to manage job distribution and worker pool
class CompilationServer:
    def __init__(self, max_workers=20):
        self.max_workers = max_workers
        self.workers = {}  # {version_path: (worker_process, job_queue)}
        self.result_queue = multiprocessing.Queue()

    def get_or_create_worker(self, version_path):
        # If a worker for the version exists, return it
        if version_path in self.workers:
            return self.workers[version_path]

        # If the worker pool is full, remove the oldest worker
        if len(self.workers) >= self.max_workers:
            old_version, (worker, job_queue) = self.workers.popitem()
            job_queue.put("STOP")  # Send shutdown signal to the old worker
            worker.join()  # Wait for it to exit

        # Create a new worker for the version
        job_queue = multiprocessing.Queue()
        worker = multiprocessing.Process(target=worker_process, args=(version_path, job_queue, self.result_queue))
        worker.start()
        self.workers[version_path] = (worker, job_queue)
        return self.workers[version_path]

    def submit_job(self, job_id, version_path, task):
        _, job_queue = self.get_or_create_worker(version_path)
        job_queue.put((job_id, task))

    def get_result(self):
        while not self.result_queue.empty():
            job_id, result = self.result_queue.get()
            print(f"Result for job {job_id}: {result}")

    def shutdown(self):
        for version_path, (worker, job_queue) in self.workers.items():
            job_queue.put("STOP")
            worker.join()
        print("Server shut down.")

# Example usage
if __name__ == "__main__":
    server = CompilationServer(max_workers=5)

    # Simulate receiving compilation jobs for different versions
    versions = ["hidet_v1", "hidet_v2", "hidet_v3"]
    os.makedirs("hidet_v1", exist_ok=True)
    os.makedirs("hidet_v2", exist_ok=True)
    os.makedirs("hidet_v3", exist_ok=True)

    try:
        for i in range(10):
            version = versions[i % len(versions)]
            server.submit_job(job_id=i, version_path=version, task=f"compile_task_{i}")

        time.sleep(2)
        server.get_result()
    finally:
        server.shutdown()

How do you think to adopt above POC to our server?

@vadiklyutiy
Copy link
Collaborator Author

I recently asked both gtp4 and somet to change dll load from ctype.cddl to cffi, both atemtps failed even after my manual fixes :-D

To avoid misunderstanding. You prefer pool of workers, every worker hold a separate version of hidet(as you described above) instead of make option for compserver that support only one version of hidet. Right?

@yaoyaoding
Copy link
Member

I recently asked both gtp4 and somet to change dll load from ctype.cddl to cffi, both atemtps failed even after my manual fixes :-D

To avoid misunderstanding. You prefer pool of workers, every worker hold a separate version of hidet(as you described above) instead of make option for compserver that support only one version of hidet. Right?

Yeah, the complexity between only supporting one version of hidet vs. with a pool should be similar since both should be run in seperate processes.

@vadiklyutiy vadiklyutiy force-pushed the vadim/fix-comp-server branch from aa87120 to 50ac5c2 Compare January 12, 2025 05:36
@vadiklyutiy
Copy link
Collaborator Author

Yaoyao's proposal implemented in #489

@vadiklyutiy vadiklyutiy deleted the vadim/fix-comp-server branch January 28, 2025 11:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Task Or "Story" in JIRA's term.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants