Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to_dlpack #50138

Merged
merged 5 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions paddle/fluid/framework/dlpack_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,58 @@ struct DLDeviceVisitor
};
} // namespace internal

struct PaddleDLMTensor {
phi::DenseTensor handle;
DLManagedTensor tensor;
};

void deleter(DLManagedTensor *arg) {
delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides;
delete static_cast<PaddleDLMTensor *>(arg->manager_ctx);
}

DLManagedTensor *toDLPack(const phi::DenseTensor &src) {
PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor);
pdDLMTensor->handle = const_cast<phi::DenseTensor &>(src);
pdDLMTensor->tensor.manager_ctx = pdDLMTensor;
pdDLMTensor->tensor.deleter = &deleter;
pdDLMTensor->tensor.dl_tensor.data = const_cast<void *>(src.data());

// init ndim
using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int
pdDLMTensor->tensor.dl_tensor.ndim = static_cast<DimType>(src.dims().size());
DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim;

// init shape
auto shape = new int64_t[ndim];
for (DimType i = 0; i < ndim; ++i) {
shape[i] = src.dims()[i];
}
pdDLMTensor->tensor.dl_tensor.shape = shape;

// init stride
auto strides = new int64_t[ndim];
for (DimType i = 0; i < ndim; ++i) {
strides[i] = 1;
}
for (DimType i = ndim - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
pdDLMTensor->tensor.dl_tensor.strides = strides;

// init device, DLDevice type with device_type and device_id
auto place = src.place();
pdDLMTensor->tensor.dl_tensor.device =
paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());

pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex(
framework::TransToProtoVarType(src.dtype()));

pdDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(pdDLMTensor->tensor);
}

DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
// init data, data buffer
t_.data = const_cast<void *>(tensor.data());
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/dlpack_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,7 @@ class DLPackTensor {
ShapeType shape_[DDim::kMaxRank];
};

DLManagedTensor* toDLPack(const phi::DenseTensor& src);

} // namespace framework
} // namespace paddle
21 changes: 7 additions & 14 deletions paddle/fluid/pybind/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,22 +473,15 @@ void BindTensor(pybind11::module &m) { // NOLINT
)DOC")
.def("_to_dlpack",
[](phi::DenseTensor &self) {
DLPackTensor dlpack_tensor(self, 1);
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
auto capsule = py::capsule(
DLManagedTensor *dmt = framework::toDLPack(self);
auto capsule = pybind11::capsule(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pr 内容栏补充一下 问题原因,修复策略吧

static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
if (ptr) {
auto dltensor = new DLManagedTensor;
try {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "used_dltensor"));
return;
} catch (...) {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
}
dltensor->deleter(dltensor);
if (!PyCapsule_IsValid(ptr, "dltensor")) {
return;
}
DLManagedTensor *dmt = static_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
dmt->deleter(dmt);
});
return capsule;
})
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/tests/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def test_dlpack_deletion(self):
dlpack = paddle.utils.dlpack.to_dlpack(a)
b = paddle.utils.dlpack.from_dlpack(dlpack)

def test_to_dlpack_for_loop(self):
# See Paddle issue 50120
for i in range(10):
x = paddle.rand([3, 5])
dlpack = paddle.utils.dlpack.to_dlpack(x)


class TestRaiseError(unittest.TestCase):
def test_from_dlpack_raise_type_error(self):
Expand Down