Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Research][PyTorch 2.6] Save compiled triton kernel as device binary code #1792

Open
vlad-penkin opened this issue Aug 7, 2024 · 6 comments · Fixed by #2148 · May be fixed by #2350
Open

[Research][PyTorch 2.6] Save compiled triton kernel as device binary code #1792

vlad-penkin opened this issue Aug 7, 2024 · 6 comments · Fixed by #2148 · May be fixed by #2350
Assignees
Labels

Comments

@vlad-penkin
Copy link
Contributor

There is a plan to enable AOT Inductor for Intel GPU in PyTorch 2.6. While working on the design, PyTorch Team realized that the Triton kernel is now saved as SPIR-V(IR), while CUDA is cubin(device code binary) which will affect E2E performance:

Current implementation for the deployment require load up SPIR-V, and then compile it into device binary code through IGC. This will cause will cause more compilation time compared to CUDA when kernel is run in the deployment environment.

PyTorch Team is asking whether Triton can save the compiled kernel as device binary code, and load it with L0 runtime.

@etaf
Copy link

etaf commented Aug 8, 2024

Seems we can retrieve native binary from level zero module using zeModuleGetNativeBinary.

@alexbaden
Copy link
Contributor

Yes, assuming they work L0 has APIs we should be able to leverage.
The difference is in the way the compiler/driver works. For NVIDIA, they can call ptxas to assemble the PTX to cubin and then pass the cubin to their runtime. For us, we actually compile the spirv to machine code during the driver stage. So, I need to either lift the compilation of spirv to native binary out of driver and into compiler, or find a way to get the paths to the driver without breaking triton layering.

@alexbaden
Copy link
Contributor

I wanted to look into this to see if it could be related to #1721, but the numbers don't quite match so I suppose I am not optimistic. Still, this could be a nice win for us as compilation can be 100-300ms, especially if there are register spills and we recompile.

@etaf
Copy link

etaf commented Aug 8, 2024

I think this maybe the solution:

we can retrieve native binary from level zero module using zeModuleGetNativeBinary.
And here is the example: https://github.com/oneapi-src/oneDNN/blob/2e7b691217ff17497aebd7e565fa1701f8a42396/src/gpu/intel/sycl/utils.cpp#L211

Then to reconstruct the L0 model in deployment, create level zero module by set the ze_module_format_t as ZE_MODULE_FORMAT_NATIVE in zeModuleCreate
Here is the example: https://github.com/oneapi-src/oneDNN/blob/2e7b691217ff17497aebd7e565fa1701f8a42396/src/gpu/intel/sycl/l0/utils.cpp#L184

@alexbaden
Copy link
Contributor

I have a prototype working. The level zero APIs are the easy part - we have to make significant changes to our triton compilation flow to fit this into Triton's architecture. Fortunately, I think I can adjust the compilation flow while preserving the existing Triton layering. I will clean up my prototype and post it as a draft PR for review tomorrow.

@vlad-penkin
Copy link
Contributor Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment