Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/dcu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ function hybrid_paddlex() {
function main(){
cd ${PADDLE_ROOT}/build
pip install hypothesis
/opt/py310/bin/pip install -r ${PADDLE_ROOT}/python/unittest_py/requirements.txt
/opt/py310/bin/pip install safetensors
if ls ${PADDLE_ROOT}/build/python/dist/*whl >/dev/null 2>&1; then
pip install ${PADDLE_ROOT}/build/python/dist/*whl
Expand Down
20 changes: 18 additions & 2 deletions paddle/fluid/framework/dlpack_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ ::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype) {
framework::TransToProtoVarType(dtype));
}

phi::Place DLDeviceToPlace(const DLDevice &dl_device) {
phi::Place DLDeviceToPlace(const ::DLDevice &dl_device) {
phi::Place place;
if (dl_device.device_type == kDLCPU) {
place = phi::CPUPlace();
Expand All @@ -279,7 +279,7 @@ phi::Place DLDeviceToPlace(const DLDevice &dl_device) {
return place;
}

DLDevice PlaceToDLDevice(const phi::Place &place) {
::DLDevice PlaceToDLDevice(const phi::Place &place) {
return phi::VisitPlace(place, internal::DLDeviceVisitor());
}

Expand Down Expand Up @@ -358,6 +358,22 @@ DLManagedTensorVersioned *ToDLPackVersioned(const phi::DenseTensor &src,
return ToDLPackImpl<DLManagedTensorVersioned>(src, flags);
}

void ToDLPackNonOwningImpl(const phi::DenseTensor &tensor,
::DLTensor &out) { // NOLINT
// Fill in the pre-allocated DLTensor struct with direct pointers
// This is a non-owning conversion - the caller owns the tensor
// and must keep it alive for the duration of DLTensor usage
out.data = const_cast<void *>(tensor.data());
out.device = PlaceToDLDevice(tensor.place());
out.ndim = static_cast<int32_t>(tensor.dims().size());
out.dtype = PhiDataTypeToDLDataType(tensor.dtype());
// sizes() and strides() return pointers to TensorImpl's stable storage
// which remains valid as long as the tensor is alive
out.shape = const_cast<int64_t *>(tensor.dims().Get());
out.strides = const_cast<int64_t *>(tensor.strides().Get());
out.byte_offset = 0;
}

template <typename T>
phi::DenseTensor FromDLPackImpl(T *src, Deleter deleter) {
std::vector<int64_t> shape_vec;
Expand Down
22 changes: 13 additions & 9 deletions paddle/fluid/framework/dlpack_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,19 @@ and paddle/phi/api/lib/tensor_utils.cc
*/
using Deleter = std::function<void(void*)>;

phi::Place DLDeviceToPlace(const DLDevice& device);
DLDevice PlaceToDLDevice(const phi::Place& place);

TEST_API DLManagedTensor* ToDLPack(const phi::DenseTensor& src,
uint64_t flags = 0);
DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src,
uint64_t flags = 0);
TEST_API phi::DenseTensor FromDLPack(DLManagedTensor* src);
phi::DenseTensor FromDLPackVersioned(DLManagedTensorVersioned* src);
::DLDataType PhiDataTypeToDLDataType(phi::DataType dtype);
phi::DataType DLDataTypeToPhiDataType(::DLDataType type);
phi::Place DLDeviceToPlace(const ::DLDevice& device);
::DLDevice PlaceToDLDevice(const phi::Place& place);

TEST_API ::DLManagedTensor* ToDLPack(const phi::DenseTensor& src,
uint64_t flags = 0);
::DLManagedTensorVersioned* ToDLPackVersioned(const phi::DenseTensor& src,
uint64_t flags = 0);
void ToDLPackNonOwningImpl(const phi::DenseTensor& tensor,
::DLTensor& out); // NOLINT
TEST_API phi::DenseTensor FromDLPack(::DLManagedTensor* src);
phi::DenseTensor FromDLPackVersioned(::DLManagedTensorVersioned* src);

// A traits to support both DLManagedTensor and DLManagedTensorVersioned
template <typename T>
Expand Down
106 changes: 106 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,108 @@ class PyLayerBlockContextManager {
PyLayerBlockContextManager() = default;
};

int DLPackDLTensorFromPyObjectNoSync(void *py_obj, DLTensor *out) {
try {
// Use handle (non-owning) to avoid unnecessary refcount operations
py::handle handle(static_cast<PyObject *>(py_obj));
paddle::Tensor tensor = handle.cast<paddle::Tensor>();
std::shared_ptr<phi::DenseTensor> dense_tensor =
std::static_pointer_cast<phi::DenseTensor>(tensor.impl());
paddle::framework::ToDLPackNonOwningImpl(*dense_tensor, *out);
return 0;
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
Comment on lines +775 to +777
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只需要管 std::exception,不需要管其他的异常吗?类似这样的?

try {
    // 代码
} catch (const std::exception& e) {
    // 处理标准异常
} catch (...) {
    // 处理所有其他异常
}

Copy link
Member Author

@SigureMo SigureMo Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

协议模板,这里不应该动

dmlc/dlpack#175 (comment)

}
}

int DLPackManagedTensorFromPyObjectNoSync(void *py_obj,
DLManagedTensorVersioned **out) {
try {
py::handle handle(static_cast<PyObject *>(py_obj));
paddle::Tensor tensor = handle.cast<paddle::Tensor>();
std::shared_ptr<phi::DenseTensor> dense_tensor =
std::static_pointer_cast<phi::DenseTensor>(tensor.impl());
*out = paddle::framework::ToDLPackVersioned(*dense_tensor);
return 0;
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}

int DLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned *src,
void **py_obj_out) {
try {
phi::DenseTensor dense_tensor = paddle::framework::FromDLPackVersioned(src);
paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
egr::EagerUtils::autograd_meta(&tensor)->SetPersistable(false);
*py_obj_out = ToPyObject(tensor);
return 0;
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}

int DLPackManagedTensorAllocator(::DLTensor *prototype,
::DLManagedTensorVersioned **out,
void *error_ctx,
void (*SetError)(void *error_ctx,
const char *kind,
const char *message)) {
try {
phi::IntArray shape(prototype->shape, prototype->ndim);
phi::Place place(paddle::framework::DLDeviceToPlace(prototype->device));
phi::DataType dtype =
paddle::framework::DLDataTypeToPhiDataType(prototype->dtype);
paddle::Tensor tensor = paddle::empty(shape, dtype, place);
std::shared_ptr<phi::DenseTensor> dense_tensor =
std::static_pointer_cast<phi::DenseTensor>(tensor.impl());
*out = paddle::framework::ToDLPackVersioned(*dense_tensor);
return 0;
} catch (const std::exception &e) {
SetError(error_ctx, "DLPackManagedTensorAllocator", e.what());
return -1;
}
}

int DLPackCurrentWorkStream(DLDeviceType device_type,
int32_t device_id,
void **out_stream) {
try {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_CUSTOM_DEVICE)
if (device_type == kDLCUDA || device_type == kDLROCM) {
*out_stream = platform::get_current_stream(device_id)->raw_stream();
}
#endif
return 0;
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return -1;
}
}

struct PaddleDLPackExchangeAPI : public ::DLPackExchangeAPI {
PaddleDLPackExchangeAPI() {
header.version.major = DLPACK_MAJOR_VERSION;
header.version.minor = DLPACK_MINOR_VERSION;
header.prev_api = nullptr;
managed_tensor_allocator = DLPackManagedTensorAllocator;
managed_tensor_from_py_object_no_sync =
DLPackManagedTensorFromPyObjectNoSync;
managed_tensor_to_py_object_no_sync = DLPackManagedTensorToPyObjectNoSync;
dltensor_from_py_object_no_sync = DLPackDLTensorFromPyObjectNoSync;
current_work_stream = DLPackCurrentWorkStream;
}

static const DLPackExchangeAPI *Instance() {
static PaddleDLPackExchangeAPI inst;
return &inst;
}
};

// NOTE: use to load file by Mmap
enum MMapLoadModes {
ALLOCATOR_MAPPED_SHARED = 1,
Expand Down Expand Up @@ -1773,6 +1875,10 @@ PYBIND11_MODULE(libpaddle, m) {
dl_device.device_id);
});

m.def("dlpack_exchange_api_ptr", []() -> int64_t {
return reinterpret_cast<int64_t>(PaddleDLPackExchangeAPI::Instance());
});

m.def("from_dlpack", [](py::object data) {
if (PyCapsule_IsValid(data.ptr(),
DLPackTraits<DLManagedTensorVersioned>::capsule)) {
Expand Down
1 change: 1 addition & 0 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,7 @@ def __tvm_ffi_env_stream__(self) -> int:
("__dlpack_device__", __dlpack_device__),
("get_device", get_device),
("__tvm_ffi_env_stream__", __tvm_ffi_env_stream__),
("__c_dlpack_exchange_api__", core.dlpack_exchange_api_ptr()),
):
setattr(core.eager.Tensor, method_name, method)

Expand Down
3 changes: 2 additions & 1 deletion python/paddle/utils/dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class DLDeviceType(enum.IntEnum):
kDLWebGPU = (15,)
kDLHexagon = (16,)
kDLMAIA = (17,)
kDLTrn = (18,)


def to_dlpack(x: Tensor) -> CapsuleType:
Expand Down Expand Up @@ -215,7 +216,7 @@ def from_dlpack(

if hasattr(dlpack, "__dlpack__"):
kwargs = {}
kwargs["max_version"] = (1, 1)
kwargs["max_version"] = (1, 2)
if copy is not None:
kwargs["copy"] = copy

Expand Down
1 change: 1 addition & 0 deletions python/unittest_py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ xdoctest==1.3.0
ubelt==1.3.3 # just for xdoctest
mypy==1.17.1
soundfile
apache-tvm-ffi==0.1.0b16
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
'__dlpack__',
"__dlpack_device__",
"__tvm_ffi_env_stream__",
"__c_dlpack_exchange_api__",
]
)
STATIC_ONLY_TENSOR_ATTRS_ALLOW_LIST = OrderedSet(
Expand Down
120 changes: 119 additions & 1 deletion test/legacy_test/test_tvm_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import platform
import unittest
from typing import TYPE_CHECKING

import numpy as np
import tvm_ffi.cpp

import paddle

if TYPE_CHECKING:
from tvm_ffi import Module


class TestTVMFFI(unittest.TestCase):
class TestTVMFFIEnvStream(unittest.TestCase):
def test_tvm_ffi_env_stream_for_gpu_tensor(self):
if not paddle.is_compiled_with_cuda():
return
Expand All @@ -34,5 +44,113 @@ def test_tvm_ffi_env_stream_for_cpu_tensor(self):
tensor.__tvm_ffi_env_stream__()


class TestCDLPackExchangeAPI(unittest.TestCase):
def test_c_dlpack_exchange_api_cpu(self):
cpp_source = r"""
void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape";
for (int i = 0; i < x->shape[0]; ++i) {
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
}
}
"""

mod: Module = tvm_ffi.cpp.load_inline(
name='mod', cpp_sources=cpp_source, functions='add_one_cpu'
)

x = paddle.full((3,), 1.0, dtype='float32').cpu()
y = paddle.zeros((3,), dtype='float32').cpu()
mod.add_one_cpu(x, y)
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])

def test_c_dlpack_exchange_api_gpu(self):
if not paddle.is_compiled_with_cuda():
return
if paddle.is_compiled_with_rocm():
# Skip on DCU because CUDA_HOME is not available
return
if platform.system() == "Windows":
# Temporary skip this test case on windows because compile bug on TVM FFI
return
cpp_sources = r"""
void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y);
"""
cuda_sources = r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
y[idx] = x[idx] + 1;
}
}

void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same shape";

int64_t n = x->shape[0];
int64_t nthread_per_block = 256;
int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
// Obtain the current stream from the environment by calling TVMFFIEnvGetStream
cudaStream_t stream = static_cast<cudaStream_t>(
TVMFFIEnvGetStream(x->device.device_type, x->device.device_id));
// launch the kernel
AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(static_cast<float*>(x->data),
static_cast<float*>(y->data), n);
}
"""
mod: Module = tvm_ffi.cpp.load_inline(
name='mod',
cpp_sources=cpp_sources,
cuda_sources=cuda_sources,
functions=['add_one_cuda'],
)

x = paddle.full((3,), 1.0, dtype='float32').cuda()
y = paddle.zeros((3,), dtype='float32').cuda()
mod.add_one_cuda(x, y)
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])

def test_c_dlpack_exchange_api_alloc_tensor(self):
if platform.system() == "Windows":
# Temporary skip this test case on windows because return owned tensor created by
# TVMFFIEnvGetTensorAllocator will cause double free error
return
cpp_source = r"""
inline tvm::ffi::Tensor alloc_tensor(tvm::ffi::Shape shape, DLDataType dtype, DLDevice device) {
return tvm::ffi::Tensor::FromDLPackAlloc(TVMFFIEnvGetTensorAllocator(), shape, dtype, device);
}

tvm::ffi::Tensor add_one_cpu(tvm::ffi::TensorView x) {
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
tvm::ffi::Shape x_shape(x->shape, x->shape + x->ndim);
tvm::ffi::Tensor y = alloc_tensor(x_shape, f32_dtype, x->device);
for (int i = 0; i < x->shape[0]; ++i) {
static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
}
return y;
}
"""
mod: Module = tvm_ffi.cpp.load_inline(
name='mod', cpp_sources=cpp_source, functions=['add_one_cpu']
)
x = paddle.full((3,), 1.0, dtype='float32').cpu()
y = mod.add_one_cpu(x)
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])


if __name__ == '__main__':
unittest.main()
Loading