-
Notifications
You must be signed in to change notification settings - Fork 164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: Refactor JIT and AOT build script #567
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still tend to separate the compilation of sm_90a
sources codes with others.
Separated out |
QQ What are the benefits of making these changes? |
How current distribution is brokenBased on the current main ( AOT
JIT sources (FlashInfer headers, cutlass headers) is missing from wheel. JITJIT sdist does have the JIT headers, but it's not in the package folder. When # ls -la /root/miniforge3/envs/flashinfer-main/lib/python3.10/site-packages/flashinfer/
total 332
drwxr-xr-x 5 root root 4096 Oct 30 03:41 .
drwxr-xr-x 52 root root 4096 Oct 30 03:41 ..
-rw-r--r-- 1 root root 3447 Oct 30 03:41 __init__.py
drwxr-xr-x 2 root root 4096 Oct 30 03:41 __pycache__
-rw-r--r-- 1 root root 22 Oct 30 03:41 _build_meta.py
-rw-r--r-- 1 root root 5777 Oct 30 03:41 activation.py
-rw-r--r-- 1 root root 39177 Oct 30 03:41 cascade.py
-rw-r--r-- 1 root root 39614 Oct 30 03:41 decode.py
-rw-r--r-- 1 root root 15113 Oct 30 03:41 gemm.py
drwxr-xr-x 3 root root 4096 Oct 30 03:41 jit
-rw-r--r-- 1 root root 5336 Oct 30 03:41 norm.py
-rw-r--r-- 1 root root 6684 Oct 30 03:41 page.py
-rw-r--r-- 1 root root 69984 Oct 30 03:41 prefill.py
-rw-r--r-- 1 root root 0 Oct 30 03:41 py.typed
-rw-r--r-- 1 root root 4739 Oct 30 03:41 quantization.py
-rw-r--r-- 1 root root 29416 Oct 30 03:41 rope.py
-rw-r--r-- 1 root root 44031 Oct 30 03:41 sampling.py
-rw-r--r-- 1 root root 19672 Oct 30 03:41 sparse.py
drwxr-xr-x 3 root root 4096 Oct 30 03:41 triton
-rw-r--r-- 1 root root 7176 Oct 30 03:41 utils.py Notably, it's also not in the parent directory. # ls -la /root/miniforge3/envs/flashinfer-main/lib/python3.10/site-packages/
total 344
drwxr-xr-x 52 root root 4096 Oct 30 03:41 .
drwxr-xr-x 36 root root 12288 Oct 30 03:38 ..
drwxr-xr-x 2 root root 4096 Oct 30 03:40 Jinja2-3.1.3.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 MarkupSafe-2.1.5.dist-info
-rw-rw-r-- 3 root root 119 Oct 16 01:41 README.txt
drwxr-xr-x 2 root root 4096 Oct 30 03:39 __pycache__
drwxr-xr-x 3 root root 4096 Oct 30 03:38 _distutils_hack
-rw-rw-r-- 3 root root 151 Sep 25 07:31 distutils-precedence.pth
drwxr-xr-x 3 root root 4096 Oct 30 03:39 filelock
drwxr-xr-x 3 root root 4096 Oct 30 03:39 filelock-3.13.1.dist-info
drwxr-xr-x 5 root root 4096 Oct 30 03:41 flashinfer
drwxr-xr-x 2 root root 4096 Oct 30 03:41 flashinfer-0.1.6.dist-info
drwxr-xr-x 5 root root 4096 Oct 30 03:39 fsspec
drwxr-xr-x 2 root root 4096 Oct 30 03:39 fsspec-2024.2.0.dist-info
drwxr-xr-x 8 root root 4096 Oct 30 03:40 functorch
-rw-r--r-- 1 root root 11207 Oct 30 03:39 isympy.py
drwxr-xr-x 3 root root 4096 Oct 30 03:40 jinja2
drwxr-xr-x 3 root root 4096 Oct 30 03:39 markupsafe
drwxr-xr-x 8 root root 4096 Oct 30 03:39 mpmath
drwxr-xr-x 2 root root 4096 Oct 30 03:39 mpmath-1.3.0.dist-info
drwxr-xr-x 11 root root 4096 Oct 30 03:39 networkx
drwxr-xr-x 2 root root 4096 Oct 30 03:39 networkx-3.2.1.dist-info
drwxr-xr-x 4 root root 4096 Oct 30 03:40 ninja
drwxr-xr-x 2 root root 4096 Oct 30 03:40 ninja-1.11.1.1.dist-info
drwxr-xr-x 25 root root 4096 Oct 30 03:40 numpy
drwxr-xr-x 2 root root 4096 Oct 30 03:40 numpy-2.1.2.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:40 numpy.libs
drwxr-xr-x 15 root root 4096 Oct 30 03:40 nvidia
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_cublas_cu12-12.4.5.8.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_cuda_cupti_cu12-12.4.127.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_cuda_nvrtc_cu12-12.4.127.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_cuda_runtime_cu12-12.4.127.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:40 nvidia_cudnn_cu12-9.1.0.70.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_cufft_cu12-11.2.1.3.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_curand_cu12-10.3.5.147.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:40 nvidia_cusolver_cu12-11.6.1.9.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_cusparse_cu12-12.3.1.170.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_nccl_cu12-2.21.5.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_nvjitlink_cu12-12.4.127.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 nvidia_nvtx_cu12-12.4.127.dist-info
drwxr-xr-x 5 root root 4096 Oct 30 03:38 pip
drwxr-xr-x 2 root root 4096 Oct 30 03:38 pip-24.3.1.dist-info
drwxr-xr-x 4 root root 4096 Oct 30 03:38 pkg_resources
drwxr-xr-x 9 root root 4096 Oct 30 03:38 setuptools
drwxr-xr-x 2 root root 4096 Oct 30 03:38 setuptools-75.1.0-py3.12.egg-info
drwxr-xr-x 43 root root 4096 Oct 30 03:39 sympy
drwxr-xr-x 2 root root 4096 Oct 30 03:39 sympy-1.13.1.dist-info
drwxr-xr-x 62 root root 4096 Oct 30 03:40 torch
drwxr-xr-x 2 root root 4096 Oct 30 03:40 torch-2.5.1+cu124.dist-info
drwxr-xr-x 11 root root 4096 Oct 30 03:40 torchgen
drwxr-xr-x 11 root root 4096 Oct 30 03:39 triton
drwxr-xr-x 2 root root 4096 Oct 30 03:39 triton-3.1.0.dist-info
drwxr-xr-x 2 root root 4096 Oct 30 03:39 typing_extensions-4.9.0.dist-info
-rw-r--r-- 1 root root 110125 Oct 30 03:39 typing_extensions.py
drwxr-xr-x 5 root root 4096 Oct 30 03:38 wheel
drwxr-xr-x 2 root root 4096 Oct 30 03:38 wheel-0.44.0.dist-info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the great work, now I think the structure looks more clear.
My last minute change in #567 changed `link_data_files()` to a context manager. I didn't properly test it. It would fail due to `this_dir / "flashinfer" / "data" / "version.txt"` not exist. This PR fixes the issue. I tested that it works, with `pip install -v -e .`. BTW, AOT wheel does not have this issue. I'm able to build AOT wheel.
Previously, JIT and AOT packaging is a bit broken. This PR produces good sdist for JIT mode, and wheel for AOT mode.
Changes
Common changes:
python/flashinfer/data/
, i.e. inside the python package folder. This is strongly recommended by setuptools.version.txt
, FlashInfer headers, Cutlass headers.develop
command.AOT changes:
flashinfer-aot
dir. Contents are moved topython/
._kernels_sm90
is preserved as a separated.so
file.)JIT changes:
Directory structure of built package
See attached.
dir-wheel.txt
dir-sdist.txt
Tests
I was able to pass
pytest -sv test_norm.py test_bmm_fp8.py
using various way of installation: