-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add stable parallel_for #161320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/mikaylagawarecki/337/base
Are you sure you want to change the base?
Add stable parallel_for #161320
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161320
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 19ab872 with merge base f06e669 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
The current state of the world is that there are two parallel backends in torch, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined both at libtorch build time and extension build time (meaning that **both libtorch and extension link against OPENMP**, for example see https://github.com/pytorch/audio/pull/1761/files and pytorch/vision#2783) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 Looking at the objdump of the test call to parallel_for I add in kernel.cpp when compiling with `-fopenmp` ``` torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { #ifdef _OPENMP int thread_id = omp_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); #else data_ptr[i] = i; #endif } }); ``` I can see the following, which I think indicates the function is getting inlined properly <img width="1201" height="327" alt="Screenshot 2025-10-08 at 5 02 06 PM" src="https://github.com/user-attachments/assets/6ada8ce9-dc28-4157-b8ba-3da347856674" /> The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two parallel backends in torch, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined both at libtorch build time and extension build time (meaning that **both libtorch and extension link against OPENMP**, for example see https://github.com/pytorch/audio/pull/1761/files and pytorch/vision#2783) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 Looking at the objdump of the test call to parallel_for I add in kernel.cpp when compiling with `-fopenmp` ``` torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { #ifdef _OPENMP int thread_id = omp_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); #else data_ptr[i] = i; #endif } }); ``` I can see the following, which I think indicates the function is getting inlined properly <img width="1201" height="327" alt="Screenshot 2025-10-08 at 5 02 06 PM" src="https://github.com/user-attachments/assets/6ada8ce9-dc28-4157-b8ba-3da347856674" /> The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two parallel backends in torch, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined both at libtorch build time and extension build time (iiuc this means that **both libtorch and extension link against OPENMP**, for example see https://github.com/pytorch/audio/pull/1761/files and pytorch/vision#2783) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** Looking at the objdump of the test call to parallel_for I add in kernel.cpp when compiling with `-fopenmp` ``` torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { #ifdef _OPENMP int thread_id = omp_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); #else data_ptr[i] = i; #endif } }); ``` I can see the following, which I think indicates the function is getting inlined properly <img width="1201" height="327" alt="Screenshot 2025-10-08 at 5 02 06 PM" src="https://github.com/user-attachments/assets/6ada8ce9-dc28-4157-b8ba-3da347856674" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
An aside before I get to reviewing the actual code here: in ExecuTorch I got a significant code size win by using a port of llvm::function_ref instead of std::function because parallel_for blocks until the computation is done, so it doesn't need to own the function it's calling. If we're thinking of stabilizing and committing to these interfaces forever we should probably first investigate whether function_ref is also appropriate in PyTorch core. |
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
Could you help me understand this better please My impression was that templating F was desirable because we want to allow the I see that smaller code size would be beneficial when binary size is a concern, but I'm not sure that is the goal here (my impression is that an extension using torch/csrc/stable/ops.h would need to depend on libtorch.so (or a binary that implements a large part of the aoti C shim, which might not be so lightweight anyway). |
This is what I get for commenting based on the description without looking at the code, sorry. I see now that we aren't exposing std::function anywhere, so we needn't worry about committing to it. |
test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp
Show resolved
Hide resolved
// matches the existing semantic. | ||
#ifdef _OPENMP | ||
template <typename F> | ||
inline void invoke_parallel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so if I understand correctly, this implementation isn't actually a stable ABI to anything, it's just an OpenMP-based invoke_parallel implementation that will get built into clients' binaries. That's fine, so long as there are no concerns about clients using gcc/clang when pytorch uses the other one (they have different OpenMP implementations that they use by default), and the overall process getting linked against two different OpenMP support libraries, and that somehow causing problems. Those problems, to the extent they are real (which I don't know), already exist for customers that want to use OpenMP in their extension anyway; however here we are sort of advertising/pushing OpenMP to them and so we should make sure that that's actually a reasonable thing to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think your understanding is right, the only reason I call it "stable" is because it uses a shimmed-ThreadIdGuard (though perhaps that is a misnomer)
Those problems, to the extent they are real (which I don't know), already exist for customers that want to use OpenMP in their extension anyway
Yes exactly! 😄
however here we are sort of advertising/pushing OpenMP to them and so we should make sure that that's actually a reasonable thing to do.
Makes sense, my intent mostly for "extension writers to be able to use parallel_for
in the same way the used to". But I see what you mean that having this here might incentivize them to try to use OpenMP.
I think this is the first op we've added in torch/csrc/stable that the compiler toolchain issues apply to (cc @janeyx99 is that right?). Would it be reasonable if I:
- add more explicit documentation here + in our user facing docs re the potential issues users might run into if they use different compilers
- (if helpful) put all the openmp related code in some "more private" looking file, rather than exposing it directly in ops.h
// For the ParallelNative path, this helps with converting C++ lambdas | ||
// etc. to a C-style function pointer expected by the C-shim | ||
template <typename F> | ||
struct Trampoline { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, this is basically a minimal llvm::function_ref so I'm fine with it. Note that they moved to intptr_t from void*
11 years ago because some compilers warn on these casts. llvm/llvm-project@36e1295
AOTI_TORCH_EXPORT bool aoti_torch_get_intra_op_parallel_enabled(); | ||
|
||
// Value of AT_PARALLEL_OPENMP | ||
AOTI_TORCH_EXPORT bool aoti_torch_get_parallel_openmp_enabled(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we try to mimic naming patterns of existing AOTI APIs (See pattern on line 172) as well as torch.backends.openmp.is_available()
? (get_
prefix implies there must be a matching set_
API, but OpenMP is either enabled or not, isn't it?)
AOTI_TORCH_EXPORT bool aoti_torch_get_parallel_openmp_enabled(); | |
AOTI_TORCH_EXPORT bool aoti_torch_openmp_is_available(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, thank you!
test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp
Show resolved
Hide resolved
// If using a parallel path, the thread id is encoded in the upper 32 bits | ||
torch::stable::parallel_for( | ||
0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { | ||
for (int64_t i = begin; i < end; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use c10::irange
there? (or at the very least use auto for i
)
for (int64_t i = begin; i < end; i++) { | |
for (const auto i : c10::irange(begin, end)) { |
StableIValue* stack, | ||
uint64_t num_args, | ||
uint64_t num_outputs) { | ||
Tensor res = test_parallel_for(to<int64_t>(stack[0]), to<int64_t>(stack[1])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor res = test_parallel_for(to<int64_t>(stack[0]), to<int64_t>(stack[1])); | |
auto& res = test_parallel_for(to<int64_t>(stack[0]), to<int64_t>(stack[1])); |
# always use OPENMP path, OpenMP path will only be used if (1) AND (2) | ||
# (1) libtorch was built with OpenMP | ||
# (2) extension compiles and links with -fopenmp | ||
# macOS clang does not support -fopenmp so we need to skip it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an incorrect statement, it does, option simply needs to be wrapped with -Xcompiler -fopenmp
And indeed, if we want some sort of an abstraction there, shouldn't we have helper function, say torch.utils.cpp_extensions.get_openmp_flags()
(which I think already exist)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ooh thank you, let me try this
Sorry, I saw other errors in the codebase that seemed to corroborate this so I assumed it was the case
pytorch/torch/_inductor/cpp_builder.py
Lines 606 to 615 in ffe3cb2
if openmp_problem and sys.platform == "darwin": | |
instruction = ( | |
"\n\nOpenMP support not found. Please try one of the following solutions:\n" | |
"(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " | |
"that has builtin OpenMP support;\n" | |
"(2) install OpenMP via conda: `conda install llvm-openmp`;\n" | |
"(3) install libomp via brew: `brew install libomp`;\n" | |
"(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" | |
" with `include/omp.h` under it." | |
) |
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); | ||
|
||
// parallel utilities | ||
AOTI_TORCH_EXPORT void aoti_torch_lazy_init_num_threads(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API looks quite confusing, may be some explanation why this is needed would be good? (For example, I don't know myself why, on all non-emedded OSes it's pretty safe to query num cores at the initialization time)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add a comment that it's only for use by parallel_for in torch/csrc/stable!
# (2) extension compiles and links with -fopenmp | ||
# macOS clang does not support -fopenmp so we need to skip it | ||
if sys.platform != "darwin": | ||
extra_compile_args["cxx"].extend(["-fopenmp", "-D_OPENMP"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to define _OPENMP here, isn't it a compiler's job? See https://godbolt.org/z/7dEcn1vfP
Also, you'll need to pass a different flag if you are on Windows
torch/csrc/stable/parallel_utils.h
Outdated
@@ -0,0 +1,55 @@ | |||
#pragma once | |||
|
|||
#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm somewhat new to the codebase, but this header name makes me a bit uncofortable. Are stable API simply piggy backing on some of AOTI shim definitions? And is the intention to move it later into torch/csrc/stable folder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes torch/csrc/stable
mainly provides C++ wrappers around the AOTI shim definitions (my understanding is some of these are intended to be more ergonomic e.g. do memory management that the C header does not, provide kwarg defaults that the C header can't etc.). I don't think there is an intention to move it later into torch/csrc/stable
cc @janeyx99 for why
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I can see why the headers give discomfort. I think based on offline discussion, we should start moving our shims to a non-aoti file and go from there.
torch/csrc/stable/ops.h
Outdated
namespace internal { | ||
|
||
// Copied from aten/src/ATen/Parallel.h | ||
inline int64_t divup(int64_t x, int64_t y) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is 5th copy of this function that I'm seeing in the codebase. Why not move it to say torch/csrs/stable/utils.h
? Also what's wrong with the template
template<typename T>
intline T divup(T x, T y) {
...
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, let me add it to torch/headeronly
and make the rest of the libtorch include that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like torch::stable::internal::parallel_for
is just a copy of respective implementations from ParallelOpenMP.h
and ParallelNative.h
. If this is the case, why aren't you deleting the implementaiton there and just make them use this "Stable" implementation?
@malfet Are you referring to For
@janeyx99 If this is an accurate representation/there is any other fundamental reasoning that I'm missing For See my comment in the PR description
|
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative. `INTRA_OP_PARALLEL` (which is used to is gate whether the "parallel logic" in `parallel_for` is used, see below) is defined if 1. `AT_PARALLEL_OPENMP = 1` (at libtorch build time per the generated ATen/Config.h) + `_OPENMP` is defined at extension build time (meaning that **both libtorch and extension compile/link against OPENMP**) 2. `AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1` (at libtorch build time per the generated ATen/Config.h) https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel-inl.h#L9-L43 The approach taken in this PR is to paste the implementation of `parallel_for` from `ATen/Parallel-inl.h `into `torch/csrc/stable/ops.h` with the following modifications: For perf, we want the function passed to `parallel_for` to be inlined all the way into `invoke_parallel` - This is possible for the OpenMP implementation which templates F just like how `parallel_for` does https://github.com/pytorch/pytorch/blob/001e1d263746ae9d121d9c8cf55bc87f777d9dba/aten/src/ATen/ParallelOpenMP.h#L14-L53 We paste the implementation of `invoke_parallel` into `torch::stable::internal` **with a modification that `ThreadIdGuard` (the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitation** When I compile the call I added in `kernel.cpp`, ```cpp torch::stable::parallel_for( 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { int thread_id = aoti_torch_get_thread_num(); data_ptr[i] = i | (static_cast<int64_t>(thread_id) << 32); } }); ``` I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into `invoke_parallel` for this path <img width="1194" height="260" alt="Screenshot 2025-10-08 at 7 03 49 PM" src="https://github.com/user-attachments/assets/32982cfc-8b5f-4765-84db-4aaeb5e77591" /> - This is not possible for the ParallelNative implementation - takes in an `std::function` for `f` - Is defined in a cpp (and relies on other non-headeronly functions) For the above two reasons, we shim the ParallelNative version of `invoke_parallel` https://github.com/pytorch/pytorch/blob/71aefd5595834dd97f38aa978ee32abbd13ac3d6/aten/src/ATen/ParallelNative.cpp#L144-L199 The rest of the APIs are shimmed - `at::internal::lazy_init_num_threads()` --> `aoti_torch_lazy_init_num_threads` Reason for shimming: The [implementation of `at::internal::lazy_init_num_threads()`](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/Parallel.h#L32-L38) calls `at::init_num_threads` which is not header-only - `at::in_parallel_region` --> `aoti_torch_in_parallel_region` Reason for shimming: The [OpenMP implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L94-L100) is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The [ParallelNative implementation](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L266-L276) is defined in a .cpp and depends on whether `c10_MOBILE` is defined at libtorch build time - `at::get_num_threads` --> `aoti_torch_get_num_threads` Reason for shimming: Similar story to `in_parallel_region`, [`ParallelNative` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelNative.cpp#L241-L260), [`OpenMP` impl](https://github.com/pytorch/pytorch/blob/e0cb1848d0fd9fb4467ad8b844c565aea5071838/aten/src/ATen/ParallelOpenMP.cpp#L75-L82) - `ThreadIdGuard` --> `aoti_torch_create_thread_id_guard` and `aoti_torch_delete_thread_id_guard`, with a C++ wrapper `torch::stable::ThreadIdGuard` Reason for shimming: Depends on `set_thread_num` which is not header-only [ThreadIdGuard impl](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/Parallel.h?fbclid=IwY2xjawNTpHBleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR6LlnxdN6zn2HJlVDeoUyYJBHLZKidAmH_wiEJ7CbBE5bF56_4-WaltmlBOEw_aem_iArU_QX6AZQZeBizz6EEJQ#L42-L50) - `c10::ParallelGuard` --> `aoti_torch_create_parallel_guard`, `aoti_torch_delete_parallel_guard`, `aoti_torch_parallel_guard_is_enabled` with a C++ wrapper `torch::stable::ParallelGuard` Reason for shimming: Has a cpp file [[ParallelGuard.cpp](https://github.com/pytorch/pytorch/blob/main/c10/util/ParallelGuard.cpp?fbclid=IwY2xjawNTpLxleHRuA2FlbQIxMQBicmlkETBhRjNHRG5BZEZIcjRTdHVzAR7k_w2Ob695Dy7w_WYPK9vsiMEycutaGMeNkwsp_m0x8Y2FbyWMA1QVtvsH7Q_aem_07IlbboUfzDpqB07wU27IA)] [ghstack-poisoned]
if (begin_tid < end) { | ||
try { | ||
ThreadIdGuard tid_guard(static_cast<uint64_t>(tid)); | ||
f(begin_tid, std::min(end, chunk_size + begin_tid)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@swolchok This is where I thought we could inline f to
The current state of the world is that there are two implementations of torch's parallel interface, OpenMP and ParallelNative.
INTRA_OP_PARALLEL
(which is used to is gate whether the "parallel logic" inparallel_for
is used, see below) is defined ifAT_PARALLEL_OPENMP = 1
(at libtorch build time per the generated ATen/Config.h) +_OPENMP
is defined at extension build time (meaning that both libtorch and extension compile/link against OPENMP)AT_PARALLEL_OPENMP=0 && AT_PARALLEL_NATIVE = 1
(at libtorch build time per the generated ATen/Config.h)pytorch/aten/src/ATen/Parallel-inl.h
Lines 9 to 43 in e0cb184
The approach taken in this PR is to paste the implementation of
parallel_for
fromATen/Parallel-inl.h
intotorch/csrc/stable/ops.h
with the following modifications:For perf, we want the function passed to
parallel_for
to be inlined all the way intoinvoke_parallel
parallel_for
doespytorch/aten/src/ATen/ParallelOpenMP.h
Lines 14 to 53 in 001e1d2
We paste the implementation of
invoke_parallel
intotorch::stable::internal
with a modification thatThreadIdGuard
(the only non-headeronly piece) uses the shimmed version. We do not move it to torch/headeronly due to the ThreadIdGuard limitationWhen I compile the call I added in
kernel.cpp
,I can find the following in the objdump, which I think indicates that the function is getting inlined correctly into
invoke_parallel
for this pathstd::function
forf
For the above two reasons, we shim the ParallelNative version of
invoke_parallel
pytorch/aten/src/ATen/ParallelNative.cpp
Lines 144 to 199 in 71aefd5
The rest of the APIs are shimmed
at::internal::lazy_init_num_threads()
-->aoti_torch_lazy_init_num_threads
Reason for shimming: The implementation of
at::internal::lazy_init_num_threads()
callsat::init_num_threads
which is not header-onlyat::in_parallel_region
-->aoti_torch_in_parallel_region
Reason for shimming: The OpenMP implementation is defined in a .cpp and depends on whether OPENMP is linked against at libtorch build time. The ParallelNative implementation is defined in a .cpp and depends on whether
c10_MOBILE
is defined at libtorch build timeat::get_num_threads
-->aoti_torch_get_num_threads
Reason for shimming: Similar story to
in_parallel_region
,ParallelNative
impl,OpenMP
implThreadIdGuard
-->aoti_torch_create_thread_id_guard
andaoti_torch_delete_thread_id_guard
, with a C++ wrappertorch::stable::ThreadIdGuard
Reason for shimming: Depends on
set_thread_num
which is not header-only ThreadIdGuard implc10::ParallelGuard
-->aoti_torch_create_parallel_guard
,aoti_torch_delete_parallel_guard
,aoti_torch_parallel_guard_is_enabled
with a C++ wrappertorch::stable::ParallelGuard
Reason for shimming: Has a cpp file [ParallelGuard.cpp]
Stack from ghstack (oldest at bottom):