diff --git a/Makefile b/Makefile index bc42e8beaf9..a7487958613 100644 --- a/Makefile +++ b/Makefile @@ -87,21 +87,22 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-cpu whisper-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal whisper-cuda whisper-debug-cuda whisper-cpu whisper-metal llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help help: - @echo "This Makefile adds targets to build runners for various models on various backends. Run using `make `. Available targets:" - @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" - @echo " voxtral-cpu - Build Voxtral runner with CPU backend" - @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" - @echo " whisper-cuda - Build Whisper runner with CUDA backend" - @echo " whisper-cpu - Build Whisper runner with CPU backend" - @echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)" - @echo " llama-cpu - Build Llama runner with CPU backend" - @echo " llava-cpu - Build Llava runner with CPU backend" - @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" - @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" - @echo " clean - Clean build artifacts" + @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" + @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" + @echo " voxtral-cpu - Build Voxtral runner with CPU backend" + @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" + @echo " whisper-cuda - Build Whisper runner with CUDA backend" + @echo " whisper-debug-cuda - Build Whisper runner with CUDA backend (debug mode)" + @echo " whisper-cpu - Build Whisper runner with CPU backend" + @echo " whisper-metal - Build Whisper runner with Metal backend (macOS only)" + @echo " llama-cpu - Build Llama runner with CPU backend" + @echo " llava-cpu - Build Llava runner with CPU backend" + @echo " gemma3-cuda - Build Gemma3 runner with CUDA backend" + @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" + @echo " clean - Clean build artifacts" voxtral-cuda: @echo "==> Building and installing ExecuTorch with CUDA..." @@ -139,6 +140,15 @@ whisper-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" +whisper-debug-cuda: + @echo "==> Building and installing ExecuTorch with CUDA (debug mode)..." + cmake --workflow --preset llm-debug-cuda + @echo "==> Building Whisper runner with CUDA (debug mode)..." + cd examples/models/whisper && cmake --workflow --preset whisper-debug-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/whisper/whisper_runner" + whisper-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 675a9864e74..4dbc7794deb 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -59,11 +59,18 @@ aoti_torch_get_device_index(Tensor* tensor, int32_t* ret_device_index); AOTI_SHIM_EXPORT AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim); +// Device type query (optional; backends may implement this to return device +// type constants such as CUDA). Declared here so common shims can dispatch +// appropriately. +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_get_device_type(Tensor* tensor, int32_t* ret_device_type); + // Utility functions for device and layout information AOTI_SHIM_EXPORT int32_t aoti_torch_device_type_cpu(); AOTI_SHIM_EXPORT int32_t aoti_torch_layout_strided(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index aaaf3913381..00b7ccdf611 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -24,6 +24,7 @@ namespace executorch::backends::cuda { using executorch::aten::SizesType; using executorch::aten::StridesType; +using executorch::backends::aoti::aoti_torch_dtype_bool; using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; using executorch::backends::aoti::aoti_torch_get_sizes; @@ -682,6 +683,50 @@ AOTITorchError aoti_torch__reinterpret_tensor( return Error::Ok; } +// item implementation for scalar tensors +AOTITorchError aoti_torch_item_bool(Tensor* tensor, bool* ret_value) { + // Validate that tensor is 0D + ET_CHECK_OR_RETURN_ERROR( + tensor->dim() == 0, + InvalidArgument, + "aoti_torch_item_bool failed: tensor is not 0D"); + + // Validate that tensor dtype is bool + int32_t dtype; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(tensor, &dtype)); + + ET_CHECK_OR_RETURN_ERROR( + dtype == aoti_torch_dtype_bool(), // PyTorch dtype code for bool + InvalidArgument, + "aoti_torch_item_bool failed: tensor dtype is not bool"); + + // Get the data pointer + const void* data_ptr = tensor->const_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "aoti_torch_item_bool failed: tensor data pointer is null"); + + // Check if tensor is on CUDA or CPU + cudaPointerAttributes attributes{}; + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&attributes, data_ptr)); + + if (attributes.type == cudaMemoryTypeDevice) { + // CUDA memory case: copy from device to host + bool device_value; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + &device_value, data_ptr, sizeof(bool), cudaMemcpyDeviceToHost)); + *ret_value = device_value; + } else { + // CPU memory case: direct access + const bool* bool_ptr = static_cast(data_ptr); + *ret_value = *bool_ptr; + } + + return Error::Ok; +} + AOTITorchError aoti_torch_new_tensor_handle( Tensor* orig_handle, Tensor** new_handle) { @@ -771,6 +816,75 @@ AOTITorchError aoti_torch_new_tensor_handle( return Error::Ok; } + +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) { + if (src == nullptr || ret_dst == nullptr) { + return Error::InvalidArgument; + } + + // Get the original data pointer from the source tensor + void* data_ptr = src->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map, if not return error + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + int32_t dtype = 0; + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(src, &dtype)); + + std::vector sizes; + std::vector strides; + + int64_t* view_sizes_ptr; + if (aoti_torch_get_sizes(src, &view_sizes_ptr) != Error::Ok) { + return Error::Internal; + } + int64_t* view_strides_ptr; + if (aoti_torch_get_strides(src, &view_strides_ptr) != Error::Ok) { + return Error::Internal; + } + sizes = convert_sizes_to_vector(src->dim(), view_sizes_ptr); + strides = + convert_strides_to_vector(src->dim(), view_sizes_ptr, view_strides_ptr); + + // Create new tensor view that reinterprets the same memory with different + // shape/strides This creates a view, not a copy - the data pointer is shared + // Using CUDA-specific tensor maker that supports incontiguous tensors + std::shared_ptr tensor = make_tensor( + sizes, // New sizes with explicit SizesType + data_ptr, // Reuse the same memory from source tensor + {}, // dim_order (empty, will be auto-generated) + strides, // New strides with explicit StridesType + dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, + InvalidArgument, + "Failed to create reinterpreted tensor view"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *ret_dst = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + return Error::Ok; +} } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 1a89d8b782c..ed408c10897 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -161,9 +161,30 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); * @return Error::Ok on success, appropriate error code on failure: * - Error::InvalidArgument: null pointers or invalid parameters */ -AOTITorchError aoti_torch_new_tensor_handle( - Tensor* orig_handle, - Tensor** new_handle); +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle); + +// Function to retrieve boolean value from a 0D boolean tensor +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_item_bool(Tensor* tensor, bool* ret_value); + +/** + * Reinterprets the destination tensor to be a view of the source tensor's + * data. + * + * This function makes the destination tensor a view of the source tensor's + * underlying data, but with the destination tensor's original shape and + * strides. The number of elements in both tensors must match. The destination + * tensor handle will be updated to point to a new tensor view. + * + * @param src The source tensor providing the data. + * @param ret_dst On input, a pointer to the destination tensor. On output, + * this will be updated to point to the new tensor view. + * + * @return Error::Ok on success, appropriate error code on failure. + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst); // Function to clear all tensors from internal storage AOTI_SHIM_EXPORT void clear_all_tensors(); diff --git a/examples/models/whisper/CMakePresets.json b/examples/models/whisper/CMakePresets.json index a081ad5be29..a7644d14c97 100644 --- a/examples/models/whisper/CMakePresets.json +++ b/examples/models/whisper/CMakePresets.json @@ -28,6 +28,20 @@ "rhs": "Linux" } }, + { + "name": "whisper-debug-cuda", + "displayName": "Whisper runner (CUDA Debug)", + "inherits": ["whisper-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON", + "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Linux" + } + }, { "name": "whisper-metal", "displayName": "Whisper runner (Metal)", @@ -55,6 +69,12 @@ "configurePreset": "whisper-cuda", "targets": ["whisper_runner"] }, + { + "name": "whisper-debug-cuda", + "displayName": "Build Whisper runner (CUDA Debug)", + "configurePreset": "whisper-debug-cuda", + "targets": ["whisper_runner"] + }, { "name": "whisper-metal", "displayName": "Build Whisper runner (Metal)", @@ -91,6 +111,20 @@ } ] }, + { + "name": "whisper-debug-cuda", + "displayName": "Configure and build Whisper runner (CUDA Debug)", + "steps": [ + { + "type": "configure", + "name": "whisper-debug-cuda" + }, + { + "type": "build", + "name": "whisper-debug-cuda" + } + ] + }, { "name": "whisper-metal", "displayName": "Configure and build Whisper runner (Metal)",