Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/optimum-executorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d03e90c2cd9048e6d9a75285c0355f033cd016fc
8967fe914c252bf242b7d0ad4f5e098a007a6993
2 changes: 1 addition & 1 deletion .ci/scripts/test_model_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ case "$HF_MODEL" in
esac

echo "::group::Setup ExecuTorch Requirements"
./install_requirements.sh
# ./install_requirements.sh
pip list
echo "::endgroup::"

Expand Down
8 changes: 5 additions & 3 deletions backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def preprocess(
with open(so_path, "rb") as f:
so_data = f.read()

print("so_path: ", so_path)

# Read weights blob
with open(blob_path, "rb") as f:
blob_data = f.read()
Expand All @@ -229,9 +231,9 @@ def preprocess(
method_name + "_weights_blob", blob_data, 1, weights_blob_data_type
)

# Clean up the generated files
os.remove(so_path)
os.remove(blob_path)
# # Clean up the generated files
# os.remove(so_path)
# os.remove(blob_path)

return PreprocessResult(
processed_bytes=b"",
Expand Down
7 changes: 4 additions & 3 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ install(
set(_aoti_cuda_shim_sources
runtime/shims/memory.cpp runtime/shims/tensor_attribute.cpp
runtime/guard.cpp runtime/shims/cuda_guard.cpp runtime/shims/int4mm.cu
runtime/shims/sdpa.cu
${EXECUTORCH_ROOT}/backends/aoti/common_shims.cpp
)

Expand Down Expand Up @@ -130,12 +131,12 @@ target_link_options(
aoti_cuda_shims PUBLIC $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
)

# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and
# Link against CUDA::cudart, CUDA::cublas, common AOTI library, cuda_tensor_maker, and
# platform utilities
target_link_libraries(
aoti_cuda_shims
PRIVATE cuda_platform
PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS}
PRIVATE cuda_platform executorch_core
PUBLIC extension_tensor cuda_tensor_maker CUDA::cudart CUDA::cublas ${CMAKE_DL_LIBS}
)

if(NOT MSVC)
Expand Down
39 changes: 21 additions & 18 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def get_device_name(cls) -> str:
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
return {
"at::_ops::_weight_int4pack_mm::call": None,
"at::_ops::_scaled_dot_product_flash_attention::call": None,
"at::_ops::_scaled_dot_product_efficient_attention::call": None,
}

@classmethod
Expand Down Expand Up @@ -68,7 +70,8 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
)
triton_kernel_mode = mode

return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
return []
# return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []

@classmethod
def get_aoti_compile_options(
Expand Down Expand Up @@ -134,20 +137,20 @@ def get_aoti_compile_options(

return options

@classmethod
def get_extra_aoti_compile_context_manager(cls):
"""
Return SDPA MATH backend context manager for CUDA compilation.

This context manager plays as a fallback solution for any remaining PyTorch SDPA
operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.

Note:
- If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
this context manager will have no effect on those ops (they are no longer
PyTorch SDPA ops).
- If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
context manager will force them to use the MATH backend, causing them to
be automatically decomposed during compilation.
"""
return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
# @classmethod
# def get_extra_aoti_compile_context_manager(cls):
# """
# Return SDPA MATH backend context manager for CUDA compilation.

# This context manager plays as a fallback solution for any remaining PyTorch SDPA
# operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.

# Note:
# - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
# this context manager will have no effect on those ops (they are no longer
# PyTorch SDPA ops).
# - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
# context manager will force them to use the MATH backend, causing them to
# be automatically decomposed during compilation.
# """
# return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
4 changes: 4 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ runtime.cxx_library(
"shims/cuda_guard.cpp",
"shims/int4mm.cu",
"shims/memory.cpp",
"shims/sdpa.cu",
"shims/tensor_attribute.cpp",
],
headers = [
Expand All @@ -61,6 +62,8 @@ runtime.cxx_library(
"shims/int4mm.cuh",
"shims/int4mm.h",
"shims/memory.h",
"shims/sdpa.cuh",
"shims/sdpa.h",
"shims/tensor_attribute.h",
"utils.h",
],
Expand All @@ -84,6 +87,7 @@ runtime.cxx_library(
],
external_deps = [
("cuda", None, "cuda-lazy"),
("cuda", None, "cublas-lazy"),
],
)

Expand Down
Loading
Loading