Skip to content

Commit 6f808ba

Browse files
authored
[DLPack] Bump DLPack to v1.2 and implement C functions exchange API (#75650)
1 parent 8e37ed6 commit 6f808ba

File tree

10 files changed

+263
-14
lines changed

10 files changed

+263
-14
lines changed

ci/dcu_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ function hybrid_paddlex() {
7575
function main(){
7676
cd ${PADDLE_ROOT}/build
7777
pip install hypothesis
78+
/opt/py310/bin/pip install -r ${PADDLE_ROOT}/python/unittest_py/requirements.txt
7879
/opt/py310/bin/pip install safetensors
7980
if ls ${PADDLE_ROOT}/build/python/dist/*whl >/dev/null 2>&1; then
8081
pip install ${PADDLE_ROOT}/build/python/dist/*whl

paddle/fluid/framework/dlpack_tensor.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ ::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype) {
265265
framework::TransToProtoVarType(dtype));
266266
}
267267

268-
phi::Place DLDeviceToPlace(const DLDevice &dl_device) {
268+
phi::Place DLDeviceToPlace(const ::DLDevice &dl_device) {
269269
phi::Place place;
270270
if (dl_device.device_type == kDLCPU) {
271271
place = phi::CPUPlace();
@@ -279,7 +279,7 @@ phi::Place DLDeviceToPlace(const DLDevice &dl_device) {
279279
return place;
280280
}
281281

282-
DLDevice PlaceToDLDevice(const phi::Place &place) {
282+
::DLDevice PlaceToDLDevice(const phi::Place &place) {
283283
return phi::VisitPlace(place, internal::DLDeviceVisitor());
284284
}
285285

@@ -358,6 +358,22 @@ DLManagedTensorVersioned *ToDLPackVersioned(const phi::DenseTensor &src,
358358
return ToDLPackImpl<DLManagedTensorVersioned>(src, flags);
359359
}
360360

361+
void ToDLPackNonOwningImpl(const phi::DenseTensor &tensor,
362+
::DLTensor &out) { // NOLINT
363+
// Fill in the pre-allocated DLTensor struct with direct pointers
364+
// This is a non-owning conversion - the caller owns the tensor
365+
// and must keep it alive for the duration of DLTensor usage
366+
out.data = const_cast<void *>(tensor.data());
367+
out.device = PlaceToDLDevice(tensor.place());
368+
out.ndim = static_cast<int32_t>(tensor.dims().size());
369+
out.dtype = PhiDataTypeToDLDataType(tensor.dtype());
370+
// sizes() and strides() return pointers to TensorImpl's stable storage
371+
// which remains valid as long as the tensor is alive
372+
out.shape = const_cast<int64_t *>(tensor.dims().Get());
373+
out.strides = const_cast<int64_t *>(tensor.strides().Get());
374+
out.byte_offset = 0;
375+
}
376+
361377
template <typename T>
362378
phi::DenseTensor FromDLPackImpl(T *src, Deleter deleter) {
363379
std::vector<int64_t> shape_vec;

paddle/fluid/framework/dlpack_tensor.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@ and paddle/phi/api/lib/tensor_utils.cc
2929
*/
3030
using Deleter = std::function<void(void*)>;
3131

32-
phi::Place DLDeviceToPlace(const DLDevice& device);
33-
DLDevice PlaceToDLDevice(const phi::Place& place);
34-
35-
TEST_API DLManagedTensor* ToDLPack(const phi::DenseTensor& src,
36-
uint64_t flags = 0);
37-
DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src,
38-
uint64_t flags = 0);
39-
TEST_API phi::DenseTensor FromDLPack(DLManagedTensor* src);
40-
phi::DenseTensor FromDLPackVersioned(DLManagedTensorVersioned* src);
32+
::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype);
33+
phi::DataType DLDataTypeToPhiDataType(::DLDataType type);
34+
phi::Place DLDeviceToPlace(const ::DLDevice& device);
35+
::DLDevice PlaceToDLDevice(const phi::Place& place);
36+
37+
TEST_API ::DLManagedTensor* ToDLPack(const phi::DenseTensor& src,
38+
uint64_t flags = 0);
39+
::DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src,
40+
uint64_t flags = 0);
41+
void ToDLPackNonOwningImpl(const phi::DenseTensor& tensor,
42+
::DLTensor& out); // NOLINT
43+
TEST_API phi::DenseTensor FromDLPack(::DLManagedTensor* src);
44+
phi::DenseTensor FromDLPackVersioned(::DLManagedTensorVersioned* src);
4145

4246
// A traits to support both DLManagedTensor and DLManagedTensorVersioned
4347
template <typename T>

paddle/fluid/pybind/pybind.cc

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,108 @@ class PyLayerBlockContextManager {
763763
PyLayerBlockContextManager() = default;
764764
};
765765

766+
int DLPackDLTensorFromPyObjectNoSync(void *py_obj, DLTensor *out) {
767+
try {
768+
// Use handle (non-owning) to avoid unnecessary refcount operations
769+
py::handle handle(static_cast<PyObject *>(py_obj));
770+
paddle::Tensor tensor = handle.cast<paddle::Tensor>();
771+
std::shared_ptr<phi::DenseTensor> dense_tensor =
772+
std::static_pointer_cast<phi::DenseTensor>(tensor.impl());
773+
paddle::framework::ToDLPackNonOwningImpl(*dense_tensor, *out);
774+
return 0;
775+
} catch (const std::exception &e) {
776+
PyErr_SetString(PyExc_RuntimeError, e.what());
777+
return -1;
778+
}
779+
}
780+
781+
int DLPackManagedTensorFromPyObjectNoSync(void *py_obj,
782+
DLManagedTensorVersioned **out) {
783+
try {
784+
py::handle handle(static_cast<PyObject *>(py_obj));
785+
paddle::Tensor tensor = handle.cast<paddle::Tensor>();
786+
std::shared_ptr<phi::DenseTensor> dense_tensor =
787+
std::static_pointer_cast<phi::DenseTensor>(tensor.impl());
788+
*out = paddle::framework::ToDLPackVersioned(*dense_tensor);
789+
return 0;
790+
} catch (const std::exception &e) {
791+
PyErr_SetString(PyExc_RuntimeError, e.what());
792+
return -1;
793+
}
794+
}
795+
796+
int DLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned *src,
797+
void **py_obj_out) {
798+
try {
799+
phi::DenseTensor dense_tensor = paddle::framework::FromDLPackVersioned(src);
800+
paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
801+
egr::EagerUtils::autograd_meta(&tensor)->SetPersistable(false);
802+
*py_obj_out = ToPyObject(tensor);
803+
return 0;
804+
} catch (const std::exception &e) {
805+
PyErr_SetString(PyExc_RuntimeError, e.what());
806+
return -1;
807+
}
808+
}
809+
810+
int DLPackManagedTensorAllocator(::DLTensor *prototype,
811+
::DLManagedTensorVersioned **out,
812+
void *error_ctx,
813+
void (*SetError)(void *error_ctx,
814+
const char *kind,
815+
const char *message)) {
816+
try {
817+
phi::IntArray shape(prototype->shape, prototype->ndim);
818+
phi::Place place(paddle::framework::DLDeviceToPlace(prototype->device));
819+
phi::DataType dtype =
820+
paddle::framework::DLDataTypeToPhiDataType(prototype->dtype);
821+
paddle::Tensor tensor = paddle::empty(shape, dtype, place);
822+
std::shared_ptr<phi::DenseTensor> dense_tensor =
823+
std::static_pointer_cast<phi::DenseTensor>(tensor.impl());
824+
*out = paddle::framework::ToDLPackVersioned(*dense_tensor);
825+
return 0;
826+
} catch (const std::exception &e) {
827+
SetError(error_ctx, "DLPackManagedTensorAllocator", e.what());
828+
return -1;
829+
}
830+
}
831+
832+
int DLPackCurrentWorkStream(DLDeviceType device_type,
833+
int32_t device_id,
834+
void **out_stream) {
835+
try {
836+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
837+
defined(PADDLE_WITH_CUSTOM_DEVICE)
838+
if (device_type == kDLCUDA || device_type == kDLROCM) {
839+
*out_stream = platform::get_current_stream(device_id)->raw_stream();
840+
}
841+
#endif
842+
return 0;
843+
} catch (const std::exception &e) {
844+
PyErr_SetString(PyExc_RuntimeError, e.what());
845+
return -1;
846+
}
847+
}
848+
849+
struct PaddleDLPackExchangeAPI : public ::DLPackExchangeAPI {
850+
PaddleDLPackExchangeAPI() {
851+
header.version.major = DLPACK_MAJOR_VERSION;
852+
header.version.minor = DLPACK_MINOR_VERSION;
853+
header.prev_api = nullptr;
854+
managed_tensor_allocator = DLPackManagedTensorAllocator;
855+
managed_tensor_from_py_object_no_sync =
856+
DLPackManagedTensorFromPyObjectNoSync;
857+
managed_tensor_to_py_object_no_sync = DLPackManagedTensorToPyObjectNoSync;
858+
dltensor_from_py_object_no_sync = DLPackDLTensorFromPyObjectNoSync;
859+
current_work_stream = DLPackCurrentWorkStream;
860+
}
861+
862+
static const DLPackExchangeAPI *Instance() {
863+
static PaddleDLPackExchangeAPI inst;
864+
return &inst;
865+
}
866+
};
867+
766868
// NOTE: use to load file by Mmap
767869
enum MMapLoadModes {
768870
ALLOCATOR_MAPPED_SHARED = 1,
@@ -1773,6 +1875,10 @@ PYBIND11_MODULE(libpaddle, m) {
17731875
dl_device.device_id);
17741876
});
17751877

