Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor JIT and AOT build script #567

Merged
merged 9 commits into from
Oct 30, 2024

Conversation

abcdabcd987
Copy link
Member

@abcdabcd987 abcdabcd987 commented Oct 29, 2024

Previously, JIT and AOT packaging is a bit broken. This PR produces good sdist for JIT mode, and wheel for AOT mode.

Changes

Common changes:

  1. Remove the symlinks. Symlinks causes lots of duplication when search in VSCode.
  2. In package distribution (sdist or wheel), add data files to python/flashinfer/data/, i.e. inside the python package folder. This is strongly recommended by setuptools.
    • Data files include: version.txt, FlashInfer headers, Cutlass headers.
    • Symlinks will be created when building wheel, and will be removed when finished unless it's using develop command.
  3. Exclude unneeded cutlass docs and files from wheel and sdist.

AOT changes:

  1. Remove flashinfer-aot dir. Contents are moved to python/.
  2. Merge all kernels into one pybind. This is good for compilation speed. (_kernels_sm90 is preserved as a separated .so file.)
  3. AOT wheel can now be built with the following command:
    cd flashinfer/python
    TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py bdist_wheel
    ls -la dist/
  4. AOT wheel can also be built for editable install (develop purpose)
    cd flashinfer/python
    TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" python3 aot_setup.py develop

JIT changes:

  1. JIT mode can now be installed in various ways:
    cd flashinfer/python
    pip install -v .     # Regular install from source
    pip install -v -e .  # Editable install
    
    python -m build --sdist               # Build sdist
    pip install dist/flashinfer-*.tar.gz  # Install from sdist

Directory structure of built package

See attached.
dir-wheel.txt
dir-sdist.txt

Tests

I was able to pass pytest -sv test_norm.py test_bmm_fp8.py using various way of installation:

  1. Editable install
  2. Regular install from source
  3. Install from sdist
  4. Install from wheel

@abcdabcd987 abcdabcd987 requested a review from yzh119 October 29, 2024 19:41
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still tend to separate the compilation of sm_90a sources codes with others.

@abcdabcd987
Copy link
Member Author

Separated out _kernels_sm90

@zhyncs
Copy link
Member

zhyncs commented Oct 30, 2024

QQ What are the benefits of making these changes?

@abcdabcd987
Copy link
Member Author

What are the benefits of making these changes?

  1. Fix the broken wheel and sdist (I'll explain this later).
  2. Better directory structure -- No more symlink floating everywhere. This makes vscode work much better.
  3. Faster AOT compilation -- Merging _kernels, _decode_kernels, and _prefill_kernels allows parallel compilation for all kernels. Otherwise, these three PyTorch extensions will compile one by one. Also, this reduces the pybind11 compilation times from three times to once.

How current distribution is broken

Based on the current main (3e104bc)

AOT

Archive:  dist/flashinfer-0.1.6-cp310-cp310-linux_x86_64.whl
  Length      Date    Time    Name
---------  ---------- -----   ----
     3447  2024-10-30 03:26   flashinfer/__init__.py
      145  2024-10-30 03:27   flashinfer/_build_meta.py
 50296136  2024-10-30 03:31   flashinfer/_decode_kernels.cpython-310-x86_64-linux-gnu.so
  2444160  2024-10-30 03:28   flashinfer/_kernels.cpython-310-x86_64-linux-gnu.so
   511384  2024-10-30 03:30   flashinfer/_kernels_sm90.cpython-310-x86_64-linux-gnu.so
240830648  2024-10-30 03:32   flashinfer/_prefill_kernels.cpython-310-x86_64-linux-gnu.so
     5777  2024-10-30 03:26   flashinfer/activation.py
    39177  2024-10-30 03:26   flashinfer/cascade.py
    39614  2024-10-30 03:26   flashinfer/decode.py
    15113  2024-10-30 03:26   flashinfer/gemm.py
     5336  2024-10-30 03:26   flashinfer/norm.py
     6684  2024-10-30 03:26   flashinfer/page.py
    69984  2024-10-30 03:26   flashinfer/prefill.py
        0  2024-10-30 03:26   flashinfer/py.typed
     4739  2024-10-30 03:26   flashinfer/quantization.py
    29416  2024-10-30 03:26   flashinfer/rope.py
    44031  2024-10-30 03:26   flashinfer/sampling.py
    19672  2024-10-30 03:26   flashinfer/sparse.py
     7176  2024-10-30 03:26   flashinfer/utils.py
     5101  2024-10-30 03:26   flashinfer/jit/__init__.py
     2380  2024-10-30 03:26   flashinfer/jit/activation.py
    91088  2024-10-30 03:27   flashinfer/jit/aot_config.py
    13531  2024-10-30 03:26   flashinfer/jit/attention.py
     7114  2024-10-30 03:26   flashinfer/jit/batch_decode_templ.py
    12751  2024-10-30 03:26   flashinfer/jit/batch_prefill_templ.py
     1450  2024-10-30 03:26   flashinfer/jit/env.py
     7650  2024-10-30 03:26   flashinfer/jit/single_decode_templ.py
     8102  2024-10-30 03:26   flashinfer/jit/single_prefill_templ.py
     1708  2024-10-30 03:26   flashinfer/jit/utils.py
       22  2024-10-30 03:26   flashinfer/triton/__init__.py
     3870  2024-10-30 03:26   flashinfer/triton/cascade.py
      784  2024-10-30 03:26   flashinfer/triton/utils.py
      237  2024-10-30 03:32   flashinfer-0.1.6.dist-info/METADATA
      104  2024-10-30 03:32   flashinfer-0.1.6.dist-info/WHEEL
       11  2024-10-30 03:32   flashinfer-0.1.6.dist-info/top_level.txt
     3070  2024-10-30 03:33   flashinfer-0.1.6.dist-info/RECORD
---------                     -------
294531612                     36 files

JIT sources (FlashInfer headers, cutlass headers) is missing from wheel.

JIT

dir-3e104bc-sdist.txt

JIT sdist does have the JIT headers, but it's not in the package folder. When pip install the sdist, the headers are not installed.

# ls -la /root/miniforge3/envs/flashinfer-main/lib/python3.10/site-packages/flashinfer/
total 332
drwxr-xr-x  5 root root  4096 Oct 30 03:41 .
drwxr-xr-x 52 root root  4096 Oct 30 03:41 ..
-rw-r--r--  1 root root  3447 Oct 30 03:41 __init__.py
drwxr-xr-x  2 root root  4096 Oct 30 03:41 __pycache__
-rw-r--r--  1 root root    22 Oct 30 03:41 _build_meta.py
-rw-r--r--  1 root root  5777 Oct 30 03:41 activation.py
-rw-r--r--  1 root root 39177 Oct 30 03:41 cascade.py
-rw-r--r--  1 root root 39614 Oct 30 03:41 decode.py
-rw-r--r--  1 root root 15113 Oct 30 03:41 gemm.py
drwxr-xr-x  3 root root  4096 Oct 30 03:41 jit
-rw-r--r--  1 root root  5336 Oct 30 03:41 norm.py
-rw-r--r--  1 root root  6684 Oct 30 03:41 page.py
-rw-r--r--  1 root root 69984 Oct 30 03:41 prefill.py
-rw-r--r--  1 root root     0 Oct 30 03:41 py.typed
-rw-r--r--  1 root root  4739 Oct 30 03:41 quantization.py
-rw-r--r--  1 root root 29416 Oct 30 03:41 rope.py
-rw-r--r--  1 root root 44031 Oct 30 03:41 sampling.py
-rw-r--r--  1 root root 19672 Oct 30 03:41 sparse.py
drwxr-xr-x  3 root root  4096 Oct 30 03:41 triton
-rw-r--r--  1 root root  7176 Oct 30 03:41 utils.py

Notably, it's also not in the parent directory.

# ls -la /root/miniforge3/envs/flashinfer-main/lib/python3.10/site-packages/
total 344
drwxr-xr-x 52 root root   4096 Oct 30 03:41 .
drwxr-xr-x 36 root root  12288 Oct 30 03:38 ..
drwxr-xr-x  2 root root   4096 Oct 30 03:40 Jinja2-3.1.3.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 MarkupSafe-2.1.5.dist-info
-rw-rw-r--  3 root root    119 Oct 16 01:41 README.txt
drwxr-xr-x  2 root root   4096 Oct 30 03:39 __pycache__
drwxr-xr-x  3 root root   4096 Oct 30 03:38 _distutils_hack
-rw-rw-r--  3 root root    151 Sep 25 07:31 distutils-precedence.pth
drwxr-xr-x  3 root root   4096 Oct 30 03:39 filelock
drwxr-xr-x  3 root root   4096 Oct 30 03:39 filelock-3.13.1.dist-info
drwxr-xr-x  5 root root   4096 Oct 30 03:41 flashinfer
drwxr-xr-x  2 root root   4096 Oct 30 03:41 flashinfer-0.1.6.dist-info
drwxr-xr-x  5 root root   4096 Oct 30 03:39 fsspec
drwxr-xr-x  2 root root   4096 Oct 30 03:39 fsspec-2024.2.0.dist-info
drwxr-xr-x  8 root root   4096 Oct 30 03:40 functorch
-rw-r--r--  1 root root  11207 Oct 30 03:39 isympy.py
drwxr-xr-x  3 root root   4096 Oct 30 03:40 jinja2
drwxr-xr-x  3 root root   4096 Oct 30 03:39 markupsafe
drwxr-xr-x  8 root root   4096 Oct 30 03:39 mpmath
drwxr-xr-x  2 root root   4096 Oct 30 03:39 mpmath-1.3.0.dist-info
drwxr-xr-x 11 root root   4096 Oct 30 03:39 networkx
drwxr-xr-x  2 root root   4096 Oct 30 03:39 networkx-3.2.1.dist-info
drwxr-xr-x  4 root root   4096 Oct 30 03:40 ninja
drwxr-xr-x  2 root root   4096 Oct 30 03:40 ninja-1.11.1.1.dist-info
drwxr-xr-x 25 root root   4096 Oct 30 03:40 numpy
drwxr-xr-x  2 root root   4096 Oct 30 03:40 numpy-2.1.2.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:40 numpy.libs
drwxr-xr-x 15 root root   4096 Oct 30 03:40 nvidia
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_cublas_cu12-12.4.5.8.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_cuda_cupti_cu12-12.4.127.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_cuda_nvrtc_cu12-12.4.127.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_cuda_runtime_cu12-12.4.127.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:40 nvidia_cudnn_cu12-9.1.0.70.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_cufft_cu12-11.2.1.3.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_curand_cu12-10.3.5.147.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:40 nvidia_cusolver_cu12-11.6.1.9.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_cusparse_cu12-12.3.1.170.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_nccl_cu12-2.21.5.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_nvjitlink_cu12-12.4.127.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 nvidia_nvtx_cu12-12.4.127.dist-info
drwxr-xr-x  5 root root   4096 Oct 30 03:38 pip
drwxr-xr-x  2 root root   4096 Oct 30 03:38 pip-24.3.1.dist-info
drwxr-xr-x  4 root root   4096 Oct 30 03:38 pkg_resources
drwxr-xr-x  9 root root   4096 Oct 30 03:38 setuptools
drwxr-xr-x  2 root root   4096 Oct 30 03:38 setuptools-75.1.0-py3.12.egg-info
drwxr-xr-x 43 root root   4096 Oct 30 03:39 sympy
drwxr-xr-x  2 root root   4096 Oct 30 03:39 sympy-1.13.1.dist-info
drwxr-xr-x 62 root root   4096 Oct 30 03:40 torch
drwxr-xr-x  2 root root   4096 Oct 30 03:40 torch-2.5.1+cu124.dist-info
drwxr-xr-x 11 root root   4096 Oct 30 03:40 torchgen
drwxr-xr-x 11 root root   4096 Oct 30 03:39 triton
drwxr-xr-x  2 root root   4096 Oct 30 03:39 triton-3.1.0.dist-info
drwxr-xr-x  2 root root   4096 Oct 30 03:39 typing_extensions-4.9.0.dist-info
-rw-r--r--  1 root root 110125 Oct 30 03:39 typing_extensions.py
drwxr-xr-x  5 root root   4096 Oct 30 03:38 wheel
drwxr-xr-x  2 root root   4096 Oct 30 03:38 wheel-0.44.0.dist-info

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great work, now I think the structure looks more clear.

@yzh119 yzh119 changed the title Refactor JIT and AOT build script refactor: Refactor JIT and AOT build script Oct 30, 2024
@yzh119 yzh119 merged commit 7df90dd into flashinfer-ai:main Oct 30, 2024
tsu-bin added a commit to tsu-bin/flashinfer_dev that referenced this pull request Oct 30, 2024
zhyncs pushed a commit that referenced this pull request Oct 30, 2024
Hi, cpp integration was broken again by #567, please be aware that there
are cpp test, cpp benchmark and also tvm integration, they all relay on
cmake build.

Co-authored-by: tsu-bin <tsubin@gmail.com>
yzh119 pushed a commit that referenced this pull request Nov 2, 2024
My last minute change in #567 changed `link_data_files()` to a context
manager. I didn't properly test it. It would fail due to `this_dir /
"flashinfer" / "data" / "version.txt"` not exist.

This PR fixes the issue. I tested that it works, with `pip install -v -e
.`.

BTW, AOT wheel does not have this issue. I'm able to build AOT wheel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants