Skip to content

Commit

Permalink
Fix to_dlpack (#50138)
Browse files Browse the repository at this point in the history
* fix to_dlpack for loop

* fix reference count
  • Loading branch information
DesmonDay authored Feb 6, 2023
1 parent 244e754 commit 35ce2bd
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 14 deletions.
52 changes: 52 additions & 0 deletions paddle/fluid/framework/dlpack_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,58 @@ struct DLDeviceVisitor
};
} // namespace internal

struct PaddleDLMTensor {
phi::DenseTensor handle;
DLManagedTensor tensor;
};

void deleter(DLManagedTensor *arg) {
delete[] arg->dl_tensor.shape;
delete[] arg->dl_tensor.strides;
delete static_cast<PaddleDLMTensor *>(arg->manager_ctx);
}

DLManagedTensor *toDLPack(const phi::DenseTensor &src) {
PaddleDLMTensor *pdDLMTensor(new PaddleDLMTensor);
pdDLMTensor->handle = const_cast<phi::DenseTensor &>(src);
pdDLMTensor->tensor.manager_ctx = pdDLMTensor;
pdDLMTensor->tensor.deleter = &deleter;
pdDLMTensor->tensor.dl_tensor.data = const_cast<void *>(src.data());

// init ndim
using DimType = decltype(pdDLMTensor->tensor.dl_tensor.ndim); // int
pdDLMTensor->tensor.dl_tensor.ndim = static_cast<DimType>(src.dims().size());
DimType ndim = pdDLMTensor->tensor.dl_tensor.ndim;

// init shape
auto shape = new int64_t[ndim];
for (DimType i = 0; i < ndim; ++i) {
shape[i] = src.dims()[i];
}
pdDLMTensor->tensor.dl_tensor.shape = shape;

// init stride
auto strides = new int64_t[ndim];
for (DimType i = 0; i < ndim; ++i) {
strides[i] = 1;
}
for (DimType i = ndim - 2; i >= 0; --i) {
strides[i] = shape[i + 1] * strides[i + 1];
}
pdDLMTensor->tensor.dl_tensor.strides = strides;

// init device, DLDevice type with device_type and device_id
auto place = src.place();
pdDLMTensor->tensor.dl_tensor.device =
paddle::platform::VisitPlace(place, internal::DLDeviceVisitor());

pdDLMTensor->tensor.dl_tensor.dtype = internal::GetDLDataTypeFromTypeIndex(
framework::TransToProtoVarType(src.dtype()));

pdDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(pdDLMTensor->tensor);
}

DLPackTensor::DLPackTensor(const phi::DenseTensor &tensor, LaneType lanes) {
// init data, data buffer
t_.data = const_cast<void *>(tensor.data());
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/dlpack_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,7 @@ class DLPackTensor {
ShapeType shape_[DDim::kMaxRank];
};

DLManagedTensor* toDLPack(const phi::DenseTensor& src);

} // namespace framework
} // namespace paddle
21 changes: 7 additions & 14 deletions paddle/fluid/pybind/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,22 +473,15 @@ void BindTensor(pybind11::module &m) { // NOLINT
)DOC")
.def("_to_dlpack",
[](phi::DenseTensor &self) {
DLPackTensor dlpack_tensor(self, 1);
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
auto capsule = py::capsule(
DLManagedTensor *dmt = framework::toDLPack(self);
auto capsule = pybind11::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
if (ptr) {
auto dltensor = new DLManagedTensor;
try {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "used_dltensor"));
return;
} catch (...) {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
}
dltensor->deleter(dltensor);
if (!PyCapsule_IsValid(ptr, "dltensor")) {
return;
}
DLManagedTensor *dmt = static_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
dmt->deleter(dmt);
});
return capsule;
})
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/tests/test_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def test_dlpack_deletion(self):
dlpack = paddle.utils.dlpack.to_dlpack(a)
b = paddle.utils.dlpack.from_dlpack(dlpack)

def test_to_dlpack_for_loop(self):
# See Paddle issue 50120
for i in range(10):
x = paddle.rand([3, 5])
dlpack = paddle.utils.dlpack.to_dlpack(x)


class TestRaiseError(unittest.TestCase):
def test_from_dlpack_raise_type_error(self):
Expand Down

0 comments on commit 35ce2bd

Please sign in to comment.