diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index f927403cbde9..2767669bce24 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -73,7 +73,7 @@ set(tvm_ffi_extra_objs_sources "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc" - "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_context.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/env_c_api.cc" "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/testing.cc" ) @@ -249,6 +249,7 @@ endif() install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/include/tvm/ffi/ DESTINATION include/tvm/ffi/) install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/dlpack/include/ DESTINATION include/) +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tvm_ffi_python_helpers.h DESTINATION include/) install(TARGETS tvm_ffi_shared DESTINATION lib) # ship additional dSYM files for debugging symbols on if available if (APPLE) diff --git a/ffi/docs/get_started/quick_start.md b/ffi/docs/get_started/quick_start.md index c7cb007c7815..4861aa87b253 100644 --- a/ffi/docs/get_started/quick_start.md +++ b/ffi/docs/get_started/quick_start.md @@ -125,7 +125,7 @@ void AddOneCUDA(DLTensor* x, DLTensor* y) { // Get current CUDA stream from environment cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // Launch kernel AddOneKernel<<>>( @@ -136,7 +136,7 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA); ``` **Key Points:** -- We use `TVMFFIEnvGetCurrentStream` to obtain the current stream from the environement +- We use `TVMFFIEnvGetStream` to obtain the current stream from the environement - When invoking ffi Function from python end with PyTorch tensor as argument, the stream will be populated with torch's current stream. diff --git a/ffi/examples/inline_module/main.py b/ffi/examples/inline_module/main.py index b55574ae7bab..5cfcd41bec12 100644 --- a/ffi/examples/inline_module/main.py +++ b/ffi/examples/inline_module/main.py @@ -63,7 +63,7 @@ def main(): // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); diff --git a/ffi/examples/quick_start/run_example.py b/ffi/examples/quick_start/run_example.py index 456e58ce91b9..a8f4fc00a600 100644 --- a/ffi/examples/quick_start/run_example.py +++ b/ffi/examples/quick_start/run_example.py @@ -64,7 +64,7 @@ def run_add_one_cuda(): with torch.cuda.stream(stream): # tvm-ffi automatically handles DLPack compatible tensors # it also handles interactions with torch runtime - # torch.cuda.current_stream() will be set and available via TVMFFIEnvGetCurrentStream + # torch.cuda.current_stream() will be set and available via TVMFFIEnvGetStream # when calling the function mod.add_one_cuda(x, y) stream.synchronize() diff --git a/ffi/examples/quick_start/src/add_one_cuda.cu b/ffi/examples/quick_start/src/add_one_cuda.cu index ead2ec89a95c..52f1e7482505 100644 --- a/ffi/examples/quick_start/src/add_one_cuda.cu +++ b/ffi/examples/quick_start/src/add_one_cuda.cu @@ -46,8 +46,8 @@ void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { // Obtain the current stream from the environment // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 5d67fcd22128..a53dac4d00af 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -27,6 +27,21 @@ #include #include +/* + * \brief C-style Allocator that allocates memory for a DLPack tensor. + * \param prototype The prototype DLTensor to offer details about device and shape. + * \param out The output DLManagedTensorVersioned. + * \param error_ctx The context to set the error. + * \param SetError The function to set the error. + * \return 0 on success, -1 on failure. + * call SetError(error_ctx, kind, message) to set the error kind and message. + * \note Error propagation via SetError. + */ +typedef int (*DLPackTensorAllocator)( // + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // + void (*SetError)(void* error_ctx, const char* kind, const char* message) // +); + // Macros to do weak linking #ifdef _MSC_VER #define TVM_FFI_WEAK __declspec(selectany) diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index 5e20b7b51df2..59dc7739ea63 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -32,6 +32,7 @@ #include #include +#include #include namespace tvm { @@ -341,7 +342,60 @@ class Tensor : public ObjectRef { return Tensor(make_object>( alloc, shape, dtype, device, std::forward(extra_args)...)); } - + /*! + * \brief Create a Tensor from a DLPackTensorAllocator + * + * This function can be used together with TVMFFIEnvSetTensorAllocator + * in the extra/c_env_api.h to create Tensor from the thread-local + * environment allocator. + * + * \code + * + * ffi::Tensor tensor = ffi::Tensor::FromDLPackAlloc( + * TVMFFIEnvGetTensorAllocator(), shape, dtype, device + * ); + * \endcode + * + * \param allocator The DLPack allocator. + * \param shape The shape of the Tensor. + * \param dtype The data type of the Tensor. + * \param device The device of the Tensor. + * \return The created Tensor. + */ + static Tensor FromDLPackAlloc(DLPackTensorAllocator allocator, ffi::Shape shape, DLDataType dtype, + DLDevice device) { + if (allocator == nullptr) { + TVM_FFI_THROW(RuntimeError) + << "FromDLPackAlloc: allocator is nullptr, " + << "likely because TVMFFIEnvSetTensorAllocator has not been called."; + } + DLTensor prototype; + prototype.device = device; + prototype.dtype = dtype; + prototype.shape = const_cast(shape.data()); + prototype.ndim = static_cast(shape.size()); + prototype.strides = nullptr; + prototype.byte_offset = 0; + prototype.data = nullptr; + DLManagedTensorVersioned* tensor = nullptr; + // error context to be used to propagate error + struct ErrorContext { + std::string kind; + std::string message; + static void SetError(void* error_ctx, const char* kind, const char* message) { + ErrorContext* error_context = static_cast(error_ctx); + error_context->kind = kind; + error_context->message = message; + } + }; + ErrorContext error_context; + int ret = (*allocator)(&prototype, &tensor, &error_context, ErrorContext::SetError); + if (ret != 0) { + throw ffi::Error(error_context.kind, error_context.message, + TVMFFITraceback(__FILE__, __LINE__, __func__, 0)); + } + return Tensor(make_object>(tensor)); + } /*! * \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API. * \param tensor The input DLPack managed tensor. diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h b/ffi/include/tvm/ffi/extra/c_env_api.h index bd0d188155fe..3c49d79d3071 100644 --- a/ffi/include/tvm/ffi/extra/c_env_api.h +++ b/ffi/include/tvm/ffi/extra/c_env_api.h @@ -46,12 +46,11 @@ typedef void* TVMFFIStreamHandle; * \param device_id The id of the device. * \param stream The stream to set. * \param opt_out_original_stream Output original stream if the address is not nullptr. - * \note The stream is a weak reference that is cached/owned by the module. * \return 0 when success, nonzero when failure happens */ -TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle* opt_out_original_stream); +TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle* opt_out_original_stream); /*! * \brief FFI function to get the current stream for a device @@ -60,7 +59,29 @@ TVM_FFI_DLL int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id * \param device_id The id of the device. * \return The current stream of the device. */ -TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id); +TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id); + +/*! + * \brief FFI function to set the current DLPack allocator in thread-local(TLS) context + * + * \param allocator The allocator to set. + * \param write_to_global_context Whether to also set the allocator to the global context. + * \param opt_out_original_allocator Output original TLS allocator if the address is not nullptr. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvSetTensorAllocator(DLPackTensorAllocator allocator, + int write_to_global_context, + DLPackTensorAllocator* opt_out_original_allocator); + +/*! + * \brief FFI function get the current DLPack allocator stored in context. + * + * This function first queries the global context, and if not found, + * queries the thread-local context. + * + * \return The current DLPack allocator. + */ +TVM_FFI_DLL DLPackTensorAllocator TVMFFIEnvGetTensorAllocator(); /*! * \brief Check if there are any signals raised in the surrounding env. diff --git a/ffi/licenses/LICENSE.pytorch.txt b/ffi/licenses/LICENSE.pytorch.txt new file mode 100644 index 000000000000..966a609b61e5 --- /dev/null +++ b/ffi/licenses/LICENSE.pytorch.txt @@ -0,0 +1,84 @@ +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + +All contributions by Tri Dao: +Copyright (c) 2024 Tri Dao. +All rights reserved. + +All contributions by Arm: +Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/ffi/licenses/NOTICE.pytorch.txt b/ffi/licenses/NOTICE.pytorch.txt new file mode 100644 index 000000000000..6effb8b5d707 --- /dev/null +++ b/ffi/licenses/NOTICE.pytorch.txt @@ -0,0 +1,456 @@ +======================================================================= +Software under third_party +======================================================================= +Software libraries under third_party are provided as github submodule +links, and their content is not part of the Caffe2 codebase. Their +licences can be found under the respective software repositories. + +======================================================================= +Earlier BSD License +======================================================================= +Early development of Caffe2 in 2015 and early 2016 is licensed under the +BSD license. The license is attached below: + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +Caffe's BSD License +======================================================================= +Some parts of the caffe2 code is derived from the original Caffe code, which is +created by Yangqing Jia and is now a BSD-licensed open-source project. The Caffe +license is as follows: + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +======================================================================= +Caffe2's Apache License +======================================================================= + +This repo contains Caffe2 code, which was previously licensed under +Apache License Version 2.0: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +======================================================================= +Cephes's 3-Clause BSD License +======================================================================= + +Code derived from implementations in the Cephes Math Library should mention +its derivation and reference the following license: + + 3-Clause BSD License for the Cephes Math Library + Copyright (c) 2018, Steven Moshier + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +SciPy's 3-Clause BSD License +======================================================================= + +Code derived from implementations in SciPy should mention its derivation +and reference the following license: + + Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================= +Boost's 1.0 Software License +======================================================================= + +Code derived from implementations in Boost 1.0 should mention its +derivation and reference the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + + Permission is hereby granted, free of charge, to any person or organization + obtaining a copy of the software and accompanying documentation covered by + this license (the "Software") to use, reproduce, display, distribute, + execute, and transmit the Software, and to prepare derivative works of the + Software, and to permit third-parties to whom the Software is furnished to + do so, all subject to the following: + + The copyright notices in the Software and this entire statement, including + the above license grant, this restriction and the following disclaimer, + must be included in all copies of the Software, in whole or in part, and + all derivative works of the Software, unless such copies or derivative + works are solely in the form of machine-executable object code generated by + a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +======================================================================= +PILLOW-SIMD Software License +======================================================================= + +Code derived from implementations in PILLOW-SIMD should mention its derivation +and reference the following license: + + The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh + + Pillow is the friendly PIL fork. It is + + Copyright © 2010-2022 by Alex Clark and contributors + + Like PIL, Pillow is licensed under the open source HPND License: + + By obtaining, using, and/or copying this software and/or its associated + documentation, you agree that you have read, understood, and will comply + with the following terms and conditions: + + Permission to use, copy, modify, and distribute this software and its + associated documentation for any purpose and without fee is hereby granted, + provided that the above copyright notice appears in all copies, and that + both that copyright notice and this permission notice appear in supporting + documentation, and that the name of Secret Labs AB or the author not be + used in advertising or publicity pertaining to distribution of the software + without specific, written prior permission. + + SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS + SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. + IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, + INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + PERFORMANCE OF THIS SOFTWARE. diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml index 0988a78d6308..11e65a9065d2 100644 --- a/ffi/pyproject.toml +++ b/ffi/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "apache-tvm-ffi" -version = "0.1.0a9" +version = "0.1.0a11" description = "tvm ffi" authors = [{ name = "TVM FFI team" }] diff --git a/ffi/python/tvm_ffi/__init__.py b/ffi/python/tvm_ffi/__init__.py index b0ff88c6c8e1..c23e8b59fee7 100644 --- a/ffi/python/tvm_ffi/__init__.py +++ b/ffi/python/tvm_ffi/__init__.py @@ -39,6 +39,8 @@ from . import access_path from . import testing +# optional module to speedup dlpack conversion +from . import _optional_torch_c_dlpack __all__ = [ "dtype", diff --git a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py new file mode 100644 index 000000000000..f4af39302521 --- /dev/null +++ b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py @@ -0,0 +1,403 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Optional module to support faster DLPack conversion. + +This is an optional module to support faster DLPack conversion for torch. +Some of the changes are merged but not yet released, so it is used +as a stop gap to support faster DLPack conversion. + +This file contains source code from PyTorch: +License: licenses/LICENSE.pytorch.txt + +This module only serves as temp measure and will +likely be phased away and deleted after changes landed and released in pytorch. + +This module will load slowly at first time due to JITing, +subsequent calls will be much faster. +""" +import warnings +from . import libinfo + + +def load_torch_c_dlpack_extension(): + """Load the torch c dlpack extension.""" + cpp_source = """ +#include +#include +#include +#include + +using namespace std; +namespace at { +namespace { + +DLDataType getDLDataTypeForDLPackv1(const Tensor& t) { + DLDataType dtype; + dtype.lanes = 1; + dtype.bits = t.element_size() * 8; + switch (t.scalar_type()) { + case ScalarType::UInt1: + case ScalarType::UInt2: + case ScalarType::UInt3: + case ScalarType::UInt4: + case ScalarType::UInt5: + case ScalarType::UInt6: + case ScalarType::UInt7: + case ScalarType::Byte: + case ScalarType::UInt16: + case ScalarType::UInt32: + case ScalarType::UInt64: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: + case ScalarType::Char: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Double: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Float: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Int: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Long: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Short: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Half: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Bool: + dtype.code = DLDataTypeCode::kDLBool; + break; + case ScalarType::ComplexHalf: + case ScalarType::ComplexFloat: + case ScalarType::ComplexDouble: + dtype.code = DLDataTypeCode::kDLComplex; + break; + case ScalarType::BFloat16: + dtype.code = DLDataTypeCode::kDLBfloat; + break; + case ScalarType::Float8_e5m2: + dtype.code = DLDataTypeCode::kDLFloat8_e5m2; + break; + case ScalarType::Float8_e5m2fnuz: + dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz; + break; + case ScalarType::Float8_e4m3fn: + dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn; + break; + case ScalarType::Float8_e4m3fnuz: + dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz; + break; + case ScalarType::Float8_e8m0fnu: + dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu; + break; + case ScalarType::Float4_e2m1fn_x2: + dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn; + break; + default: + TORCH_CHECK(false, "Unsupported scalar type: "); + } + return dtype; +} + +DLDevice torchDeviceToDLDeviceForDLPackv1(at::Device device) { + DLDevice ctx; + + ctx.device_id = (device.is_cuda() || device.is_privateuseone()) + ? static_cast(static_cast(device.index())) + : 0; + + switch (device.type()) { + case DeviceType::CPU: + ctx.device_type = DLDeviceType::kDLCPU; + break; + case DeviceType::CUDA: +#ifdef USE_ROCM + ctx.device_type = DLDeviceType::kDLROCM; +#else + ctx.device_type = DLDeviceType::kDLCUDA; +#endif + break; + case DeviceType::OPENCL: + ctx.device_type = DLDeviceType::kDLOpenCL; + break; + case DeviceType::HIP: + ctx.device_type = DLDeviceType::kDLROCM; + break; + case DeviceType::XPU: + ctx.device_type = DLDeviceType::kDLOneAPI; + ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device); + break; + case DeviceType::MAIA: + ctx.device_type = DLDeviceType::kDLMAIA; + break; + case DeviceType::PrivateUse1: + ctx.device_type = DLDeviceType::kDLExtDev; + break; + case DeviceType::MPS: + ctx.device_type = DLDeviceType::kDLMetal; + break; + default: + TORCH_CHECK(false, "Cannot pack tensors on " + device.str()); + } + + return ctx; +} + +template +struct ATenDLMTensor { + Tensor handle; + T tensor{}; +}; + +template +void deleter(T* arg) { + delete static_cast*>(arg->manager_ctx); +} + +// Adds version information for DLManagedTensorVersioned. +// This is a no-op for the other types. +template +void fillVersion(T* tensor) {} + +template <> +void fillVersion( + DLManagedTensorVersioned* tensor) { + tensor->flags = 0; + tensor->version.major = DLPACK_MAJOR_VERSION; + tensor->version.minor = DLPACK_MINOR_VERSION; +} + +// This function returns a shared_ptr to memory managed DLpack tensor +// constructed out of ATen tensor +template +T* toDLPackImpl(const Tensor& src) { + auto view = src; + + bool need_normalize_strides = false; + int64_t expected_stride = 1; + for (int i = src.dim() - 1; i >= 0; i--) { + // detect if we do not meet continuous pattern + // and the size is 1, so there is opportunity to normalize + if (src.stride(i) != expected_stride && src.size(i) == 1) { + need_normalize_strides = true; + break; + } + expected_stride *= src.size(i); + } + + // less common case, try normalizing the strides + if (need_normalize_strides) { + // create a new tensor with possibly normalized strides + // gh-83069 + auto shape = src.sizes(); + auto strides = src.strides().vec(); + for (int i = 0; i < src.dim(); i++) { + if (shape[i] < 2) { + strides[i] = 1; + } + } + view = src.as_strided(shape, strides, src.storage_offset()); + } + + ATenDLMTensor* atDLMTensor(new ATenDLMTensor); + atDLMTensor->handle = view; + atDLMTensor->tensor.manager_ctx = atDLMTensor; + atDLMTensor->tensor.deleter = &deleter; + atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); + atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDeviceForDLPackv1(src.device()); + atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); + atDLMTensor->tensor.dl_tensor.dtype = getDLDataTypeForDLPackv1(src); + atDLMTensor->tensor.dl_tensor.shape = const_cast(view.sizes().data()); + atDLMTensor->tensor.dl_tensor.strides = const_cast(view.strides().data()); + atDLMTensor->tensor.dl_tensor.byte_offset = 0; + fillVersion(&atDLMTensor->tensor); + return &(atDLMTensor->tensor); +} + +static Device getATenDeviceForDLPackv1(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) { + switch (type) { + case DLDeviceType::kDLCPU: + return at::Device(DeviceType::CPU); +#ifndef USE_ROCM + // if we are compiled under HIP, we cannot do cuda + case DLDeviceType::kDLCUDA: + return at::Device(DeviceType::CUDA, index); +#endif + case DLDeviceType::kDLOpenCL: + return at::Device(DeviceType::OPENCL, index); + case DLDeviceType::kDLROCM: +#ifdef USE_ROCM + // this looks funny, we need to return CUDA here to masquerade + return at::Device(DeviceType::CUDA, index); +#else + return at::Device(DeviceType::HIP, index); +#endif + case DLDeviceType::kDLOneAPI: + TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data."); + return at::detail::getXPUHooks().getDeviceFromPtr(data); + case DLDeviceType::kDLMAIA: + return at::Device(DeviceType::MAIA, index); + case DLDeviceType::kDLExtDev: + return at::Device(DeviceType::PrivateUse1, index); + case DLDeviceType::kDLMetal: + return at::Device(DeviceType::MPS, index); + default: + TORCH_CHECK( + false, "Unsupported device_type: ", std::to_string(type)); + } +} + + +// This function constructs a Tensor from a memory managed DLPack which +// may be represented as either: DLManagedTensor and DLManagedTensorVersioned. +template +at::Tensor fromDLPackImpl(T* src, std::function deleter) { + if (!deleter) { + deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + } + + DLTensor& dl_tensor = src->dl_tensor; + Device device = getATenDeviceForDLPackv1(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data); + ScalarType stype = toScalarType(dl_tensor.dtype); + + if (!dl_tensor.strides) { + return at::from_blob( + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), + std::move(deleter), + at::device(device).dtype(stype), + {device}); + } + return at::from_blob( + dl_tensor.data, + IntArrayRef(dl_tensor.shape, dl_tensor.ndim), + IntArrayRef(dl_tensor.strides, dl_tensor.ndim), + deleter, + at::device(device).dtype(stype), + {device}); +} + +} // namespace +} // namespace at + +int TorchDLPackPyObjectExporter(void* py_obj, DLManagedTensorVersioned** out, void** env_stream) { + try { + py::handle handle(static_cast(py_obj)); + at::Tensor tensor = handle.cast(); + if (env_stream != nullptr && tensor.is_cuda()) { + *env_stream = at::cuda::getCurrentCUDAStream(tensor.device().index()).stream(); + } + *out = at::toDLPackImpl(tensor); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int TorchDLPackPyObjectImporter(DLManagedTensorVersioned* src, void** py_obj_out) { + try { + at::Tensor tensor = at::fromDLPackImpl(src, nullptr); + *py_obj_out = THPVariable_Wrap(tensor); + return 0; + } catch (const std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return -1; + } +} + +int TorchDLPackTensorAllocator( + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, const char* message) +) { + try { + at::IntArrayRef shape(prototype->shape, prototype->shape + prototype->ndim); + at::TensorOptions options = at::TensorOptions() + .dtype(at::toScalarType(prototype->dtype)) + .device(at::getATenDeviceForDLPackv1(prototype->device.device_type, prototype->device.device_id)); + at::Tensor tensor = at::empty(shape, options); + *out = at::toDLPackImpl(tensor); + return 0; + } catch (const std::exception& e) { + SetError(error_ctx, "TorchDLPackTensorAllocator", e.what()); + return -1; + } +} + +int64_t TorchDLPackPyObjectExporterPtr() { + return reinterpret_cast(TorchDLPackPyObjectExporter); +} + +int64_t TorchDLPackPyObjectImporterPtr() { + return reinterpret_cast(TorchDLPackPyObjectImporter); +} + +int64_t TorchDLPackTensorAllocatorPtr() { + return reinterpret_cast(TorchDLPackTensorAllocator); +} + """ + try: + # optionally import torch + import torch + from torch.utils import cpp_extension + + mod = cpp_extension.load_inline( + name="to_dlpack", + cpp_sources=cpp_source, + functions=[ + "TorchDLPackPyObjectExporterPtr", + "TorchDLPackPyObjectImporterPtr", + "TorchDLPackTensorAllocatorPtr", + ], + extra_cflags=["-O3"], + extra_include_paths=libinfo.include_paths() + cpp_extension.include_paths("cuda"), + verbose=True, + ) + # set the dlpack related flags + torch.Tensor.__c_dlpack_exporter__ = mod.TorchDLPackPyObjectExporterPtr() + torch.Tensor.__c_dlpack_importer__ = mod.TorchDLPackPyObjectImporterPtr() + torch.Tensor.__c_dlpack_tensor_allocator__ = mod.TorchDLPackTensorAllocatorPtr() + return mod + except ImportError: + pass + except Exception as e: + warnings.warn( + f"Failed to load torch c dlpack extension: {e}," + "EnvTensorAllocator will not be enabled." + ) + return None + + +# keep alive +_mod = load_torch_c_dlpack_extension() diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index 08b01d424f1f..a1de1de1cd89 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -238,27 +238,39 @@ cdef extern from "tvm/ffi/extra/c_env_api.h": ctypedef void* TVMFFIStreamHandle int TVMFFIEnvRegisterCAPI(const char* name, void* ptr) nogil - void* TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) nogil - int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, + void* TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) nogil + int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, TVMFFIStreamHandle* opt_out_original_stream) nogil cdef extern from "tvm_ffi_python_helpers.h": # no need to expose fields of the call context + # setter data structure + ctypedef int (*DLPackPyObjectExporter)( + void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream + ) except -1 + + ctypedef int (*DLPackPyObjectImporter)( + DLManagedTensorVersioned* tensor, void** py_obj_out + ) except -1 + ctypedef int (*DLPackTensorAllocator)( + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, const char* message) + ) except -1 + ctypedef struct TVMFFIPyCallContext: int device_type int device_id TVMFFIStreamHandle stream - - # setter data structure - ctypedef int (*DLPackPyObjectCExporter)( - void* py_obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream - ) except -1 + DLPackPyObjectImporter c_dlpack_importer + DLPackTensorAllocator c_dlpack_tensor_allocator ctypedef struct TVMFFIPyArgSetter: int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out) except -1 - DLPackPyObjectCExporter dlpack_c_exporter + DLPackPyObjectExporter c_dlpack_exporter + DLPackPyObjectImporter c_dlpack_importer + DLPackTensorAllocator c_dlpack_tensor_allocator ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, TVMFFIPyArgSetter* out) except -1 # The main call function @@ -267,7 +279,9 @@ cdef extern from "tvm_ffi_python_helpers.h": void* chandle, PyObject* py_arg_tuple, TVMFFIAny* result, - int* c_api_ret_code + int* c_api_ret_code, + int release_gil, + DLPackPyObjectImporter* out_dlpack_importer ) except -1 int TVMFFIPyCallFieldSetter( diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index b77b19a2eabb..bd486c5f77f5 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -29,8 +29,9 @@ else: torch = None -_torch_dlpack_c_exporter_ptr = None - +cdef int _RELEASE_GIL_BY_DEFAULT = int( + os.environ.get("TVM_FFI_RELEASE_GIL_BY_DEFAULT", "1") +) cdef inline object make_ret_small_str(TVMFFIAny result): """convert small string to return value.""" @@ -46,13 +47,13 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result): return PyBytes_FromStringAndSize(bytes.data, bytes.size) -cdef inline object make_ret(TVMFFIAny result): +cdef inline object make_ret(TVMFFIAny result, DLPackPyObjectImporter c_dlpack_importer = NULL): """convert result to return value.""" cdef int32_t type_index type_index = result.type_index if type_index == kTVMFFITensor: # specially handle Tensor as it needs a special dltensor field - return make_tensor_from_any(result) + return make_tensor_from_any(result, c_dlpack_importer) elif type_index == kTVMFFIOpaquePyObject: return make_ret_opaque_object(result) elif type_index >= kTVMFFIStaticObjectBegin: @@ -120,13 +121,18 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_( cdef TVMFFIObjectHandle temp_chandle cdef TVMFFIStreamHandle env_stream = NULL + if this.c_dlpack_importer != NULL: + ctx.c_dlpack_importer = this.c_dlpack_importer + if this.c_dlpack_tensor_allocator != NULL: + ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator + if ctx.device_id != -1: # already queried device, do not do it again, pass NULL to stream - if (this.dlpack_c_exporter)(arg, &temp_managed_tensor, NULL) != 0: + if (this.c_dlpack_exporter)(arg, &temp_managed_tensor, NULL) != 0: return -1 else: # query string on the envrionment stream - if (this.dlpack_c_exporter)(arg, &temp_managed_tensor, &env_stream) != 0: + if (this.c_dlpack_exporter)(arg, &temp_managed_tensor, &env_stream) != 0: return -1 # If device is not CPU, we should set the device type and id if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU: @@ -142,17 +148,32 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_( return 0 -cdef int TVMFFIPyArgSetterTorch_( +cdef int TorchDLPackPyObjectImporterFallback_( + DLManagedTensorVersioned* dltensor, void** py_obj_out +) except -1: + # a bit convoluted but ok as a fallback + cdef TVMFFIObjectHandle temp_chandle + TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) + tensor = make_tensor_from_chandle(temp_chandle) + torch_tensor = torch.from_dlpack(tensor) + Py_INCREF(torch_tensor) + py_obj_out[0] = (torch_tensor) + return 0 + + +cdef int TVMFFIPyArgSetterTorchFallback_( TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out ) except -1: """Current setter for torch.Tensor, go through python and not as fast as c exporter""" + # TODO(tqchen): remove this once torch always support fast DLPack importer cdef object arg = py_arg is_cuda = arg.is_cuda arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg)) out.type_index = kTVMFFITensor out.v_ptr = (arg).chandle temp_dltensor = TVMFFITensorGetDLTensorPtr((arg).chandle) + ctx.c_dlpack_importer = TorchDLPackPyObjectImporterFallback_ # record the stream and device for torch context if is_cuda and ctx.device_type != -1: ctx.device_type = temp_dltensor.device.device_type @@ -180,10 +201,10 @@ cdef int TVMFFIPyArgSetterDLPack_( if (temp_dltensor.device.device_type != kDLCPU and ctx.device_type != -1): # __tvm_ffi_env_stream__ returns the expected stream that should be set - # through TVMFFIEnvSetCurrentStream when calling a TVM FFI function + # through TVMFFIEnvSetStream when calling a TVM FFI function if hasattr(arg, "__tvm_ffi_env_stream__"): # Ideally projects should directly setup their stream context API - # write through by also calling TVMFFIEnvSetCurrentStream + # write through by also calling TVMFFIEnvSetStream # so we do not need this protocol to do exchange ctx.device_type = temp_dltensor.device.device_type ctx.device_id = temp_dltensor.device.device_id @@ -349,19 +370,21 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce if isinstance(arg, ObjectRValueRef): out.func = TVMFFIPyArgSetterObjectRValueRef_ return 0 - # external tensors - if hasattr(arg, "__dlpack_c_exporter__"): - out.func = TVMFFIPyArgSetterDLPackCExporter_ - temp_ptr = arg.__dlpack_c_exporter__ - out.dlpack_c_exporter = temp_ptr - return 0 - if torch is not None and isinstance(arg, torch.Tensor): - if _torch_dlpack_c_exporter_ptr is not None: - temp_ptr = _torch_dlpack_c_exporter_ptr + if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXPORTER", "0") != "1": + # external tensors + if hasattr(arg, "__c_dlpack_exporter__"): out.func = TVMFFIPyArgSetterDLPackCExporter_ - out.dlpack_c_exporter = temp_ptr - else: - out.func = TVMFFIPyArgSetterTorch_ + temp_ptr = arg.__c_dlpack_exporter__ + out.c_dlpack_exporter = temp_ptr + if hasattr(arg, "__c_dlpack_importer__"): + temp_ptr = arg.__c_dlpack_importer__ + out.c_dlpack_importer = temp_ptr + if hasattr(arg, "__c_dlpack_tensor_allocator__"): + temp_ptr = arg.__c_dlpack_tensor_allocator__ + out.c_dlpack_tensor_allocator = temp_ptr + return 0 + if torch is not None and isinstance(arg, torch.Tensor): + out.func = TVMFFIPyArgSetterTorchFallback_ return 0 if hasattr(arg, "__dlpack__"): out.func = TVMFFIPyArgSetterDLPack_ @@ -415,13 +438,16 @@ cdef inline int ConstructorCall(void* constructor_handle, # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone result.type_index = kTVMFFINone result.v_int64 = 0 - TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory_, constructor_handle, args, &result, &c_api_ret_code) + TVMFFIPyFuncCall( + TVMFFIPyArgSetterFactory_, constructor_handle, args, &result, &c_api_ret_code, + False, NULL + ) CHECK_CALL(c_api_ret_code) handle[0] = result.v_ptr return 0 -class Function(Object): +cdef class Function(Object): """Python class that wraps a function with tvm-ffi ABI. See Also @@ -429,9 +455,22 @@ class Function(Object): tvm_ffi.register_global_func: How to register global function. tvm_ffi.get_global_func: How to get global function. """ + cdef int c_release_gil + cdef dict __dict__ + + def __cinit__(self): + self.c_release_gil = _RELEASE_GIL_BY_DEFAULT + + property release_gil: + def __get__(self): + return self.c_release_gil != 0 + def __set__(self, value): + self.c_release_gil = value + def __call__(self, *args): cdef TVMFFIAny result cdef int c_api_ret_code + cdef DLPackPyObjectImporter c_dlpack_importer = NULL # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone result.type_index = kTVMFFINone result.v_int64 = 0 @@ -439,12 +478,14 @@ class Function(Object): TVMFFIPyArgSetterFactory_, (self).chandle, args, &result, - &c_api_ret_code + &c_api_ret_code, + self.release_gil, + &c_dlpack_importer ) # NOTE: logic is same as check_call # directly inline here to simplify traceback if c_api_ret_code == 0: - return make_ret(result) + return make_ret(result, c_dlpack_importer) elif c_api_ret_code == -2: raise_existing_error() raise move_from_last_error().py_error() diff --git a/ffi/python/tvm_ffi/cython/tensor.pxi b/ffi/python/tvm_ffi/cython/tensor.pxi index fca6cc0bbc08..2fd80bc1a6c8 100644 --- a/ffi/python/tvm_ffi/cython/tensor.pxi +++ b/ffi/python/tvm_ffi/cython/tensor.pxi @@ -51,9 +51,8 @@ cdef inline object _from_dlpack_intptr( cdef int c_api_ret_code cdef int c_req_alignment = 0 cdef int c_req_contiguous = 0 - with nogil: - c_api_ret_code = TVMFFITensorFromDLPack( - ptr, c_req_alignment, c_req_contiguous, &chandle) + c_api_ret_code = TVMFFITensorFromDLPack( + ptr, c_req_alignment, c_req_contiguous, &chandle) CHECK_CALL(c_api_ret_code) return make_tensor_from_chandle(chandle) @@ -68,9 +67,8 @@ cdef inline int _from_dlpack( cdef int c_req_contiguous = require_contiguous if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor): ptr = pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor) - with nogil: - c_api_ret_code = TVMFFITensorFromDLPack( - ptr, c_req_alignment, c_req_contiguous, out) + c_api_ret_code = TVMFFITensorFromDLPack( + ptr, c_req_alignment, c_req_contiguous, out) CHECK_CALL(c_api_ret_code) # set name and destructor to be empty pycapsule.PyCapsule_SetDestructor(dltensor, NULL) @@ -90,9 +88,8 @@ cdef inline int _from_dlpack_versioned( if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor_versioned): ptr = pycapsule.PyCapsule_GetPointer( dltensor, _c_str_dltensor_versioned) - with nogil: - c_api_ret_code = TVMFFITensorFromDLPackVersioned( - ptr, c_req_alignment, c_req_contiguous, out) + c_api_ret_code = TVMFFITensorFromDLPackVersioned( + ptr, c_req_alignment, c_req_contiguous, out) CHECK_CALL(c_api_ret_code) # set name and destructor to be empty pycapsule.PyCapsule_SetDestructor(dltensor, NULL) @@ -209,18 +206,14 @@ cdef class Tensor(Object): def _to_dlpack(self): cdef DLManagedTensor* dltensor cdef int c_api_ret_code - - with nogil: - c_api_ret_code = TVMFFITensorToDLPack(self.chandle, &dltensor) + c_api_ret_code = TVMFFITensorToDLPack(self.chandle, &dltensor) CHECK_CALL(c_api_ret_code) return pycapsule.PyCapsule_New(dltensor, _c_str_dltensor, _c_dlpack_deleter) def _to_dlpack_versioned(self): cdef DLManagedTensorVersioned* dltensor cdef int c_api_ret_code - - with nogil: - c_api_ret_code = TVMFFITensorToDLPackVersioned(self.chandle, &dltensor) + c_api_ret_code = TVMFFITensorToDLPackVersioned(self.chandle, &dltensor) CHECK_CALL(c_api_ret_code) return pycapsule.PyCapsule_New( dltensor, _c_str_dltensor_versioned, _c_dlpack_versioned_deleter) @@ -282,24 +275,24 @@ _set_class_tensor(Tensor) _register_object_by_index(kTVMFFITensor, Tensor) - -cdef int _dltensor_test_wrapper_dlpack_c_exporter( +cdef int _dltensor_test_wrapper_c_dlpack_exporter( void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream ) except -1: - cdef object ref_obj = (obj) - cdef DLTensorTestWrapper wrapper = ref_obj + cdef PyObject* py_obj = obj + cdef DLTensorTestWrapper wrapper = py_obj cdef TVMFFIStreamHandle current_stream - + cdef DLManagedTensorVersioned* temp_managed_tensor if env_stream != NULL: - env_stream[0] = TVMFFIEnvGetCurrentStream( + env_stream[0] = TVMFFIEnvGetStream( wrapper.tensor.cdltensor.device.device_type, wrapper.tensor.cdltensor.device.device_id ) + return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out) -def _dltensor_test_wrapper_dlpack_c_exporter_as_intptr(): - cdef DLPackPyObjectCExporter converter_func = _dltensor_test_wrapper_dlpack_c_exporter +def _dltensor_test_wrapper_c_dlpack_exporter_as_intptr(): + cdef DLPackPyObjectExporter converter_func = _dltensor_test_wrapper_c_dlpack_exporter cdef void* temp_ptr = converter_func cdef long long temp_int_ptr = temp_ptr return temp_int_ptr @@ -308,8 +301,10 @@ def _dltensor_test_wrapper_dlpack_c_exporter_as_intptr(): cdef class DLTensorTestWrapper: """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose. """ - __dlpack_c_exporter__ = _dltensor_test_wrapper_dlpack_c_exporter_as_intptr() + __c_dlpack_exporter__ = _dltensor_test_wrapper_c_dlpack_exporter_as_intptr() + cdef Tensor tensor + cdef dict __dict__ def __init__(self, tensor): self.tensor = tensor @@ -317,9 +312,8 @@ cdef class DLTensorTestWrapper: cdef TVMFFIStreamHandle stream cdef long long stream_as_int cdef int c_api_ret_code - with nogil: - stream = TVMFFIEnvGetCurrentStream( - self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id) + stream = TVMFFIEnvGetStream( + self.tensor.cdltensor.device.device_type, self.tensor.cdltensor.device.device_id) stream_as_int = stream return stream_as_int @@ -339,14 +333,30 @@ cdef inline object make_ret_dltensor(TVMFFIAny result): return tensor -cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle): +cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle, DLPackPyObjectImporter c_dlpack_importer = NULL): # TODO: Implement cdef Tensor tensor + cdef void* py_obj + cdef DLManagedTensorVersioned* dlpack + + if c_dlpack_importer != NULL: + # try convert and import into the environment array if possible + if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0: + try: + # note that py_obj already holds an extra reference to the tensor + # so we need to decref it after the conversion + c_dlpack_importer(dlpack, &py_obj) + tensor = (py_obj) + Py_DECREF(tensor) + return tensor + except Exception: + pass + # default return the tensor tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR) (tensor).chandle = chandle (tensor).cdltensor = TVMFFITensorGetDLTensorPtr(chandle) return tensor -cdef inline object make_tensor_from_any(TVMFFIAny any): - return make_tensor_from_chandle(any.v_ptr) +cdef inline object make_tensor_from_any(TVMFFIAny any, DLPackPyObjectImporter c_dlpack_importer): + return make_tensor_from_chandle(any.v_ptr, c_dlpack_importer) diff --git a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h index 32ded385bae8..c7d847b85780 100644 --- a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h +++ b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h @@ -27,13 +27,40 @@ #include #include +#include #include +#include #include +//---------------------------------------------------------- +// Extra support for DLPack +//---------------------------------------------------------- +/*! + * \brief C-style function pointer to speed convert a PyObject Tensor to a DLManagedTensorVersioned. + * \param py_obj The Python object to convert, this should be PyObject* + * \param out The output DLManagedTensorVersioned. + * \param env_stream Outputs the current context stream of the device provided by the tensor. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + * \note We use void* to avoid dependency on Python.h so this specific type is + * not dependent on Python.h and can be copied to dlpack.h + */ +typedef int (*DLPackPyObjectExporter)(void* py_obj, DLManagedTensorVersioned** out, + void** env_stream); +/*! + * \brief C-style function pointer to speed convert a DLManagedTensorVersioned to a PyObject Tensor. + * \param tensor The DLManagedTensorVersioned to convert. + * \param py_obj_out The output Python object. + * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. + * \note We use void* to avoid dependency on Python.h so this specific type is + * not dependent on Python.h and can be copied to dlpack.h + */ +typedef int (*DLPackPyObjectImporter)(DLManagedTensorVersioned* tensor, void** py_obj_out); + ///-------------------------------------------------------------------------------- /// We deliberately designed the data structure and function to be C-style // prefixed with TVMFFIPy so they can be easily invoked through Cython. ///-------------------------------------------------------------------------------- + /*! * \brief Context for each ffi call to track the stream, device and temporary arguments. */ @@ -54,20 +81,12 @@ struct TVMFFIPyCallContext { void** temp_py_objects = nullptr; /*! \brief the number of temporary arguments */ int num_temp_py_objects = 0; + /*! \brief the DLPack exporter, if any */ + DLPackPyObjectImporter c_dlpack_importer{nullptr}; + /*! \brief the DLPack allocator, if any */ + DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; }; -/*! - * \brief C-style function pointer to speed convert a Tensor to a DLManagedTensorVersioned. - * \param py_obj The Python object to convert, this should be PyObject* - * \param out The output DLManagedTensorVersioned. - * \param env_stream Outputs the current context stream of the device provided by the tensor. - * \return 0 on success, -1 on failure. PyError should be set if -1 is returned. - * \note We use void* to avoid dependency on Python.h so this specific type is - * not dependent on Python.h and can be copied to dlpack.h - */ -typedef int (*DLPackPyObjectCExporter)(void* py_obj, DLManagedTensorVersioned** out, - void** env_stream); - /*! \brief Argument setter for a given python argument. */ struct TVMFFIPyArgSetter { /*! @@ -83,7 +102,15 @@ struct TVMFFIPyArgSetter { /*! * \brief Optional DLPack exporter for for setters that leverages DLPack protocol. */ - DLPackPyObjectCExporter dlpack_c_exporter{nullptr}; + DLPackPyObjectExporter c_dlpack_exporter{nullptr}; + /*! + * \brief Optional DLPack importer for for setters that leverages DLPack protocol. + */ + DLPackPyObjectImporter c_dlpack_importer{nullptr}; + /*! + * \brief Optional DLPack allocator for for setters that leverages DLPack protocol. + */ + DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr}; /*! * \brief Invoke the setter. * \param call_ctx The call context. @@ -239,11 +266,14 @@ class TVMFFIPyCallManager { * \param py_arg_tuple The arguments to the function * \param result The result of the function * \param c_api_ret_code The return code of the C-call + * \param release_gil Whether to release the GIL + * \param optional_out_dlpack_importer The DLPack importer to be used for the result * \return 0 on when there is no python error, -1 on python error * \note When an error happens on FFI side, we should return 0 and set c_api_ret_code */ int Call(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, - TVMFFIAny* result, int* c_api_ret_code) { + TVMFFIAny* result, int* c_api_ret_code, bool release_gil, + DLPackPyObjectImporter* optional_out_dlpack_importer) { int64_t num_args = PyTuple_Size(py_arg_tuple); if (num_args == -1) return -1; try { @@ -256,27 +286,44 @@ class TVMFFIPyCallManager { if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; } TVMFFIStreamHandle prev_stream = nullptr; + DLPackTensorAllocator prev_tensor_allocator = nullptr; // setup stream context if needed if (ctx.device_type != -1) { c_api_ret_code[0] = - TVMFFIEnvSetCurrentStream(ctx.device_type, ctx.device_id, ctx.stream, &prev_stream); + TVMFFIEnvSetStream(ctx.device_type, ctx.device_id, ctx.stream, &prev_stream); // setting failed, directly return if (c_api_ret_code[0] != 0) return 0; } + if (ctx.c_dlpack_tensor_allocator != nullptr) { + c_api_ret_code[0] = + TVMFFIEnvSetTensorAllocator(ctx.c_dlpack_tensor_allocator, 0, &prev_tensor_allocator); + if (c_api_ret_code[0] != 0) return 0; + } // call the function - // release the GIL - Py_BEGIN_ALLOW_THREADS; - c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); - Py_END_ALLOW_THREADS; + if (release_gil) { + // release the GIL + Py_BEGIN_ALLOW_THREADS; + c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); + Py_END_ALLOW_THREADS; + } else { + c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); + } // restore the original stream if (ctx.device_type != -1 && prev_stream != ctx.stream) { // always try recover first, even if error happens - if (TVMFFIEnvSetCurrentStream(ctx.device_type, ctx.device_id, prev_stream, nullptr) != 0) { + if (TVMFFIEnvSetStream(ctx.device_type, ctx.device_id, prev_stream, nullptr) != 0) { // recover failed, set python error PyErr_SetString(PyExc_RuntimeError, "Failed to recover stream"); return -1; } } + if (prev_tensor_allocator != ctx.c_dlpack_tensor_allocator) { + c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(prev_tensor_allocator, 0, nullptr); + if (c_api_ret_code[0] != 0) return 0; + } + if (optional_out_dlpack_importer != nullptr && ctx.c_dlpack_importer != nullptr) { + *optional_out_dlpack_importer = ctx.c_dlpack_importer; + } return 0; } catch (const std::exception& ex) { // very rare, catch c++ exception and set python error @@ -376,12 +423,16 @@ class TVMFFIPyCallManager { * \param py_arg_tuple The arguments to the function * \param result The result of the function * \param c_api_ret_code The return code of the function + * \param release_gil Whether to release the GIL + * \param out_dlpack_exporter The DLPack exporter to be used for the result * \return 0 on success, nonzero on failure */ inline int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, - PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code) { + PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, + bool release_gil = true, + DLPackPyObjectImporter* out_dlpack_importer = nullptr) { return TVMFFIPyCallManager::ThreadLocal()->Call(setter_factory, func_handle, py_arg_tuple, result, - c_api_ret_code); + c_api_ret_code, release_gil, out_dlpack_importer); } /*! diff --git a/ffi/python/tvm_ffi/libinfo.py b/ffi/python/tvm_ffi/libinfo.py index b449bc1abcf5..b02897f27917 100644 --- a/ffi/python/tvm_ffi/libinfo.py +++ b/ffi/python/tvm_ffi/libinfo.py @@ -116,6 +116,18 @@ def find_include_path(): raise RuntimeError("Cannot find include path.") +def find_python_helper_include_path(): + """Find header files for C compilation.""" + candidates = [ + os.path.join(os.path.dirname(os.path.realpath(__file__)), "include"), + os.path.join(os.path.dirname(os.path.realpath(__file__)), "cython"), + ] + for candidate in candidates: + if os.path.isfile(os.path.join(candidate, "tvm_ffi_python_helpers.h")): + return candidate + raise RuntimeError("Cannot find python helper include path.") + + def find_dlpack_include_path(): """Find dlpack header files for C compilation.""" install_include_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "include") @@ -142,3 +154,14 @@ def find_cython_lib(): for path in glob.glob(os.path.join(candidate, f"core*.{suffixes}")): return os.path.abspath(path) raise RuntimeError("Cannot find tvm cython path.") + + +def include_paths(): + """Find all include paths needed for FFI related compilation.""" + include_path = find_include_path() + python_helper_include_path = find_python_helper_include_path() + dlpack_include_path = find_dlpack_include_path() + result = [include_path, dlpack_include_path] + if python_helper_include_path != include_path: + result.append(python_helper_include_path) + return result diff --git a/ffi/scripts/benchmark_dlpack.py b/ffi/scripts/benchmark_dlpack.py index 364afa1b5fdf..2ab85bf03559 100644 --- a/ffi/scripts/benchmark_dlpack.py +++ b/ffi/scripts/benchmark_dlpack.py @@ -436,9 +436,12 @@ def main(): ) bench_torch_get_current_stream(repeat, "python", torch_get_cuda_stream_native) print("---------------------------------------------------") - print("Benchmark tvm_ffi.print_helper_info") + print("Debug information") print("---------------------------------------------------") tvm_ffi.core._print_debug_info() + release_gil = tvm_ffi.get_global_func("testing.nop").release_gil + print(f"TVM_FFI_RELEASE_GIL_BY_DEFAULT={int(release_gil)}") + print("---------------------------------------------------") if __name__ == "__main__": diff --git a/ffi/src/ffi/extra/env_context.cc b/ffi/src/ffi/extra/env_context.cc new file mode 100644 index 000000000000..30f9270dabc7 --- /dev/null +++ b/ffi/src/ffi/extra/env_context.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/ffi/extra/env_context.cc + * + * \brief A minimalistic env context based on ffi values. + */ + +#include +#include + +#include + +namespace tvm { +namespace ffi { + +class EnvContext { + public: + void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, + TVMFFIStreamHandle* out_original_stream) { + if (static_cast(device_type) >= stream_table_.size()) { + stream_table_.resize(device_type + 1); + } + if (static_cast(device_id) >= stream_table_[device_type].size()) { + stream_table_[device_type].resize(device_id + 1, nullptr); + } + if (out_original_stream != nullptr) { + *out_original_stream = stream_table_[device_type][device_id]; + } + stream_table_[device_type][device_id] = stream; + } + + TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) { + if (static_cast(device_type) < stream_table_.size() && + static_cast(device_id) < stream_table_[device_type].size()) { + return stream_table_[device_type][device_id]; + } + return nullptr; + } + + DLPackTensorAllocator GetDLPackTensorAllocator() { + if (dlpack_allocator_ != nullptr) { + return dlpack_allocator_; + } + return GlobalTensorAllocator(); + } + + void SetDLPackTensorAllocator(DLPackTensorAllocator allocator, int write_to_global_context, + DLPackTensorAllocator* opt_out_original_allocator) { + dlpack_allocator_ = allocator; + if (write_to_global_context != 0) { + GlobalTensorAllocator() = allocator; + } + if (opt_out_original_allocator != nullptr) { + *opt_out_original_allocator = dlpack_allocator_; + } + dlpack_allocator_ = allocator; + } + + static EnvContext* ThreadLocal() { + static thread_local EnvContext inst; + return &inst; + } + + private: + // use static function to avoid static initialization order issue + static DLPackTensorAllocator& GlobalTensorAllocator() { // NOLINT(*) + static DLPackTensorAllocator allocator = nullptr; + return allocator; + } + std::vector> stream_table_; + DLPackTensorAllocator dlpack_allocator_ = nullptr; +}; + +} // namespace ffi +} // namespace tvm + +int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, + TVMFFIStreamHandle* out_original_stream) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::EnvContext::ThreadLocal()->SetStream(device_type, device_id, stream, + out_original_stream); + TVM_FFI_SAFE_CALL_END(); +} + +TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id) { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + return tvm::ffi::EnvContext::ThreadLocal()->GetStream(device_type, device_id); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetStream); +} + +int TVMFFIEnvSetTensorAllocator(DLPackTensorAllocator allocator, int write_to_global_context, + DLPackTensorAllocator* opt_out_original_allocator) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::EnvContext::ThreadLocal()->SetDLPackTensorAllocator(allocator, write_to_global_context, + opt_out_original_allocator); + TVM_FFI_SAFE_CALL_END(); +} + +DLPackTensorAllocator TVMFFIEnvGetTensorAllocator() { + TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); + return tvm::ffi::EnvContext::ThreadLocal()->GetDLPackTensorAllocator(); + TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetTensorAllocator); +} diff --git a/ffi/src/ffi/extra/stream_context.cc b/ffi/src/ffi/extra/stream_context.cc deleted file mode 100644 index 5a6afad4c1d8..000000000000 --- a/ffi/src/ffi/extra/stream_context.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/* - * \file src/ffi/extra/stream_context.cc - * - * \brief A minimalistic stream context based on ffi values. - */ - -#include -#include - -#include - -namespace tvm { -namespace ffi { - -class StreamContext { - public: - void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - if (static_cast(device_type) >= stream_table_.size()) { - stream_table_.resize(device_type + 1); - } - if (static_cast(device_id) >= stream_table_[device_type].size()) { - stream_table_[device_type].resize(device_id + 1, nullptr); - } - if (out_original_stream != nullptr) { - *out_original_stream = stream_table_[device_type][device_id]; - } - stream_table_[device_type][device_id] = stream; - } - - TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) { - if (static_cast(device_type) < stream_table_.size() && - static_cast(device_id) < stream_table_[device_type].size()) { - return stream_table_[device_type][device_id]; - } - return nullptr; - } - - static StreamContext* ThreadLocal() { - static thread_local StreamContext inst; - return &inst; - } - - private: - std::vector> stream_table_; -}; - -} // namespace ffi -} // namespace tvm - -int TVMFFIEnvSetCurrentStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream, - TVMFFIStreamHandle* out_original_stream) { - TVM_FFI_SAFE_CALL_BEGIN(); - tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream, - out_original_stream); - TVM_FFI_SAFE_CALL_END(); -} - -TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) { - TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::StreamContext::ThreadLocal()->GetStream(device_type, device_id); - TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetCurrentStream); -} diff --git a/ffi/tests/cpp/test_tensor.cc b/ffi/tests/cpp/test_tensor.cc index 3ad182d844f0..7c696a3429c1 100644 --- a/ffi/tests/cpp/test_tensor.cc +++ b/ffi/tests/cpp/test_tensor.cc @@ -32,6 +32,23 @@ inline Tensor Empty(Shape shape, DLDataType dtype, DLDevice device) { return Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device); } +int TestDLPackTensorAllocator(DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, + const char* message)) { + Shape shape(prototype->shape, prototype->shape + prototype->ndim); + Tensor nd = Empty(shape, prototype->dtype, prototype->device); + *out = nd.ToDLPackVersioned(); + return 0; +} + +int TestDLPackTensorAllocatorError(DLTensor* prototype, DLManagedTensorVersioned** out, + void* error_ctx, + void (*SetError)(void* error_ctx, const char* kind, + const char* message)) { + SetError(error_ctx, "RuntimeError", "TestDLPackTensorAllocatorError"); + return -1; +} + TEST(Tensor, Basic) { Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); Shape shape = nd.shape(); @@ -116,4 +133,32 @@ TEST(Tensor, DLPackVersioned) { } EXPECT_EQ(tensor.use_count(), 1); } + +TEST(Tensor, DLPackAlloc) { + // Test successful allocation + Tensor tensor = Tensor::FromDLPackAlloc(TestDLPackTensorAllocator, {1, 2, 3}, + DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); + EXPECT_EQ(tensor.use_count(), 1); + EXPECT_EQ(tensor.shape().size(), 3); + EXPECT_EQ(tensor.shape()[0], 1); + EXPECT_EQ(tensor.shape()[1], 2); + EXPECT_EQ(tensor.shape()[2], 3); + EXPECT_EQ(tensor.dtype().code, kDLFloat); + EXPECT_EQ(tensor.dtype().bits, 32); + EXPECT_EQ(tensor.dtype().lanes, 1); + EXPECT_EQ(tensor->device.device_type, kDLCPU); + EXPECT_EQ(tensor->device.device_id, 0); + EXPECT_NE(tensor->data, nullptr); +} + +TEST(Tensor, DLPackAllocError) { + // Test error handling in DLPackAlloc + EXPECT_THROW( + { + Tensor::FromDLPackAlloc(TestDLPackTensorAllocatorError, {1, 2, 3}, + DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); + }, + tvm::ffi::Error); +} + } // namespace diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py index 9a10476d8eff..89f00b1f36fd 100644 --- a/ffi/tests/python/test_load_inline.py +++ b/ffi/tests/python/test_load_inline.py @@ -186,7 +186,7 @@ def test_load_inline_cuda(): // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); @@ -202,6 +202,66 @@ def test_load_inline_cuda(): torch.testing.assert_close(x_cuda + 1, y_cuda) +@pytest.mark.skipif( + torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" +) +def test_load_inline_cuda_with_env_tensor_allocator(): + if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"): + pytest.skip("Torch does not support __c_dlpack_tensor_allocator__") + mod: Module = tvm_ffi.cpp.load_inline( + name="hello", + cpp_sources=r""" + #include + + tvm::ffi::Tensor return_add_one(DLTensor* x); + """, + cuda_sources=r""" + #include + + __global__ void AddOneKernel(float* x, float* y, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + y[idx] = x[idx] + 1; + } + } + namespace ffi = tvm::ffi; + + ffi::Tensor return_add_one(DLTensor* x) { + // implementation of a library function + TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; + DLDataType f32_dtype{kDLFloat, 32, 1}; + TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor"; + // allocate a new tensor with the env tensor allocator + // it will be redirected to torch.empty when calling the function + ffi::Tensor y = ffi::Tensor::FromDLPackAlloc( + TVMFFIEnvGetTensorAllocator(), ffi::Shape({x->shape[0]}), f32_dtype, x->device); + int64_t n = x->shape[0]; + int64_t nthread_per_block = 256; + int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block; + // Obtain the current stream from the environment + // it will be set to torch.cuda.current_stream() when calling the function + // with torch.Tensors + cudaStream_t stream = static_cast( + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); + // launch the kernel + AddOneKernel<<>>(static_cast(x->data), + static_cast(y->data), n); + return y; + } + """, + functions=["return_add_one"], + ) + + if torch is not None: + x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") + y_cuda = mod.return_add_one(x_cuda) + assert isinstance(y_cuda, torch.Tensor) + assert y_cuda.shape == (5,) + assert y_cuda.dtype == torch.float32 + torch.testing.assert_close(x_cuda + 1, y_cuda) + assert y_cuda.is_cuda + + @pytest.mark.skipif( torch is None or not torch.cuda.is_available(), reason="Requires torch and CUDA" ) @@ -248,7 +308,7 @@ def test_load_inline_both(): // it will be set to torch.cuda.current_stream() when calling the function // with torch.Tensors cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(x->device.device_type, x->device.device_id)); + TVMFFIEnvGetStream(x->device.device_type, x->device.device_id)); // launch the kernel AddOneKernel<<>>(static_cast(x->data), static_cast(y->data), n); diff --git a/ffi/tests/python/test_tensor.py b/ffi/tests/python/test_tensor.py index aa2482f88852..5c7051279815 100644 --- a/ffi/tests/python/test_tensor.py +++ b/ffi/tests/python/test_tensor.py @@ -55,22 +55,14 @@ def test_shape_object(): assert isinstance(shape3, tvm_ffi.Shape) -@pytest.mark.skipif(torch is None, reason="Torch is not installed") +@pytest.mark.skipif(torch is None, reason="Fast torch dlpack importer is not enabled") def test_tensor_auto_dlpack(): - def check(x, y): - assert isinstance(y, tvm_ffi.Tensor) - assert y.shape == (128,) - assert y.dtype == tvm_ffi.dtype("int64") - assert y.device.dlpack_device_type() == tvm_ffi.DLDeviceType.kDLCPU - assert y.device.index == 0 - x2 = torch.from_dlpack(y) - np.testing.assert_equal(x2.numpy(), x.numpy()) - x = torch.arange(128) fecho = tvm_ffi.get_global_func("testing.echo") y = fecho(x) - check(x, y) - - # pass in list of tensors - y = fecho([x]) - check(x, y[0]) + assert isinstance(y, torch.Tensor) + assert y.data_ptr() == x.data_ptr() + assert y.dtype == x.dtype + assert y.shape == x.shape + assert y.device == x.device + np.testing.assert_equal(y.numpy(), x.numpy()) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index fe29cd59459b..ff804e83460c 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -147,7 +147,7 @@ def instantiate_attention_template(attrs): } CHECK(Attention::check_supported(p)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); kernel_fn<<>>(p); @@ -185,7 +185,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${query}->data), @@ -235,7 +235,7 @@ def instantiate_flash_attention_template(attrs): int v_batch_stride = v_row_stride * ${num_keys}; int o_batch_stride = o_row_stride * ${num_queries}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_forward( static_cast(${qkv}->data), @@ -291,7 +291,7 @@ def instantiate_flash_attention_var_len_template(attrs): int v_row_stride = v_head_stride * ${num_kv_heads}; int o_row_stride = o_head_stride * ${num_q_heads}; - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${query}->device.device_id)); flash_attn::flash_attention_var_len_forward( static_cast(${query}->data), diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index b0afdcdd6e84..e323e2a14937 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -424,7 +424,7 @@ def instantiate_conv2d_template(attrs): TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); ${split_k_update} - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${data_arg}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${data_arg}->device.device_id)); status = conv2d_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index 453839cc8130..d8940230e0e3 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -345,7 +345,7 @@ def instantiate_gemm_template(attrs): status = gemm_op.initialize(arguments, workspace.get()); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${A_arg}->device.device_id)); status = gemm_op(stream); TVM_FFI_ICHECK(status == cutlass::Status::kSuccess); @@ -428,7 +428,7 @@ def emit_fp16A_intB_matmul(attrs): int k = ${B_arg}->shape[0]; cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id)); + TVMFFIEnvGetStream(kDLCUDA, ${A_arg}->device.device_id)); """, attrs, ) diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py b/python/tvm/contrib/cutlass/layer_norm_operation.py index d2a031024475..b0f7dc7c14f7 100644 --- a/python/tvm/contrib/cutlass/layer_norm_operation.py +++ b/python/tvm/contrib/cutlass/layer_norm_operation.py @@ -39,7 +39,7 @@ def instantiate_layer_norm_template(attrs): cutlass::TensorRef _beta((data_type*)${beta}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${input}->device.device_id)); cutlass::layernorm(size, _output, _input, _gamma, _beta, stream); """ diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py b/python/tvm/contrib/cutlass/rms_norm_operation.py index 51c18d4ae47b..3d038ab21011 100644 --- a/python/tvm/contrib/cutlass/rms_norm_operation.py +++ b/python/tvm/contrib/cutlass/rms_norm_operation.py @@ -38,7 +38,7 @@ def instantiate_rms_norm_template(attrs): cutlass::TensorRef _weight((data_type*)${weight}->data, layout_channels); cutlass::TensorRef _output((data_type*)out0->data, layout_2D); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, ${input}->device.device_id)); cutlass::rmsnorm(size, _output, _input, _weight, stream, ${rms_eps}); """ diff --git a/src/contrib/msc/plugin/tvm_codegen.cc b/src/contrib/msc/plugin/tvm_codegen.cc index 373e9aaac294..ae107c06773f 100644 --- a/src/contrib/msc/plugin/tvm_codegen.cc +++ b/src/contrib/msc/plugin/tvm_codegen.cc @@ -385,7 +385,7 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const ffi::String& d compute_args.push_back("meta_attr"); if (device == "cuda") { // TODO(tvm-team): update to support get stream from device id - stack_.assign("stream", "TVMFFIEnvGetCurrentStream(kDLCUDA, 0)", "auto"); + stack_.assign("stream", "TVMFFIEnvGetStream(kDLCUDA, 0)", "auto"); compute_args.push_back("stream"); } CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args); diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 13f958744e61..88a0dc128df2 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -558,7 +558,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ cublasLtHandle_t ltHandle; CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, A->device.device_id)); + static_cast(TVMFFIEnvGetStream(kDLCUDA, A->device.device_id)); CallLtIgemm(args, ret, ltHandle, stream); CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 98b05ba31995..33bdaaf0f7c0 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -91,7 +91,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { CUDA_CALL(cudaGetDevice(&device_id)); } auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 0ba654c9ebc8..f5248fde7e00 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -44,8 +44,8 @@ typedef dmlc::ThreadLocalStore CuBlasThreadStore; CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal(DLDevice curr_device) { CuBlasThreadEntry* retval = CuBlasThreadStore::Get(); - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id)); CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, stream)); return retval; } diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index fa046980e39a..48560f4306a6 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -164,8 +164,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { std::function op_exec = [=]() { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream)); auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) { diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index acedf7a9e2dd..f36a50a80a35 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -129,8 +129,8 @@ CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(Device curr_device, bool check_e ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED"; } - cudaStream_t stream = static_cast( - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id)); + cudaStream_t stream = + static_cast(TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id)); CUDNN_CALL(cudnnSetStream(res->handle, stream)); return res; } diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh index ffc05893cad6..0527829c528d 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh @@ -38,7 +38,7 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight, Tensor indptr, Tensor // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu b/src/runtime/contrib/cutlass/fp8_gemm.cu index 2be8c09da2dc..5c73c0cb74bd 100644 --- a/src/runtime/contrib/cutlass/fp8_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_gemm.cu @@ -42,8 +42,7 @@ template void tvm_cutlass_fp8_gemm(Tensor x, Tensor weight, Tensor workspace, Tensor alpha, Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_GE(x->ndim, 2); CHECK_EQ(weight->ndim, 2); diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index 48e68cb804f6..97f3e80e5bf0 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -46,8 +46,7 @@ void tvm_cutlass_fp8_group_gemm(Tensor x, Tensor weight, Tensor indptr, Tensor w Tensor alpha, Tensor out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, x->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, x->device.device_id)); CHECK_EQ(x->ndim, 2); CHECK_EQ(weight->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh index e03366a03860..35f08efbc57c 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -40,7 +40,7 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, Tensor b, Tensor scale Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_GE(a->ndim, 2); CHECK_EQ(scales_a->ndim, a->ndim); @@ -106,7 +106,7 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, Tensor b, Tensor scales Tensor out) { // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. // Recommened size is 4MB. - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 3); CHECK_EQ(scales_a->ndim, 3); diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu index 420f93d4f2f3..8ac0e0452d57 100644 --- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -38,8 +38,7 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, Tensor b, Tensor scales Tensor out) { // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. // Recommended size is 4MB. - cudaStream_t stream = - static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, a->device.device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, a->device.device_id)); CHECK_EQ(a->ndim, 2); CHECK_EQ(b->ndim, 3); CHECK_EQ(indptr->ndim, 1); diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc index 6e760b7f0625..f53f8f7c6a51 100644 --- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc @@ -89,7 +89,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase { ROCM_CALL(hipGetDevice(&device_id)); } auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(DLDevice{kDLROCM, device_id}); - hipStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + hipStream_t stream = static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) { ICHECK_LT(idx, node.GetInputs().size()); diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc b/src/runtime/contrib/hipblas/hipblas_utils.cc index 1b61cbd38219..17ed9a0d936d 100644 --- a/src/runtime/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/contrib/hipblas/hipblas_utils.cc @@ -44,8 +44,7 @@ typedef dmlc::ThreadLocalStore HipBlasThreadStore; HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(DLDevice curr_device) { HipBlasThreadEntry* retval = HipBlasThreadStore::Get(); - TVMFFIStreamHandle stream = - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); + TVMFFIStreamHandle stream = TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id); CHECK_HIPBLAS_ERROR(hipblasSetStream(retval->handle, static_cast(stream))); return retval; } diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index e860ba8ea7f2..617ea5aaf027 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -56,8 +56,7 @@ typedef dmlc::ThreadLocalStore MIOpenThreadStore; MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal(Device curr_device) { // Need to update stream per fetch to avoid stream switching MIOpenThreadEntry* res = MIOpenThreadStore::Get(); - TVMFFIStreamHandle stream = - TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id); + TVMFFIStreamHandle stream = TVMFFIEnvGetStream(curr_device.device_type, curr_device.device_id); MIOPEN_CALL(miopenSetStream(res->handle, stream)); return res; } diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc b/src/runtime/contrib/msc/tensorrt_runtime.cc index 8a837370fa34..07b190a2c0be 100644 --- a/src/runtime/contrib/msc/tensorrt_runtime.cc +++ b/src/runtime/contrib/msc/tensorrt_runtime.cc @@ -133,7 +133,7 @@ class MSCTensorRTRuntime : public JSONRuntimeBase { context.Set("datas", input_datas); (*pf)(context, "before_forward", graph_name_, tool_tag_); } - auto tvm_stream = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); + auto tvm_stream = TVMFFIEnvGetStream(kDLCUDA, device_id); #if TRT_VERSION_GE(6, 0, 1) ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr)) << "Running TensorRT failed."; diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index 1adf95f69320..7eede1b65485 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -94,7 +94,7 @@ class WorkspaceMemoryResource : public thrust::mr::memory_resource { auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); return thrust::cuda::par_nosync(memory_resouce).on(stream); } diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 623968fedeab..f8ec539cc0dc 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -301,7 +301,7 @@ class CUDATimerNode : public TimerNode { // cudaEventRecord do some stream synchronization? int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - stream_ = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id); + stream_ = TVMFFIEnvGetStream(kDLCUDA, device_id); CUDA_CALL(cudaEventRecord(start_, static_cast(stream_))); } virtual void Stop() { @@ -352,10 +352,10 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("runtime.GetCudaFreeMemory", GetCudaFreeMemory) .def("runtime.get_cuda_stream", []() { // TODO(tvm-team): remove once confirms all dep such as flashinfer - // migrated to TVMFFIEnvGetCurrentStream + // migrated to TVMFFIEnvGetStream int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - return static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + return static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); }); }); diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 9086903d0141..9673dfa169fd 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -199,7 +199,7 @@ class CUDAWrappedFunc { } } } - CUstream strm = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + CUstream strm = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), wl.dyn_shmem_size, strm, void_args, nullptr); diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/runtime/cuda/l2_cache_flush.cc index 0c7f939181a2..d02f4efdb900 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/runtime/cuda/l2_cache_flush.cc @@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not exist."; int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - cudaStream_t stream = static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)); + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); L2Flush::ThreadLocal()->Flush(stream); }); }); diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc index e574ce14b004..96d370dfe2e5 100644 --- a/src/runtime/device_api.cc +++ b/src/runtime/device_api.cc @@ -165,12 +165,11 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; } void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {} void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { - TVM_FFI_CHECK_SAFE_CALL( - TVMFFIEnvSetCurrentStream(dev.device_type, dev.device_id, stream, nullptr)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, stream, nullptr)); } TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { - return TVMFFIEnvGetCurrentStream(dev.device_type, dev.device_id); + return TVMFFIEnvGetStream(dev.device_type, dev.device_id); } void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) { diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 4b042d8d491d..2ea9727b8b53 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -264,7 +264,7 @@ class ROCMTimerNode : public TimerNode { virtual void Start() { int device_id; ROCM_CALL(hipGetDevice(&device_id)); - stream_ = TVMFFIEnvGetCurrentStream(kDLROCM, device_id); + stream_ = TVMFFIEnvGetStream(kDLROCM, device_id); ROCM_CALL(hipEventRecord(start_, static_cast(stream_))); } virtual void Stop() { @@ -302,7 +302,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ .def("runtime.get_rocm_stream", []() { int device_id; ROCM_CALL(hipGetDevice(&device_id)); - return static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + return static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); }); }); diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 3ef9bf47a9b1..f8f7ed673f07 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -172,7 +172,7 @@ class ROCMWrappedFunc { fcache_[device_id] = m_->GetFunc(device_id, func_name_); } - hipStream_t strm = static_cast(TVMFFIEnvGetCurrentStream(kDLROCM, device_id)); + hipStream_t strm = static_cast(TVMFFIEnvGetStream(kDLROCM, device_id)); ThreadWorkLoad wl = launch_param_config_.Extract(args); void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/runtime/vm/cuda/cuda_graph_builtin.cc index 252841528152..0e8cc2090784 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc @@ -118,14 +118,13 @@ class CUDACaptureStream { explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) { CUDA_CALL(cudaGetDevice(&device_id_)); TVM_FFI_CHECK_SAFE_CALL( - TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, capture_stream_, - reinterpret_cast(&prev_default_stream_))); + TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_, + reinterpret_cast(&prev_default_stream_))); CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal)); } ~CUDACaptureStream() noexcept(false) { cudaStreamEndCapture(capture_stream_, output_graph_); - TVM_FFI_CHECK_SAFE_CALL( - TVMFFIEnvSetCurrentStream(kDLCUDA, device_id_, prev_default_stream_, nullptr)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, prev_default_stream_, nullptr)); } private: @@ -159,8 +158,8 @@ class CUDAGraphExtensionNode : public VMExtensionNode { const auto& [states, exec] = it->second; int device_id; CUDA_CALL(cudaGetDevice(&device_id)); - CUDA_CALL(cudaGraphLaunch( - exec, static_cast(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id)))); + CUDA_CALL( + cudaGraphLaunch(exec, static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)))); return states; }