-
Notifications
You must be signed in to change notification settings - Fork 156
[RFC] Support DLPACK C Functions for Speed Exchange and Stream Handling #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
145b3d9
0330ad9
ffb153d
a947bef
4b1de24
bddb25b
8e628e8
e120200
df77508
180dfcd
eda587a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| /*! | ||
| * Copyright (c) 2017 by Contributors | ||
| * Copyright (c) 2017 - by Contributors | ||
| * \file dlpack.h | ||
| * \brief The common header of DLPack. | ||
| */ | ||
|
|
@@ -326,7 +326,7 @@ typedef struct DLManagedTensor { | |
| * | ||
| * \note This is the current standard DLPack exchange data structure. | ||
| */ | ||
| struct DLManagedTensorVersioned { | ||
| typedef struct DLManagedTensorVersioned { | ||
| /*! | ||
| * \brief The API and ABI version of the current managed Tensor | ||
| */ | ||
|
|
@@ -360,7 +360,266 @@ struct DLManagedTensorVersioned { | |
| uint64_t flags; | ||
| /*! \brief DLTensor which is being memory managed */ | ||
| DLTensor dl_tensor; | ||
| }; | ||
| } DLManagedTensorVersioned; | ||
|
|
||
| //---------------------------------------------------------------------- | ||
| // DLPack `__c_dlpack_exchange_api__` fast exchange protocol definitions | ||
| //---------------------------------------------------------------------- | ||
| /*! | ||
| * \brief Request a producer library to create a new tensor. | ||
| * | ||
| * Create a new `DLManagedTensorVersioned` within the context of the producer | ||
| * library. The allocation is defined via the prototype DLTensor. | ||
| * | ||
| * This function is exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * \param prototype The prototype DLTensor. Only the dtype, ndim, shape, | ||
| * and device fields are used. | ||
| * \param out The output DLManagedTensorVersioned. | ||
| * \param error_ctx Context for `SetError`. | ||
| * \param SetError The function to set the error. | ||
| * \return The owning DLManagedTensorVersioned* or NULL on failure. | ||
| * SetError is called exactly when NULL is returned (the implementor | ||
| * must ensure this). | ||
| * \note - As a C function, must not thrown C++ exceptions. | ||
| * - Error propagation via SetError to avoid any direct need | ||
| * of Python API. Due to this `SetError` may have to ensure the GIL is | ||
| * held since it will presumably set a Python error. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| typedef int (*DLPackManagedTensorAllocator)( // | ||
| DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // | ||
| void (*SetError)(void* error_ctx, const char* kind, const char* message) // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned. | ||
| * | ||
| * This function does not perform any stream synchronization. The consumer should query | ||
| * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. | ||
| * | ||
| * This function is exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * \param py_object The Python object to convert. Must have the same type | ||
| * as the one the `DLPackExchangeAPI` was discovered from. | ||
| * \return The owning DLManagedTensorVersioned* or NULL on failure with a | ||
| * Python exception set. If the data cannot be described using DLPack | ||
| * this should be a BufferError if possible. | ||
| * \note - As a C function, must not thrown C++ exceptions. | ||
| * | ||
| * \sa DLPackExchangeAPI, DLPackCurrentWorkStream | ||
| */ | ||
| typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // | ||
| void* py_object, // | ||
| DLManagedTensorVersioned** out // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor. | ||
| * | ||
| * This function provides a faster interface for temporary, non-owning, exchange. | ||
| * The producer (implementor) still owns the memory of data, strides, shape. | ||
| * The liveness of the DLTensor and the data it views is only guaranteed until | ||
| * control is returned. | ||
| * | ||
| * This function currently assumes that the producer (implementor) can fill | ||
| * in the DLTensor shape and strides without the need for temporary allocations. | ||
| * | ||
| * This function does not perform any stream synchronization. The consumer should query | ||
| * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. | ||
| * | ||
| * This function is exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * \param py_object The Python object to convert. Must have the same type | ||
| * as the one the `DLPackExchangeAPI` was discovered from. | ||
| * \param out The output DLTensor, whose space is pre-allocated on stack. | ||
| * \return 0 on success, -1 on failure with a Python exception set. | ||
| * \note - As a C function, must not thrown C++ exceptions. | ||
| * | ||
| * \sa DLPackExchangeAPI, DLPackCurrentWorkStream | ||
| */ | ||
| typedef int (*DLPackDLTensorFromPyObjectNoSync)( // | ||
| void* py_object, // | ||
| DLTensor* out // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Obtain the current work stream of a device. | ||
| * | ||
| * Obtain the current work stream of a device from the producer framework. | ||
| * For example, it should map to torch.cuda.current_stream in PyTorch. | ||
| * | ||
| * When device_type is kDLCPU, the consumer do not have to query the stream | ||
| * and the producer can simply return NULL when queried. | ||
| * The consumer do not have to do anything on stream sync or setting. | ||
| * So CPU only framework can just provide a dummy implementation that | ||
| * always set out_current_stream[0] to NULL. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forgot about the CPU part, allowing the function pointer to be |
||
| * | ||
| * \param device_type The device type. | ||
| * \param device_id The device id. | ||
| * \param out_current_stream The output current work stream. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a bit unsure if this is specified as well as it needs to be? I.e. I think NULL would be the default stream. The question is: Is there any need (or not) an "undefined or no synchronization" return value (such as -1)? The alternative is that the producer just has to return the default stream (otherwise the consumer has to guess a stream anyway probably in the kernel use-case!).
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i don;t think there is a need in this particular context. returning the default stream is likely more well defined |
||
| * | ||
| * \return 0 on success, -1 on failure with a Python exception set. | ||
| * \note - As a C function, must not thrown C++ exceptions. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| typedef int (*DLPackCurrentWorkStream)( // | ||
| DLDeviceType device_type, // | ||
| int32_t device_id, // | ||
| void** out_current_stream // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray. | ||
| * | ||
| * Convert an owning DLManagedTensorVersioned* to the Python tensor of the | ||
| * producer (implementor) library with the correct type. | ||
| * | ||
| * This function does not perform any stream synchronization. | ||
| * | ||
| * This function is exposed by the framework through the DLPackExchangeAPI. | ||
| * | ||
| * \param tensor The DLManagedTensorVersioned to convert the ownership of the | ||
| * tensor is stolen. | ||
| * \param out_py_object The output Python object. | ||
| * \return 0 on success, -1 on failure with a Python exception set. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| typedef int (*DLPackManagedTensorToPyObjectNoSync)( // | ||
| DLManagedTensorVersioned* tensor, // | ||
| void** out_py_object // | ||
| ); | ||
|
|
||
| /*! | ||
| * \brief DLPackExchangeAPI stable header. | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| typedef struct DLPackExchangeAPIHeader { | ||
| /*! | ||
| * \brief The provided DLPack version the consumer must check major version | ||
| * compatibility before using this struct. | ||
| */ | ||
| DLPackVersion version; | ||
| /*! | ||
| * \brief Optional pointer to an older DLPackExchangeAPI in the chain. | ||
| * | ||
| * It must be NULL if the framework does not support older versions. | ||
| * If the current major version is larger than the one supported by the | ||
| * consumer, the consumer may walk this to find an earlier supported version. | ||
| * | ||
| * \sa DLPackExchangeAPI | ||
| */ | ||
| struct DLPackExchangeAPIHeader* prev_api; | ||
| } DLPackExchangeAPIHeader; | ||
|
|
||
| /*! | ||
| * \brief Framework-specific function pointers table for DLPack exchange. | ||
| * | ||
| * Additionally to `__dlpack__()` we define a C function table sharable by | ||
| * Python implementations via `__c_dlpack_exchange_api__`. | ||
| * This attribute must be set on the type as a Python integer compatible | ||
| * with `PyLong_FromVoidPtr`/`PyLong_AsVoidPtr`. | ||
| * | ||
| * A consumer library may use a pattern such as: | ||
| * | ||
| * \code | ||
| * | ||
| * PyObject *api_obj = type(tensor_obj).__c_dlpack_exchange_api__; // as C-code | ||
| * MyDLPackExchangeAPI *api = PyLong_AsVoidPtr(api_obj); | ||
| * if (api == NULL && PyErr_Occurred()) { goto handle_error; } | ||
| * | ||
| * \endcode | ||
| * | ||
| * Note that this must be defined on the type. The consumer should look up the | ||
| * attribute on the type and may cache the result for each unique type. | ||
| * | ||
| * The precise API table is given by: | ||
| * \code | ||
| * struct MyDLPackExchangeAPI : public DLPackExchangeAPI { | ||
| * MyDLPackExchangeAPI() { | ||
| * header.version.major = DLPACK_MAJOR_VERSION; | ||
| * header.version.minor = DLPACK_MINOR_VERSION; | ||
| * header.prev_version_api = nullptr; | ||
| * | ||
| * managed_tensor_allocator = MyDLPackManagedTensorAllocator; | ||
| * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync; | ||
| * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync; | ||
| * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync; | ||
| * current_work_stream = MyDLPackCurrentWorkStream; | ||
| * } | ||
| * | ||
| * static const DLPackExchangeAPI* Global() { | ||
| * static MyDLPackExchangeAPI inst; | ||
| * return &inst; | ||
| * } | ||
| * }; | ||
| * \endcode | ||
| * | ||
| * Guidelines for leveraging DLPackExchangeAPI: | ||
| * | ||
| * There are generally two kinds of consumer needs for DLPack exchange: | ||
| * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel | ||
| * with the data from x, y, z. The consumer is also expected to run the kernel with the same | ||
| * stream context as the producer. For example, when x, y, z is torch.Tensor, | ||
| * consumer should query exchange_api->current_work_stream to get the | ||
| * current stream and launch the kernel with the same stream. | ||
| * This setup is necessary for no synchronization in kernel launch and maximum compatibility | ||
| * with CUDA graph capture in the producer. | ||
| * This is the desirable behavior for library extension support for frameworks like PyTorch. | ||
| * - N1: data ingestion and retention | ||
| * | ||
| * Note that obj.__dlpack__() API should provide useful ways for N1. | ||
| * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0 | ||
| * with the support of the function pointer current_work_stream. | ||
| * | ||
| * Array/Tensor libraries should statically create and initialize this structure | ||
| * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array. | ||
| * The DLPackExchangeAPI* must stay alive throughout the lifetime of the process. | ||
| * | ||
| * One simple way to do so is to create a static instance of DLPackExchangeAPI | ||
| * within the framework and return a pointer to it. The following code | ||
| * shows an example to do so in C++. It should also be reasonably easy | ||
| * to do so in other languages. | ||
| */ | ||
| typedef struct DLPackExchangeAPI { | ||
| /*! | ||
| * \brief The header that remains stable across versions. | ||
| */ | ||
| DLPackExchangeAPIHeader header; | ||
| /*! | ||
| * \brief Producer function pointer for DLPackManagedTensorAllocator | ||
| * This function must not be NULL. | ||
| * \sa DLPackManagedTensorAllocator | ||
| */ | ||
| DLPackManagedTensorAllocator managed_tensor_allocator; | ||
| /*! | ||
| * \brief Producer function pointer for DLPackManagedTensorFromPyObject | ||
| * This function must be not NULL. | ||
| * \sa DLPackManagedTensorFromPyObject | ||
| */ | ||
| DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; | ||
| /*! | ||
| * \brief Producer function pointer for DLPackManagedTensorToPyObject | ||
| * This function must be not NULL. | ||
| * \sa DLPackManagedTensorToPyObject | ||
| */ | ||
| DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; | ||
| /*! | ||
| * \brief Producer function pointer for DLPackDLTensorFromPyObject | ||
| * This function can be NULL when the producer does not support this function. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this approach as such, but as I just mentioned it elsewhere, I want to point out that if this is optional and commonly not support, it means all consumers will need: which unfortunately means that we can get a small speed improvement if available (not having to allocate the I suppose part of the solution here may be that you really want a C++ convenience layer that is easy to vendor... This isn't a deal breaker! It made me wonder if filling in a |
||
| * \sa DLPackDLTensorFromPyObjectNoSync | ||
| */ | ||
| DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; | ||
| /*! | ||
| * \brief Producer function pointer for DLPackCurrentWorkStream | ||
| * This function must be not NULL. | ||
| * \sa DLPackCurrentWorkStream | ||
| */ | ||
| DLPackCurrentWorkStream current_work_stream; | ||
| } DLPackExchangeAPI; | ||
|
|
||
| #ifdef __cplusplus | ||
| } // DLPACK_EXTERN_C | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small thing is, I am not sure how we will use
kind? Let's say this typically get's converted to a Python exception, then we would have to somewhat agree on e.g.MemoryErroror so to translate that.Maybe this will just settle itself, in the end I bet this pretty much only ever raises
MemoryErroranyway.(I think this function is important! I know I am slow to settle on something, but I haven't quite settled on loving this approach. But I suspect its true that asking to allocate e.g. the
torchtensor here -- requiring the GIL -- and then viewing that is also not really better.)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, likely
MemoryErroris the most common one likely