1878+
m.def("dlpack_exchange_api_ptr", []() -> int64_t {
1879+
return reinterpret_cast<int64_t>(PaddleDLPackExchangeAPI::Instance());
1880+
});
1881+
17761882
m.def("from_dlpack", [](py::object data) {
17771883
if (PyCapsule_IsValid(data.ptr(),
17781884
DLPackTraits<DLManagedTensorVersioned>::capsule)) {

python/paddle/base/dygraph/tensor_patch_methods.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,7 @@ def __tvm_ffi_env_stream__(self) -> int:
15861586
("__dlpack_device__", __dlpack_device__),
15871587
("get_device", get_device),
15881588
("__tvm_ffi_env_stream__", __tvm_ffi_env_stream__),
1589+
("__c_dlpack_exchange_api__", core.dlpack_exchange_api_ptr()),
15891590
):
15901591
setattr(core.eager.Tensor, method_name, method)
15911592

python/paddle/utils/dlpack.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class DLDeviceType(enum.IntEnum):
7575
kDLWebGPU = (15,)
7676
kDLHexagon = (16,)
7777
kDLMAIA = (17,)
78+
kDLTrn = (18,)
7879

7980

8081
def to_dlpack(x: Tensor) -> CapsuleType:
@@ -215,7 +216,7 @@ def from_dlpack(
215216

216217
if hasattr(dlpack, "__dlpack__"):
217218
kwargs = {}
218-
kwargs["max_version"] = (1, 1)
219+
kwargs["max_version"] = (1, 2)
219220
if copy is not None:
220221
kwargs["copy"] = copy
221222

python/unittest_py/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ xdoctest==1.3.0
2020
ubelt==1.3.3 # just for xdoctest
2121
mypy==1.17.1
2222
soundfile
23+
apache-tvm-ffi==0.1.0b16

test/dygraph_to_static/test_tensor_attr_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
'__dlpack__',
8282
"__dlpack_device__",
8383
"__tvm_ffi_env_stream__",
84+
"__c_dlpack_exchange_api__",
8485
]
8586
)
8687
STATIC_ONLY_TENSOR_ATTRS_ALLOW_LIST = OrderedSet(

test/legacy_test/test_tvm_ffi.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
import platform
1518
import unittest
19+
from typing import TYPE_CHECKING
20+
21+
import numpy as np
22+
import tvm_ffi.cpp
1623

1724
import paddle
1825

26+
if TYPE_CHECKING:
27+
from tvm_ffi import Module
28+
1929

20-
class TestTVMFFI(unittest.TestCase):
30+
class TestTVMFFIEnvStream(unittest.TestCase):
2131
def test_tvm_ffi_env_stream_for_gpu_tensor(self):
2232
if not paddle.is_compiled_with_cuda():
2333
return
@@ -34,5 +44,113 @@ def test_tvm_ffi_env_stream_for_cpu_tensor(self):
3444
tensor.__tvm_ffi_env_stream__()
3545

3646

47+
class TestCDLPackExchangeAPI(unittest.TestCase):
48+
def test_c_dlpack_exchange_api_cpu(self):
49+
cpp_source = r"""
50+
void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
51+
// implementation of a library function
52+
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
53+
DLDataType f32_dtype{kDLFloat, 32, 1};
54+
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
55+
TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
56+
TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
57+
TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape";
58+
for (int i = 0; i < x->shape[0]; ++i) {
59+
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
60+
}
61+
}
62+
"""
63+
64+
mod: Module = tvm_ffi.cpp.load_inline(
65+
name='mod', cpp_sources=cpp_source, functions='add_one_cpu'
66+
)
67+
68+
x = paddle.full((3,), 1.0, dtype='float32').cpu()
69+
y = paddle.zeros((3,), dtype='float32').cpu()
70+
mod.add_one_cpu(x, y)
71+
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
72+
73+
def test_c_dlpack_exchange_api_gpu(self):
74+
if not paddle.is_compiled_with_cuda():
75+
return
76+
if paddle.is_compiled_with_rocm():
77+
# Skip on DCU because CUDA_HOME is not available
78+
return
79+
if platform.system() == "Windows":
80+
# Temporary skip this test case on windows because compile bug on TVM FFI
81+
return
82+
cpp_sources = r"""
83+
void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y);
84+
"""
85+
cuda_sources = r"""
86+
__global__ void AddOneKernel(float* x, float* y, int n) {
87+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
88+
if (idx < n) {
89+
y[idx] = x[idx] + 1;
90+
}
91+
}
92+
93+
void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
94+
// implementation of a library function
95+
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
96+
DLDataType f32_dtype{kDLFloat, 32, 1};
97+
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
98+
TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
99+
TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
100+
TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape";
101+
102+
int64_t n = x->shape[0];
103+
int64_t nthread_per_block = 256;
104+
int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
105+
// Obtain the current stream from the environment by calling TVMFFIEnvGetStream
106+
cudaStream_t stream = static_cast<cudaStream_t>(
107+
TVMFFIEnvGetStream(x->device.device_type, x->device.device_id));
108+
// launch the kernel
109+
AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(static_cast<float*>(x->data),
110+
static_cast<float*>(y->data), n);
111+
}
112+
"""
113+
mod: Module = tvm_ffi.cpp.load_inline(
114+
name='mod',
115+
cpp_sources=cpp_sources,
116+
cuda_sources=cuda_sources,
117+
functions=['add_one_cuda'],
118+
)
119+
120+
x = paddle.full((3,), 1.0, dtype='float32').cuda()
121+
y = paddle.zeros((3,), dtype='float32').cuda()
122+
mod.add_one_cuda(x, y)
123+
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
124+
125+
def test_c_dlpack_exchange_api_alloc_tensor(self):
126+
if platform.system() == "Windows":
127+
# Temporary skip this test case on windows because return owned tensor created by
128+
# TVMFFIEnvGetTensorAllocator will cause double free error
129+
return
130+
cpp_source = r"""
131+
inline tvm::ffi::Tensor alloc_tensor(tvm::ffi::Shape shape, DLDataType dtype, DLDevice device) {
132+
return tvm::ffi::Tensor::FromDLPackAlloc(TVMFFIEnvGetTensorAllocator(), shape, dtype, device);
133+
}
134+
135+
tvm::ffi::Tensor add_one_cpu(tvm::ffi::TensorView x) {
136+
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
137+
DLDataType f32_dtype{kDLFloat, 32, 1};
138+
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
139+
tvm::ffi::Shape x_shape(x->shape, x->shape + x->ndim);
140+
tvm::ffi::Tensor y = alloc_tensor(x_shape, f32_dtype, x->device);
141+
for (int i = 0; i < x->shape[0]; ++i) {
142+
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
143+
}
144+
return y;
145+
}
146+
"""
147+
mod: Module = tvm_ffi.cpp.load_inline(
148+
name='mod', cpp_sources=cpp_source, functions=['add_one_cpu']
149+
)
150+
x = paddle.full((3,), 1.0, dtype='float32').cpu()
151+
y = mod.add_one_cpu(x)
152+
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
153+
154+
37155
if __name__ == '__main__':
38156
unittest.main()

0 commit comments

Comments
 (0)