Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

RFC: DLpack support for interoperability with other GPU frameworks #180

Merged
merged 20 commits into from
Apr 15, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions rfcs/20191016-dlpack-support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# dlpack support for interoperability with other GPU frameworks

| Status | Accepted |
:-------------- |:---------------------------------------------------- |
| **RFC #** | 180 (https://github.com/tensorflow/community/pull/180) (update when you have community PR #)|
| **Author(s)** | eoldridge@nvidia.com, wmjlyjemaine@gmail.com, zhoujinjing09@gmail.com |
| **Sponsor** | apassos@google.com, sanjoy@google.com |
| **Updated** | 2019-11-26 |

## Objective

This document proposes the adoption of dlpack (https://github.com/dmlc/dlpack) as way of passing tensor data to other frameworks without leaving the GPU and without a copy per [24453](https://github.com/tensorflow/tensorflow/issues/24453). dlpack is a community effort to define a common tensor data structure that can be shared by different frameworks. dlpack is currently supported by cuPy, cuDF, DGL, TGL, PyTorch, and MxNet.

The interoperability of dlpack would allow for fast on-GPU communication between TensorFlow and these frameworks opening up a wide range of use cases outlined below. It would further enable \_\_cuda_array_interface\_\_ interoperability through cuPy/cuDF which support both methods providing a way to transfer data to Numba, PyArrow and other frameworks that have adopted that method, although [a similar request has been made to support that method of interoperability](https://github.com/tensorflow/tensorflow/issues/29039) and ideally both would be supported.

A solution has already been developed by @VoVAllen and @jermainewang (coauthored above) as an external python package. This RFC would see the concepts from the package integrated into Tensorflow Core, and reviewed and enhanced by the TF team so that dlpack support is native.

## Motivation

DLPack is a community effort to define a common tensor data structure that can be shared by different frameworks allowing data to be quickly shared often with zero or minimal copy. One of the main bottlenecks when trying to achieve GPU performance when operating across different frameworks is I/O and data formatting. The transfer of data between GPU and CPU or between formats is costly to the point where many operations become faster to simply run on the CPU because of the additional costs associated with moving/transforming the data. Even when mechanisms exist to copy data without leaving the GPU, memory constraints limit the application because two copies of the data are required. By implementing dlpack within TensorFlow there would be a way to transfer data directly between frameworks, enabling the development of a range of applications that weren't previously possible.

Existing applications that take advantage of dlpack include:
- Inline on-gpu preprocessing of tabular data using cuDF to prepare it for deep learning models (continuous normalization, categorical encoding, etc) improving preprocessing performance by 10x over pandas and CPU
- Larger than cpu memory dataloader that iterates over parquet files and batch loads tensors, providing a significant speedup over traditional dataloaders for tabular data
- [End to end acceleration of training on GPU](https://medium.com/rapids-ai/accelerating-deep-learning-recommender-systems-by-15x-using-rapids-fastai-and-pytorch-b50b4d8568d1);
- Use of Tensorflow in conjunction with [tvm](https://github.com/dmlc/tvm); [TF custom op implementation of TVM](https://github.com/tobegit3hub/tftvm)
- Use of Tensorflow in conjunction with [dgl](https://github.com/dmlc/dgl)
- Zero copy transfer of data in [DALI](https://github.com/NVIDIA/DALI) reducing memory requirements.
- [thinc.ai](https://thinc.ai/docs/usage-frameworks) framework interoperability.

Beyond the benefit of specific applications, Tensorflow's adoption of dlpack would further incentivize other frameworks considering its adoption as all three major DL frameworks would now be supporting it. Finally, it would also make the development of applications that operate upstream and downstream of deep learning frameworks easier to develop as a single framework agnostic method could be used in conjunction all DL frameworks.

## User Benefit

Users who wish to utilize other GPU accelerated frameworks like cuDF, cuPy, etc would be able to do so without expensive copy operations. By doing direct dataloading, feature engineering and preprocessing on GPU we see 10-15x speedups over traditional workflows involving CPUs to prepare the data for model readiness in other frameworks and they would be immediately available in tensorflow.

More generally, users would be able to develop preprocessing or other GPU based functionality and be able to support integration with all dl frameworks simplifying development efforts when creating solutions that are upstream or downstream from deep learning models.

A blog post or release notes headline could read "Tensorflow now supports dlpack enabling interoperability with other GPU powered frameworks like cuPy, cuDF, DGL, TGL, PyTorch, and MxNet."

## Design Proposal

A working version of dlpack integration has been released as a package by coauthors @jermainewang and @VoVAllen here:
https://github.com/VoVAllen/tf-dlpack/issues/3

This proposal would leverage that solution and integrate it into TF so that the operations could be performed natively.

User experience
We plan to release a python package tfdlpack, containing two APIs:
```
to_dlpack: Given a tensorflow tensor, return a DLPack tensor contain.
from_dlpack: Given a DLPack-compatible python capsule, return a tensorflow tensor.
```

Example code of converting a Tensorflow tensor to Torch tensor using DLPack using the package:
```python
import numpy as np
import tensorflow as tf
import torch.utils.dlpack as thdlpack
import tfdlpack

t1 = tf.constant([1, 2, 3], dtype=np.float32)
dlpack = tfdlpack.to_dlpack(t1) # tf tensor -> dlpack
t2 = thdlpack.from_dlpack(dlpack) # dlpack -> th tensor
print(t2)
dlpack = thdlpack.to_dlpack(t2) # th tensor -> dlpack
t3 = tfdlpack.from_dlpack(dlpack) # dlpack -> tf tensor
print(t3)
```
You will find that t1, t2 and t3 all have the same values, shape, and device contexts.
Package dependency: tensorflow>=2.0

Proposed code of converting a Tensorflow tensor to Torch tensor using DLPack natively:
```python
import numpy as np
import tensorflow as tf
import tensorflow.experimental.dlpack as tfdlpack
import torch.utils.dlpack as thdlpack


t1 = tf.constant([1, 2, 3], dtype=np.float32)
dlpack = tfdlpack.to_dlpack(t1) # tf tensor -> dlpack
t2 = thdlpack.from_dlpack(dlpack) # dlpack -> th tensor
print(t2)
dlpack = thdlpack.to_dlpack(t2) # th tensor -> dlpack
t3 = tfdlpack.from_dlpack(dlpack) # dlpack -> tf tensor
print(t3)
```

Potential technical problems for this API:
1. Memory usability on async device (to_dlpack)
As mentioned by @alextp
> TF does not use cudamalloc to allocate memory but its own allocator whose internal state is stored on the CPU and matches the head of TF's compute stream, so we need to sync TF's stream before the memory is usable from dlpack and similarly sync other cuda streams before memory is made usable by TF tensors (and similarly we need to sync the streams when trying to free the buffers).
Here we decide to manunally sync the device when exporting TF tensor to dlpack. The sync behavior is done in the `TFE_TensorHandleDevicePointer` API, which returns the pointer to the underlying memory.

2. Memory management (avoid leak) (to_dlpack/from_dlpack)
As the design of dlpack, the framework constructing tensor from dlpack is responsible to call the dlpack's deleter, which is usually dereferencing the underlying buffer, when destructing the constructed tensor.
For `from_dlpack`, a deleter function is registered when constructing the TF tensor, and would be called upon destruction.
For `to_dlpack`, the dlpack data structure will hold a reference (by `TensorReference`) to the underlying buffer, and `unref` it in the dlpack's deleter function.

Proposed API implementation details:
- to_dlpack
- Implementing `TFE_HandleToDLPack`, which converts tf's eager tensor handle to dlpack tensor's pointer(`DLManagedTensor*`). And wrap it into PyCapsule to adapt to the Python interface in ffi binding file. For the underlying memory liveness, `TensorReference` is used to maintain the reference counting over the underlying `TensorBuffer`, which increases when creating dlpack tensor, and decreases in the deleter of dlpack tensor.
- from_dlpack
- Implementing `TFE_HandleFromDLPack`, which converts dlpack tensor's pointer(`DLManagedTensor*`) to tf's eager tensor handle. `TFE_TensorHandleDevicePointer` is used to get the data pointer of underlying buffer, and synchronize the related device to ensures the memory readiness.


## Questions and Discussion Topics

https://github.com/tensorflow/tensorflow/issues/29039#issuecomment-527520270 outlines the key issues that need to be addressed, namely that a synch is required to ensure the tensor information is valid. Supporting [\_\_cuda_array_interface\_\_](https://github.com/tensorflow/tensorflow/issues/29039) is another option as well, although cuPy and cuDF have opted to support both and ideally Tensorflow would as well.

## Reference

### tfdlpack package implementation detail

The first design consideration is that we want to avoid any modification to the main Tensorflow library, so to get around the potential long delay of PR, code review, and release cycle of Tensorflow main package. Inspired by the solution from https://github.com/tobegit3hub/tftvm, we decide to implement the functionality as two custom tensor ops: to_dlpack and from_dlpack.

Besides, we want this feature to be plugged into other projects quite easily. For example, any project that relies on this feature is able to run without compiling against Tensorflow's header files. Not only that an extra dependency usually means extra effort, but also that such maintenance is repetitive and should be handled by the feature developer (i.e., us) alone. To this end, we have an idea of releasing it as a python package. However, the question is how to invoke the two custom tensor ops in python? The challenge is that Tensorflow's custom op interface has a limited support of argument and return types, while to_dlpack and from_dlpack should have an argument/return type of DLPack object. We work around this by encoding the address of an DLPack object as an integer, so it can be accepted/returned by the custom op interface. Then, we decode it in python or C depending on whether we return it (to_dlpack) or consume it (from_dlpack).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this lives within TF, I'd like to see a assessment of whether we can use DT_VARIANT to hold instances of DLManagedTensor.

Copy link
Contributor

@VoVAllen VoVAllen Dec 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At python side, the dlpack is expected to be represented as a PyCapsule holding the address of DLManagedTensor. The framework will cast the address back to the pointer of DLManagedTensor for further usage. Therefore if I understand correctly about DT_VARIANT, it may not help in this case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanjoy Could you provide an example usage of DT_VARIANT? We'd like to investigate more about the proposal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you look at the Variant class I linked above? You should be able to look at its uses to find some examples.


Finally, to achieve the maximal efficiency, we want the conversion happens without memory copy.

For to_dlpack, the returned DLPack tensor shares the same memory address of the input Tensorflow tensor and holds a reference to it. Upon the destruction of the DLPack tensor, it will dereference the Tensorflow tensor, so it can be collected by Tensorflow's memory management. (inspired by PyTorch's DLPack implementation).
For from_dlpack, it first creates an allocator object (subclass Tensorflow's allocator interface) that holds the reference to the DLPack tensor. The AllocateRaw function directly returns the memory it holds without creating any new buffer. Upon destruction, the DeallocateRaw function just calls the deletor of the DLPack tensor. (inspired by Tensorflow's immutable_constant_op).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the deallocate call will have to do a host-device sync as well since the dlpack tensor could have users enqueued in arbitrary streams and free'ing it without waiting for those kernels to finish will cause data races.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically DeallocateRaw won't delete the tensor, but dereferencing the buffer. The data races/data free issue is handled by the original framework which produces this tensor.

Detail: https://github.com/VoVAllen/tf-dlpack/blob/master/src/to_dlpack_kernel.cc#L66

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can work, but we need to be clear on the contract that TF will unref the dlpack Tensor as soon as all uses have been enqueued, and won't want for the kernels to actually finish. As long as this is part of dlpack's contract all is good.

Copy link
Contributor

@byronyi byronyi Dec 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Echo the concern here. Take an example as followed:

Image another framework that mirrors current TF design, say X.

  1. TF_ToDLPackOp increments the TF TensorBuffer ref count, and the tensors produced by the upstream TF_Op is ready for use in stream context A;
  2. X_FromDLPackOp executes in stream context B;
  3. A downstream Op in X consumes this DLPack tensor in stream context B, and you called TF_DeallocateRaw which immediately decrements the ref count;
  4. TF reuses the TensorBuffer and starting to write new data onto it in in stream context A.

How do you plan to sync stream A in TF and stream B in X?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Echo the concern here. Take an example as followed:

Image another framework that mirrors current TF design, say X.

  1. TF_ToDLPackOp increments the TF TensorBuffer ref count, and the tensors produced by the upstream TF_Op is ready for use in stream context A;
  2. X_FromDLPackOp executes in stream context B;
  3. A downstream Op in X consumes this DLPack tensor in stream context B, and you called TF_DeallocateRaw which immediately decrements the ref count;
  4. TF reuses the TensorBuffer and starting to write new data onto it in in stream context A.

How do you plan to sync stream A in TF and stream B in X?

@byronyi
Minjie's comment (#180 (comment)) partially addressed your concern. You are right about the situation, as the operation across stream is not guaranteed by dlpack so far. One solution is to sync the stream producing tensor to ensure the memory is ready before downstream's execution, which you can see from mxnet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're in TensorFlow, you don't need to go through this indirection. You could subclass tensorflow::TensorBuffer to implement the behavior you want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason uses allocator is due to the alignment issue. However I found the alignment requirement is always 64 bytes later. Will this change in the future?
Also I agreeTensorBuffer is simpler and do you think it's fine to always use 64 bytes alignment requirement in dlpack operation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how alignment is related to using TensorBuffer or not. IIUC in from_dlpack you roughly do this:

if (IsAligned(64, dlm_tensor->dl_tensor.data)) {
  // forward dlpack data without memcpy
} else {
  // allocate output buffer and memcpy
}

This seems correct to me and we'll continue doing this when we use a TensorBuffer subclass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I realized alignment is always 64 bytes after I finished my first version which used allocator.