Skip to content

Conversation

@tqchen
Copy link
Member

@tqchen tqchen commented Sep 12, 2025

This PR adds support for three C functions to speedup DLPack exchange. As of now, DLPack exchange relies on python functions such as tensor.__dlpack__().

While they works well for common cases, the general overhead of such exchange is at the level of 0.2-0.3 us for very well optimized version, and can go up to 0.4-1 us for less optimized implementation.

For a function that takes three arguments f(a, b, c), assume we run DLPack exchange for each argument, the general conversion overhead usually gets to around 1us and sometimes to 3us.

While such overhead can be acceptable in many settings, in GPU applications the extra 1-3us overhead can still be significant.

This PR proposes four functions for speed exchange DLPack tensors without going through python interpreter.

  • DLPackManagedTensorFromPyObjectNoSync for fast exchange owned tensors into consumer
  • DLPackDLTensorFromPyObjectNoSync for fast exchange non-owned tensors into consumer
  • DLPackManagedTensorToPyObjectNoSync for fast exchange consumer tensors into producer (return values)
  • DLPackCurrentWorkStream for query the current working stream from the stream context

Our preliminary results show that these functions, when incorporated correctly via native extensions such as c/c++, can bring exchange cost to the level of 30ns - 80ns, giving us about one order of maginitude speedup. That means the API overhead of functions like f(a, b, c) will be at 0.2us-0.4us level (including exchange), which is close to what native cpp extension overhead do without exchange.

@tqchen
Copy link
Member Author

tqchen commented Sep 12, 2025

RFC #175

@tqchen
Copy link
Member Author

tqchen commented Sep 14, 2025

updates to incorporate suggestion by @dalcinl.

This PR adds support for three C functions to speedup DLPack exchange.
As of now, DLPack exchange relies on python functions such as tensor.__dlpack__().

While they works well for common cases, the general overhead of such exchange is
at the level of 0.2-0.3 us for very well optimized version, and can go up to
0.4-1 us for less optimized implementation.

For a function that takes three arguments f(a, b, c), assume we run DLPack
exchange for each argument, the general conversion overhead usually gets to
around 1us and sometimes to 3us.

While such overhead can be acceptable in many settings, in GPU applications
the extra 1-3us overhead can still be significant.

This PR proposes three functions for speed exchange DLPack tensors without
going through python interpreter.

- DLPackFromPyObject: exports a PyObject Tensor to DLManagedTensorVesioned
- DLPackToPyObject: DLManagedTensorVesioned converts to a PyObject Tensor
- DLPackTensorAllocator: Used to expose one package's tensor allocator to another package
  - This allows for example we implement libraries that allocates intermediate tensor
    based on the caller's specified Tensor Allocator.

Our preliminary results show that these functions, when incorporated correctly
via native extensions such as c/c++, can bring exchange cost to the level of
30ns - 80ns, giving us about one order of maginitude speedup. That means functions
like f(a, b, c) can finish at 0.2us-0.4us level, which is close to what native
cpp extension overhead do without exchange.
@tqchen
Copy link
Member Author

tqchen commented Sep 22, 2025

This PR is updated to reflect the all the suggestions in #175

Naming (updated per suggestion from @oleksandr-pavlyk @kkraus14):

  • DLPackManagedTensorAllocator
  • DLPackManagedTensorToPyObject
  • DLPackManagedTensorFromPyObject

Clarified that the dunder should be attached to the class type @gbonik and @seberg

Copy link

@kkraus14 kkraus14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything LGTM except for the ongoing discussions related to stream / synchronization handling

@tqchen
Copy link
Member Author

tqchen commented Oct 6, 2025

Thanks everyone for comments. We have updated the proposal to include a non-owned version while explicitly state clearly the intend is for fast caling for framework library/DSL call use cases

Summary of the current functions

  • DLPackManagedTensorFromPyObjectNoSync for fast exchange owned tensors into consumer
  • DLPackDLTensorFromPyObjectNoSync for fast exchange non-owned tensors into consumer
  • DLPackManagedTensorToPyObjectNoSync for fast exchange consumer tensors into producer (return values)
  • DLPackCurrentWorkStream for query the current working stream from the stream context

I think it is getting to a mergable state, but would be good to also get final inputs if any

@yongwww
Copy link

yongwww commented Oct 6, 2025

I’m in favor of this RF. Adding C APIs for DLPack exchange and stream querying is a very compelling direction, achieving exchange latency down to tens of nanoseconds could remove a significant bottleneck. This RFC is a valuable addition to DLPack. I’m excited to see it get landed and deployed in production!

Co-authored-by: Sebastian Berg <sebastianb@nvidia.com>
@tqchen
Copy link
Member Author

tqchen commented Oct 8, 2025

incoprorated suggestions from @seberg into the PR. i think we are in good shape, going to merge this in two days if there is no more future comments.

Thanks everyone for suggestions so far

Copy link
Collaborator

@seberg seberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach in general. I am still a bit unsure about the allocation function and the filling one, but if nobody else chimes in, then so be it...

One question about the stream sync, we should clarify that a bit maybe.

One curiosity: Should there be an entry (or function?) that allows discovering which device types are supported?


I think in a sense, I still need to see how this actually looks like for a real-world consumer/producer pair, but in the near future it is maybe still malleable enough.

*/
typedef int (*DLPackManagedTensorAllocator)( //
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, //
void (*SetError)(void* error_ctx, const char* kind, const char* message) //
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One small thing is, I am not sure how we will use kind? Let's say this typically get's converted to a Python exception, then we would have to somewhat agree on e.g. MemoryError or so to translate that.

Maybe this will just settle itself, in the end I bet this pretty much only ever raises MemoryError anyway.

(I think this function is important! I know I am slow to settle on something, but I haven't quite settled on loving this approach. But I suspect its true that asking to allocate e.g. the torch tensor here -- requiring the GIL -- and then viewing that is also not really better.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, likely MemoryError is the most common one likely

DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync;
/*!
* \brief Producer function pointer for DLPackDLTensorFromPyObject
* This function can be NULL when the producer does not support this function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this approach as such, but as I just mentioned it elsewhere, I want to point out that if this is optional and commonly not support, it means all consumers will need:

if (api->managed_tensor_to_py_object_no_sync == NULL) {
    // do complicated thing
}

which unfortunately means that we can get a small speed improvement if available (not having to allocate the DLManagedTensor, etc.) but we may need to support both paths as a consumer, unfortunately.

I suppose part of the solution here may be that you really want a C++ convenience layer that is easy to vendor...

This isn't a deal breaker! It made me wonder if filling in a DLManagedTensor is much worse -- it would be different from currently by not owning its own allocation.

* and the producer can simply return NULL when queried.
* The consumer do not have to do anything on stream sync or setting.
* So CPU only framework can just provide a dummy implementation that
* always set out_current_stream[0] to NULL.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot about the CPU part, allowing the function pointer to be NULL makes sense. But I am also happy to just implement always NULL return or an always -1 with a "NumPy doens't have streams, what's going on" error.

*
* \param device_type The device type.
* \param device_id The device id.
* \param out_current_stream The output current work stream.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit unsure if this is specified as well as it needs to be? I.e. I think NULL would be the default stream.

The question is: Is there any need (or not) an "undefined or no synchronization" return value (such as -1)?
If not, we are all good, but if some producer might need this (for whatever reason), then we need to specify this here.

The alternative is that the producer just has to return the default stream (otherwise the consumer has to guess a stream anyway probably in the kernel use-case!).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don;t think there is a need in this particular context. returning the default stream is likely more well defined

tqchen and others added 2 commits October 9, 2025 13:31
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
@tqchen
Copy link
Member Author

tqchen commented Oct 9, 2025

thanks @seberg on entry function to discover device supported, as of now no, but if there is a need, we could add such API. although i am not sure if it is strictly needed

*
* \sa DLPackExchangeAPI
*/
struct DLPackExchangeAPI* prev_version_api;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, a version v2 may have a DLPackExchangeAPI struct with more/different entries than v1, therefore the struct corresponding to v2 may not have the same layout (and thus type name) than the older v1.

I assume that the intention here is that DLPackExchangeAPI will grow, if ever, to the end; additional, the entries currently defined below will never change (not their relative position, nor the layout of each struct).
I that correct? If not, this will require some nasty and inconvenient fixes.

Long story short: this is a great idea, but the implementation as it is now is not flexible to future changes the layout of the various structures. Just double checking we are all in the same page here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I think we can say only the two first two fields remains unchanged, will do a patch to clarify this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to move the exchange api stable part to a new struct DLPackExchangeAPIHeader

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If things ever change after the first two fields, then the C code users will write to deal with older versions will violate the C language strict aliasing rules and entering the undefined behavior territory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see you changed things to use a header.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dalcinl if it works for you, would be great if you can explicitly approve, thanks

@tqchen
Copy link
Member Author

tqchen commented Oct 10, 2025

thanks everyone for valuable feedbacks so far, planning to merge in 24hours

Co-authored-by: Lisandro Dalcin <dalcinl@gmail.com>
@tqchen tqchen merged commit 1117366 into dmlc:main Oct 11, 2025
3 checks passed
@tqchen
Copy link
Member Author

tqchen commented Oct 11, 2025

Thanks everyone, this is merged

tqchen pushed a commit to apache/tvm-ffi that referenced this pull request Oct 11, 2025
## Summary of Changes

This PR introduces a unified `DLPackExchangeAPI` struct as described in
proposal [175](dmlc/dlpack#175). This new
convention replaces the previous mechanism of separate function
pointers, and aligns with the latest DLPack standard as shown in PR
[174](dmlc/dlpack#174).

Within the new `DLPackExchangeAPI` struct, it also includes a
`current_work_stream` function pointer that allows more robust and
integrated querying of the current device stream (e.g., CUDA stream)
during DLPack tensor exchanges. All the conversion from/to DLPack has
been updated to `_no_sync`, meaning you should use `current_work_stream`
to explicitly handle stream synchronization. It also includes a
non-owning DLTensor conversion to avoid unnecessary reference counting.

Following this change, the Python FFI for PyTorch has been updated to
expose the new `DLPackExchangeAPI` struct via
`__c_dlpack_exchange_api__` on torch.Tensor.

The `3rdparty/dlpack` has been updated to incorporate the latest commit.
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Dec 4, 2025
…165483)

## Addressed Issue

Issue #162845

## Summary of Changes

This PR introduces a unified `DLPackExchangeAPI` struct as described in proposal [175](dmlc/dlpack#175). This new convention replaces the previous mechanism of separate function pointers, and aligns with the latest DLPack standard as shown in PR [174](dmlc/dlpack#174).

Specifically, the new `DLPackExchangeAPI` struct is exposed as `torch.Tensor.__c_dlpack_exchange_api__`, which stores and exposes the following function pointers:

* `managed_tensor_allocator`
* `managed_tensor_from_py_object_no_sync`
* `managed_tensor_to_py_object_no_sync`
* `dltensor_from_py_object_no_sync`
* `current_work_stream`

Within the new `DLPackExchangeAPI` struct, the new `current_work_stream` function pointer allows more robust and integrated querying of the current device stream (e.g., CUDA stream) during DLPack tensor exchanges. All the conversion from/to DLPack has been updated to `_no_sync`, meaning you should use `current_work_stream` to explicitly handle stream synchronization. It also includes a non-owning DLTensor conversion `dltensor_from_py_object_no_sync` to avoid unnecessary reference counting.

Following this change, the `dlpack.h` has been updated to the latest DLPack.

Unit tests are added using `torch.utils.cpp_extension.load_inline` to avoid GIL release issues
when calling `THPVariable_Wrap`.
Pull Request resolved: #165483
Approved by: https://github.com/tqchen, https://github.com/albanD
umechand-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 8, 2025
…ytorch#165483)

## Addressed Issue

Issue pytorch#162845

## Summary of Changes

This PR introduces a unified `DLPackExchangeAPI` struct as described in proposal [175](dmlc/dlpack#175). This new convention replaces the previous mechanism of separate function pointers, and aligns with the latest DLPack standard as shown in PR [174](dmlc/dlpack#174).

Specifically, the new `DLPackExchangeAPI` struct is exposed as `torch.Tensor.__c_dlpack_exchange_api__`, which stores and exposes the following function pointers:

* `managed_tensor_allocator`
* `managed_tensor_from_py_object_no_sync`
* `managed_tensor_to_py_object_no_sync`
* `dltensor_from_py_object_no_sync`
* `current_work_stream`

Within the new `DLPackExchangeAPI` struct, the new `current_work_stream` function pointer allows more robust and integrated querying of the current device stream (e.g., CUDA stream) during DLPack tensor exchanges. All the conversion from/to DLPack has been updated to `_no_sync`, meaning you should use `current_work_stream` to explicitly handle stream synchronization. It also includes a non-owning DLTensor conversion `dltensor_from_py_object_no_sync` to avoid unnecessary reference counting.

Following this change, the `dlpack.h` has been updated to the latest DLPack.

Unit tests are added using `torch.utils.cpp_extension.load_inline` to avoid GIL release issues
when calling `THPVariable_Wrap`.
Pull Request resolved: pytorch#165483
Approved by: https://github.com/tqchen, https://github.com/albanD
JacobSzwejbka pushed a commit to pytorch/pytorch that referenced this pull request Dec 8, 2025
…165483)

## Addressed Issue

Issue #162845

## Summary of Changes

This PR introduces a unified `DLPackExchangeAPI` struct as described in proposal [175](dmlc/dlpack#175). This new convention replaces the previous mechanism of separate function pointers, and aligns with the latest DLPack standard as shown in PR [174](dmlc/dlpack#174).

Specifically, the new `DLPackExchangeAPI` struct is exposed as `torch.Tensor.__c_dlpack_exchange_api__`, which stores and exposes the following function pointers:

* `managed_tensor_allocator`
* `managed_tensor_from_py_object_no_sync`
* `managed_tensor_to_py_object_no_sync`
* `dltensor_from_py_object_no_sync`
* `current_work_stream`

Within the new `DLPackExchangeAPI` struct, the new `current_work_stream` function pointer allows more robust and integrated querying of the current device stream (e.g., CUDA stream) during DLPack tensor exchanges. All the conversion from/to DLPack has been updated to `_no_sync`, meaning you should use `current_work_stream` to explicitly handle stream synchronization. It also includes a non-owning DLTensor conversion `dltensor_from_py_object_no_sync` to avoid unnecessary reference counting.

Following this change, the `dlpack.h` has been updated to the latest DLPack.

Unit tests are added using `torch.utils.cpp_extension.load_inline` to avoid GIL release issues
when calling `THPVariable_Wrap`.
Pull Request resolved: #165483
Approved by: https://github.com/tqchen, https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants