Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Apache Arrow #5667

Closed
wants to merge 20 commits into from

Conversation

zhangzhang10
Copy link
Contributor

@zhangzhang10 zhangzhang10 commented May 14, 2020

This feature allows creating DMatrix from Arrow memory format. It works
with all file formats that PyArrow supports. And it provides better
performance than, for example, XGBoost's CSV interface and than Pandas
read_parquet interface.

Performance of DMatrix creation:

CSV input file      
  higgs nyc-taxi mortgage (small) mortgage (large)
num_rows 550,000 1,458,644 36,190,382 151,675,455
num_cols 30 15 46 46
xgboost csv data interface 1.29 sec 1.37 sec 100.1 sec 382.9 sec
PyArrow 0.09 sec 0.08 sec 12.8 sec 62.2 sec
Parquet input file  
mortgage (small) mortgage (large)
num_rows 36,190,382 151,675,455
num_cols 46 46
Pandas 44.5 sec 185.5 sec
PyArrow 5.2 sec 10.6 sec

Changes:

  • Python: Create DMatrix from pyarrow.Table
  • C: Add XGDMatrixCreateFromArrowTable()
  • C++: Add class ArrowAdapter and class ArrowAdapterBatch
  • CMake: Add FindArrow.cmake, FindArrowPython.cmake

Build:

  • Set env variable ARROW_HOME to the path where Arrow libs, PyArrow
    libs, and Arrow C++ headers are found. Alternatively, install arrow-python , pyarrow, and gxx_linux-64 packages in a Conda environment.
  • Configure the build by passing -DUSE_ARROW=ON to cmake command.

Usage:

import xgboost as xgb
import pyarrow.csv as pc
import pyarrow.parquet as pq

# CSV input
table = pc.read_csv('/path/to/csv/input')
dmat = xgb.DMatrix(table)

# Parquet input
table = pq.read_table('/path/to/parquet/intput')
dmat = xgb.DMatrix(table)

CMakeLists.txt Outdated Show resolved Hide resolved
@trivialfis
Copy link
Member

We have built in support for cudf, which uses arrow. Could you use a similar code path? I'm happy to answer questions.

@zhangzhang10
Copy link
Contributor Author

We have built in support for cudf, which uses arrow. Could you use a similar code path? I'm happy to answer questions.

@trivialfis The support for cudf only uses the same physical memory layout defined by Arrow. It doesn't directly work with Arrow tables or other Arrow constructs. Did I miss anything? If an in-memory Arrow table is already available, then its content can just be copied into a DMatrix. It shouldn't be necessary to use a sequence o json objects (__array_interface__) to transfer the data. I'm not sure how cudf code path can be used in this case. Thanks.

@trivialfis
Copy link
Member

trivialfis commented May 18, 2020

@zhangzhang10 We first export a handle to the columnar memory block (in cudf's case, we chose it to be a list of __cuda_array_interface__) in Python. These interfaces specify all attributes of columnar format like strides, masks, types etc. Then we parse the memory directly according to the __array_interface__. This way we don't need a hard dependency in C++.

Also, once an adapter is defined, there's no need for a new Push method, you can simply pass that adapter into the constructor of SimpleDMatrix and the rest of DMatrix construction should happen automatically.

def _maybe_chunked_array(data, feature_names, feature_types):
if PYARROW_INSTALLED and isinstance(data, ArrowChunkedArray):
if not PANDAS_INSTALLED:
raise Exception(('pandas must be available to use this method.'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose better to say: 'Pandas must be available to use Apache Arrow as an input format'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will incorporate your input.

std::shared_ptr<arrow::Table> table;
CHECK(arrow::py::unwrap_table(data, &table).ok());
data::ArrowAdapter adapter(table, nrow, ncol);
*out = new std::shared_ptr<DMatrix>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is delete for this new?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is in XGDMatrixFree(), which is supposed to be called by client code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it aligned with other DMatrixes?

Comment on lines +673 to +889
// Set number of threads but keep old value so we can reset it after
const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread = nthreadmax;
int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need different amount of threads for all lib and this piece of code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can change # of threads timely by following option:

#pragma omp parallel for schedule(static) num_threads(nthread)
...

It looks simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm just following the convention used by the SparsePage::Push template function, defined above this particular specialization.

Comment on lines +200 to +230
#if defined(XGBOOST_BUILD_ARROW_SUPPORT)
template SimpleDMatrix::SimpleDMatrix(ArrowAdapter* adapter, float missing,
int nthread);
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, it will be good to add tests for Apache Arrow support on C++/Python side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add tests.

@codecov-commenter
Copy link

codecov-commenter commented Jun 9, 2020

Codecov Report

Merging #5667 into master will increase coverage by 0.00%.
The diff coverage is 80.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #5667   +/-   ##
=======================================
  Coverage   78.49%   78.50%           
=======================================
  Files          12       12           
  Lines        3013     3028   +15     
=======================================
+ Hits         2365     2377   +12     
- Misses        648      651    +3     
Impacted Files Coverage Δ
python-package/xgboost/compat.py 54.61% <50.00%> (-0.23%) ⬇️
python-package/xgboost/data.py 59.17% <100.00%> (+0.86%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5a2dcd1...e5a739b. Read the comment docs.

@zhangzhang10
Copy link
Contributor Author

@hcho3, @CodingCat, @trivialfis, @SmirnovEgorRu, can you please review recent changes to PR 5667? This PR introduces support for Apache Arrow. Specifically, this allows creating DMatrix directly from Arrow tables. To move forward, I am about to introduce dependencies on Arrow libs and Pyarrow. For example, I am thinking about adding a new dockerfile in ci_build that installs these dependencies. Is this the right approach? I'd appreciate your feedback. Thanks!

@hcho3
Copy link
Collaborator

hcho3 commented Jun 18, 2020

@zhangzhang10 If I build XGBoost with Arrow support, can I use that built binary to another machine without Arrow installed? If so, we can enable Arrow support in the default builds and distribute it via Pip. As a maintainer of CI, I'd prefer not to add more special cases if we can avoid it.

Currently, we build CUDA support into the XGBoost binary already, and the binary can be used on machines without NVIDIA GPUs. (Requesting a GPU algorithm would produce a run-time error on these machines.)

@hcho3
Copy link
Collaborator

hcho3 commented Jun 18, 2020

From #5754 (comment):

Two conditions should be met before any integration code with third-party tools gets hosted in the main XGBoost repository:

  • Integration brings benefits to a substantial portion of XGBoost users.
  • One or more contributors are available to maintain the integration code in the long term.

Integration with Arrow satisfies the first criterion: many users stand to benefit from fast ingestion of tabular data. It would be even better if XGBoost can support it out of the box: pip install xgboost and get Arrow support right away.

On the other hand, the Arrow support did introduce lots of new C++ code, and I'm concerned about maintenance burden. @zhangzhang10 Would you be willing to take care of this code in the medium and long term? Or find people who would be? Mainly, I want to avoid a situation in the future where the Arrow ingestion feature breaks and no one is around to diagnose or fix the bug.

@SmirnovEgorRu
Copy link
Contributor

Would you be willing to take care of this code in the medium and long term? Or find people who would be?

@hcho3, since it's a part of Intel's efforts to improve XGBoost - I can be responsible for this.

@hcho3
Copy link
Collaborator

hcho3 commented Jun 19, 2020

@SmirnovEgorRu Ah that's a good news. Thanks for letting us know.

@zhangzhang10
Copy link
Contributor Author

Integration with Arrow satisfies the first criterion: many users stand to benefit from fast ingestion of tabular data. It would be even better if XGBoost can support it out of the box: pip install xgboost and get Arrow support right away.

@hcho3 Thanks for the comments. I will take a look at how CUDA support is being integrated in XGBoost binary, and then see if we can do it similarly for Arrow support.

@trivialfis
Copy link
Member

After looking into arrow a little bit more. I think the array interface can be obtained by:

import pyarrow as pa
import pandas as pd
df = pd.DataFrame({"a": [1, 2, 3]})
table: pa.lib.Table = pa.Table.from_pandas(df)
interface = table.column('a').chunks[0].__array__().__array_interface__
print(interface)
>>> {'data': (21839296, True), 'strides': None, 'descr': [('', '<i8')], 'typestr': '<i8', 'shape': (3,), 'version': 3}

Having this, we can parse the memory buffer directly. Just not sure whether calling __array__() generates data copy.

@zhangzhang10
Copy link
Contributor Author

After looking into arrow a little bit more. I think the array interface can be obtained by:

import pyarrow as pa
import pandas as pd
df = pd.DataFrame({"a": [1, 2, 3]})
table: pa.lib.Table = pa.Table.from_pandas(df)
interface = table.column('a').chunks[0].__array__().__array_interface__
print(interface)
>>> {'data': (21839296, True), 'strides': None, 'descr': [('', '<i8')], 'typestr': '<i8', 'shape': (3,), 'version': 3}

Having this, we can parse the memory buffer directly. Just not sure whether calling __array__() generates data copy.

A column of an Arrow table is of type pyarrow.lib.ChunkedArray. It may consist of multiple chunks (buffers). We could do

table.column('a').__array__().__array_interface__

But this will likely lead to data copying because multiple buffers need to be combined.

Or, we could iterate over the chunks and call __array__() on each of them

table.column('a').chunks[0].__array__().__array_interface__
table.column('a').chunks[1].__array__().__array_interface__
table.column('a').chunks[2].__array__().__array_interface__
... ...

To make it more complicated, columns in a table can have different numbers of chunks. So we would end up parsing many buffers and keeping track of their relative positions ourselves, whereas the arrow lib has already provided this facility at no cost (zero copy).

@trivialfis
Copy link
Member

Hi, we are having some headaches for supporting more and more data types as input (currently 15 of them!). And corresponding support for meta data like labels (which is different from data X like you can't push CSR as label). I would like to create something really simple for data dispatching and make it work uniformly. So is it okay if I push to your branch in coming days?

@zhangzhang10
Copy link
Contributor Author

Hi, we are having some headaches for supporting more and more data types as input (currently 15 of them!). And corresponding support for meta data like labels (which is different from data X like you can't push CSR as label). I would like to create something really simple for data dispatching and make it work uniformly. So is it okay if I push to your branch in coming days?

@trivialfis Yes, please feel free to push to my branch. Thank you so much for your help! I really appreciate it.

@trivialfis
Copy link
Member

Let me try working on this during weekend.

@trivialfis
Copy link
Member

The specification has grown a lot since the last time I look into it ...

This feature allows creating DMatrix from Arrow memory format. It works
with all file formats that PyArrow supports. And it provides better
performance than, for example, XGBoost's CSV interface and than Pandas
read_parquet interface.
- Python: Create DMatrix from pyarrow.Table
- C: Add XGDMatrixCreateFromArrowTable()
- C++: Add class ArrowAdapter and class ArrowAdapterBatch
- CMake: Add FindArrow.cmake

Build:
- Set env variable ARROW_ROOT to the path where Arrow libs, PyArrow
libs, and Arrow C++ headers are found.
- Configure the build by passing '-DUSE_ARROW=ON' to cmake command.

Usage:
	import xgboost as xgb
	import pyarrow.csv as pc
	import pyarrow.parquet as pq

	# CSV input
	table = pc.read_csv('/path/to/csv/input')
	dmat = xgb.DMatrix(table)

	# Parquet input
	table = pq.read_table('/path/to/parquet/intput')
	dmat = xgb.DMatrix(table)
@zhangzhang10 zhangzhang10 force-pushed the add-arrow-support branch 3 times, most recently from d960152 to f2208da Compare August 4, 2020 04:18
- Modify run_test.sh cmake command
- Modify setup.sh to include pyarrow 0.17
@trivialfis
Copy link
Member

trivialfis commented Aug 5, 2020

Note to myself, current compilation is failing on this branch:

n file included from /home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include/arrow/python/platform.h:24,
                 from /home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include/arrow/python/pyarrow.h:20,
                 from /home/fis/Workspace/XGBoost/xgboost/include/xgboost/c_api.h:21,
                 from /home/fis/Workspace/XGBoost/xgboost/src/c_api/c_api_error.cc:7:
/usr/include/python3.7m/datetime.h:204:25: warning: ‘PyDateTimeAPI’ defined but not used [-Wunused-variable]
 static PyDateTime_CAPI *PyDateTimeAPI = NULL;
                         ^~~~~~~~~~~~~
[12/233] Building CXX object src/CMakeFiles/objxgboost.dir/data/simple_dmatrix.cc.o
In file included from /home/fis/Workspace/XGBoost/xgboost/src/data/simple_dmatrix.cc:18:
/home/fis/Workspace/XGBoost/xgboost/src/data/adapter.h: In constructor ‘xgboost::data::ArrowAdapterBatch::ArrowAdapterBatch(const RecordBatches&, const TableColumn&, xgboost::bst_row_t, xgboost::bst_feature_t)’:
/home/fis/Workspace/XGBoost/xgboost/src/data/adapter.h:339:26: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<std::shared_ptr<arrow::Array> >::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
       for (auto i = 0; i < arrs.size(); ++i) {
                        ~~^~~~~~~~~~~~~
/home/fis/Workspace/XGBoost/xgboost/src/data/adapter.h: In instantiation of ‘static std::shared_ptr<arrow::NumericArray<T> > xgboost::data::ArrowAdapterBatch::CastArray(const std::shared_ptr<arrow::Array>&, typename T::c_type, bool) [with T = arrow::FloatType; typename T::c_type = float]’:
/home/fis/Workspace/XGBoost/xgboost/src/data/adapter.h:340:65:   required from here
/home/fis/Workspace/XGBoost/xgboost/src/data/adapter.h:392:28: warning: comparison of integer expressions of different signedness: ‘size_t’ {aka ‘long unsigned int’} and ‘int64_t’ {aka ‘long int’} [-Wsign-compare]
       for (size_t j = 0; j < arr->length(); ++j) {
                          ~~^~~~~~~~~~~~~~~
In file included from /home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include/arrow/python/platform.h:24,
                 from /home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include/arrow/python/pyarrow.h:20,
                 from /home/fis/Workspace/XGBoost/xgboost/include/xgboost/c_api.h:21,
                 from /home/fis/Workspace/XGBoost/xgboost/src/data/simple_dmatrix.cc:13:
/usr/include/python3.7m/datetime.h: At global scope:
/usr/include/python3.7m/datetime.h:204:25: warning: ‘PyDateTimeAPI’ defined but not used [-Wunused-variable]
 static PyDateTime_CAPI *PyDateTimeAPI = NULL;
                         ^~~~~~~~~~~~~
[14/233] Building CXX object src/CMakeFiles/objxgboost.dir/predictor/cpu_predictor.cc.o
In file included from /home/fis/Workspace/XGBoost/xgboost/src/predictor/cpu_predictor.cc:19:
/home/fis/Workspace/XGBoost/xgboost/src/predictor/../data/adapter.h: In constructor ‘xgboost::data::ArrowAdapterBatch::ArrowAdapterBatch(const RecordBatches&, const TableColumn&, xgboost::bst_row_t, xgboost::bst_feature_t)’:
/home/fis/Workspace/XGBoost/xgboost/src/predictor/../data/adapter.h:339:26: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<std::shared_ptr<arrow::Array> >::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
       for (auto i = 0; i < arrs.size(); ++i) {
                        ~~^~~~~~~~~~~~~
/home/fis/Workspace/XGBoost/xgboost/src/predictor/../data/adapter.h: In instantiation of ‘static std::shared_ptr<arrow::NumericArray<T> > xgboost::data::ArrowAdapterBatch::CastArray(const std::shared_ptr<arrow::Array>&, typename T::c_type, bool) [with T = arrow::FloatType; typename T::c_type = float]’:
/home/fis/Workspace/XGBoost/xgboost/src/predictor/../data/adapter.h:340:65:   required from here
/home/fis/Workspace/XGBoost/xgboost/src/predictor/../data/adapter.h:392:28: warning: comparison of integer expressions of different signedness: ‘size_t’ {aka ‘long unsigned int’} and ‘int64_t’ {aka ‘long int’} [-Wsign-compare]
       for (size_t j = 0; j < arr->length(); ++j) {
                          ~~^~~~~~~~~~~~~~~
[16/233] Building CXX object src/CMakeFiles/objxgboost.dir/data/sparse_page_dmatrix.cc.o
In file included from /home/fis/Workspace/XGBoost/xgboost/src/data/./sparse_page_source.h:25,
                 from /home/fis/Workspace/XGBoost/xgboost/src/data/./ellpack_page_source.h:14,
                 from /home/fis/Workspace/XGBoost/xgboost/src/data/./sparse_page_dmatrix.h:17,
                 from /home/fis/Workspace/XGBoost/xgboost/src/data/sparse_page_dmatrix.cc:11:
/home/fis/Workspace/XGBoost/xgboost/src/data/./adapter.h: In constructor ‘xgboost::data::ArrowAdapterBatch::ArrowAdapterBatch(const RecordBatches&, const TableColumn&, xgboost::bst_row_t, xgboost::bst_feature_t)’:
/home/fis/Workspace/XGBoost/xgboost/src/data/./adapter.h:339:26: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<std::shared_ptr<arrow::Array> >::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
       for (auto i = 0; i < arrs.size(); ++i) {
                        ~~^~~~~~~~~~~~~
/home/fis/Workspace/XGBoost/xgboost/src/data/./adapter.h: In instantiation of ‘static std::shared_ptr<arrow::NumericArray<T> > xgboost::data::ArrowAdapterBatch::CastArray(const std::shared_ptr<arrow::Array>&, typename T::c_type, bool) [with T = arrow::FloatType; typename T::c_type = float]’:
/home/fis/Workspace/XGBoost/xgboost/src/data/./adapter.h:340:65:   required from here
/home/fis/Workspace/XGBoost/xgboost/src/data/./adapter.h:392:28: warning: comparison of integer expressions of different signedness: ‘size_t’ {aka ‘long unsigned int’} and ‘int64_t’ {aka ‘long int’} [-Wsign-compare]
       for (size_t j = 0; j < arr->length(); ++j) {
                          ~~^~~~~~~~~~~~~~~
[21/233] Building CXX object src/CMakeFiles/objxgboost.dir/data/data.cc.o
In file included from /home/fis/Workspace/XGBoost/xgboost/src/data/data.cc:22:
/home/fis/Workspace/XGBoost/xgboost/src/data/../data/adapter.h: In constructor ‘xgboost::data::ArrowAdapterBatch::ArrowAdapterBatch(const RecordBatches&, const TableColumn&, xgboost::bst_row_t, xgboost::bst_feature_t)’:
/home/fis/Workspace/XGBoost/xgboost/src/data/../data/adapter.h:339:26: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vector<std::shared_ptr<arrow::Array> >::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
       for (auto i = 0; i < arrs.size(); ++i) {
                        ~~^~~~~~~~~~~~~
/home/fis/Workspace/XGBoost/xgboost/src/data/data.cc: In member function ‘uint64_t xgboost::SparsePage::Push(const xgboost::data::ArrowAdapterBatch&, float, int)’:
/home/fis/Workspace/XGBoost/xgboost/src/data/data.cc:916:24: warning: comparison of integer expressions of different signedness: ‘int’ and ‘__gnu_cxx::__alloc_traits<std::allocator<long unsigned int>, long unsigned int>::value_type’ {aka ‘long unsigned int’} [-Wsign-compare]
     for (auto t = 0; t < rb_lengths[i]; ++t) {
In file included from /home/fis/Workspace/XGBoost/xgboost/src/data/data.cc:22:
/home/fis/Workspace/XGBoost/xgboost/src/data/../data/adapter.h: In instantiation of ‘static std::shared_ptr<arrow::NumericArray<T> > xgboost::data::ArrowAdapterBatch::CastArray(const std::shared_ptr<arrow::Array>&, typename T::c_type, bool) [with T = arrow::FloatType; typename T::c_type = float]’:
/home/fis/Workspace/XGBoost/xgboost/src/data/../data/adapter.h:340:65:   required from here
/home/fis/Workspace/XGBoost/xgboost/src/data/../data/adapter.h:392:28: warning: comparison of integer expressions of different signedness: ‘size_t’ {aka ‘long unsigned int’} and ‘int64_t’ {aka ‘long int’} [-Wsign-compare]
       for (size_t j = 0; j < arr->length(); ++j) {
                          ~~^~~~~~~~~~~~~~~
In file included from /home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include/arrow/python/platform.h:24,
                 from /home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include/arrow/python/pyarrow.h:20,
                 from /home/fis/Workspace/XGBoost/xgboost/include/xgboost/c_api.h:21,
                 from /home/fis/Workspace/XGBoost/xgboost/src/data/data.cc:10:
/usr/include/python3.7m/datetime.h: At global scope:
/usr/include/python3.7m/datetime.h:204:25: warning: ‘PyDateTimeAPI’ defined but not used [-Wunused-variable]
 static PyDateTime_CAPI *PyDateTimeAPI = NULL;
                         ^~~~~~~~~~~~~
[35/233] Building CUDA object src/CMakeFiles/objxgboost.dir/common/hist_util.cu.o
FAILED: src/CMakeFiles/objxgboost.dir/common/hist_util.cu.o 
/usr/local/cuda/bin/nvcc -ccbin=/usr/bin/g++-8 -DDMLC_CORE_USE_CMAKE -DDMLC_LOG_CUSTOMIZE=1 -DDMLC_USE_CXX11=1 -DXGBOOST_BUILD_ARROW_SUPPORT=1 -DXGBOOST_BUILTIN_PREFETCH_PRESENT=1 -DXGBOOST_MM_PREFETCH_PRESENT=1 -DXGBOOST_USE_CUDA=1 -DXGBOOST_USE_NCCL=1 -D_GLIBCXX_USE_CXX11_ABI=0 -D_MWAITXINTRIN_H_INCLUDED -D__USE_XOPEN2K8 -I/home/fis/Workspace/XGBoost/xgboost/cub -I/usr/local/include -I/home/fis/Workspace/XGBoost/xgboost/include -I/home/fis/Workspace/XGBoost/xgboost/dmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/include -I/home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include -I/usr/include/python3.7m -Idmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/../dmlc-core/include -g -Xcompiler=-fPIC   --expt-extended-lambda --expt-relaxed-constexpr -lineinfo --generate-code=arch=compute_61,code=sm_61 --generate-code=arch=compute_61,code=compute_61 -Xcompiler=-fopenmp -std=c++14 -x cu -c /home/fis/Workspace/XGBoost/xgboost/src/common/hist_util.cu -o src/CMakeFiles/objxgboost.dir/common/hist_util.cu.o && /usr/local/cuda/bin/nvcc -ccbin=/usr/bin/g++-8 -DDMLC_CORE_USE_CMAKE -DDMLC_LOG_CUSTOMIZE=1 -DDMLC_USE_CXX11=1 -DXGBOOST_BUILD_ARROW_SUPPORT=1 -DXGBOOST_BUILTIN_PREFETCH_PRESENT=1 -DXGBOOST_MM_PREFETCH_PRESENT=1 -DXGBOOST_USE_CUDA=1 -DXGBOOST_USE_NCCL=1 -D_GLIBCXX_USE_CXX11_ABI=0 -D_MWAITXINTRIN_H_INCLUDED -D__USE_XOPEN2K8 -I/home/fis/Workspace/XGBoost/xgboost/cub -I/usr/local/include -I/home/fis/Workspace/XGBoost/xgboost/include -I/home/fis/Workspace/XGBoost/xgboost/dmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/include -I/home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include -I/usr/include/python3.7m -Idmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/../dmlc-core/include -g -Xcompiler=-fPIC   --expt-extended-lambda --expt-relaxed-constexpr -lineinfo --generate-code=arch=compute_61,code=sm_61 --generate-code=arch=compute_61,code=compute_61 -Xcompiler=-fopenmp -std=c++14 -x cu -M /home/fis/Workspace/XGBoost/xgboost/src/common/hist_util.cu -MT src/CMakeFiles/objxgboost.dir/common/hist_util.cu.o -o src/CMakeFiles/objxgboost.dir/common/hist_util.cu.o.d
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/sort.h(886): error: identifier "xgboost::Entry::Entry" is undefined in device code

.......

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/cub/block/../iterator/../thread/thread_load.cuh(126): error: calling a __host__ function("xgboost::Entry::Entry") from a __device__ function("thrust::cuda_cub::__scan_by_key::ScanByKeyAgent< ::thrust::detail::normal_iterator< ::thrust::device_ptr< ::xgboost::Entry> > ,  ::thrust::detail::normal_iterator< ::thrust::device_ptr<float> > ,  ::thrust::detail::normal_iterator< ::thrust::device_ptr<float> > ,  ::,  ::thrust::plus<float> , int, float,  ::thrust::detail::integral_constant<bool, (bool)1> > ::impl::consume_tile<(bool)0,  ::thrust::cuda_cub::__scan_by_key::DoNothing<float> > ") is not allowed

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/cub/block/../iterator/../thread/thread_load.cuh(126): error: identifier "xgboost::Entry::Entry" is undefined in device code

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/scan_by_key.h(453): error: identifier "xgboost::Entry::Entry" is undefined in device code

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/scan_by_key.h(383): error: identifier "xgboost::Entry::Entry" is undefined in device code

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/scan_by_key.h(427): error: identifier "xgboost::Entry::Entry" is undefined in device code

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/cub/block/../iterator/../thread/thread_load.cuh(126): error: calling a __host__ function("xgboost::Entry::Entry") from a __device__ function("thrust::cuda_cub::__scan_by_key::ScanByKeyAgent< ::thrust::detail::normal_iterator< ::thrust::device_ptr< ::xgboost::Entry> > ,  ::thrust::detail::normal_iterator< ::thrust::device_ptr<float> > ,  ::thrust::detail::normal_iterator< ::thrust::device_ptr<float> > ,  ::,  ::thrust::plus<float> , int, float,  ::thrust::detail::integral_constant<bool, (bool)1> > ::impl::consume_tile<(bool)1,  ::thrust::cuda_cub::__scan_by_key::DoNothing<float> > ") is not allowed

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/cub/block/../iterator/../thread/thread_load.cuh(126): error: identifier "xgboost::Entry::Entry" is undefined in device code

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/scan_by_key.h(453): error: identifier "xgboost::Entry::Entry" is undefined in device code

22 errors detected in the compilation of "/tmp/tmpxft_0000157f_00000000-6_hist_util.cpp1.ii".
[37/233] Building CUDA object src/CMakeFiles/objxgboost.dir/data/simple_dmatrix.cu.o
FAILED: src/CMakeFiles/objxgboost.dir/data/simple_dmatrix.cu.o 
/usr/local/cuda/bin/nvcc -ccbin=/usr/bin/g++-8 -DDMLC_CORE_USE_CMAKE -DDMLC_LOG_CUSTOMIZE=1 -DDMLC_USE_CXX11=1 -DXGBOOST_BUILD_ARROW_SUPPORT=1 -DXGBOOST_BUILTIN_PREFETCH_PRESENT=1 -DXGBOOST_MM_PREFETCH_PRESENT=1 -DXGBOOST_USE_CUDA=1 -DXGBOOST_USE_NCCL=1 -D_GLIBCXX_USE_CXX11_ABI=0 -D_MWAITXINTRIN_H_INCLUDED -D__USE_XOPEN2K8 -I/home/fis/Workspace/XGBoost/xgboost/cub -I/usr/local/include -I/home/fis/Workspace/XGBoost/xgboost/include -I/home/fis/Workspace/XGBoost/xgboost/dmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/include -I/home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include -I/usr/include/python3.7m -Idmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/../dmlc-core/include -g -Xcompiler=-fPIC   --expt-extended-lambda --expt-relaxed-constexpr -lineinfo --generate-code=arch=compute_61,code=sm_61 --generate-code=arch=compute_61,code=compute_61 -Xcompiler=-fopenmp -std=c++14 -x cu -c /home/fis/Workspace/XGBoost/xgboost/src/data/simple_dmatrix.cu -o src/CMakeFiles/objxgboost.dir/data/simple_dmatrix.cu.o && /usr/local/cuda/bin/nvcc -ccbin=/usr/bin/g++-8 -DDMLC_CORE_USE_CMAKE -DDMLC_LOG_CUSTOMIZE=1 -DDMLC_USE_CXX11=1 -DXGBOOST_BUILD_ARROW_SUPPORT=1 -DXGBOOST_BUILTIN_PREFETCH_PRESENT=1 -DXGBOOST_MM_PREFETCH_PRESENT=1 -DXGBOOST_USE_CUDA=1 -DXGBOOST_USE_NCCL=1 -D_GLIBCXX_USE_CXX11_ABI=0 -D_MWAITXINTRIN_H_INCLUDED -D__USE_XOPEN2K8 -I/home/fis/Workspace/XGBoost/xgboost/cub -I/usr/local/include -I/home/fis/Workspace/XGBoost/xgboost/include -I/home/fis/Workspace/XGBoost/xgboost/dmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/include -I/home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include -I/usr/include/python3.7m -Idmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/../dmlc-core/include -g -Xcompiler=-fPIC   --expt-extended-lambda --expt-relaxed-constexpr -lineinfo --generate-code=arch=compute_61,code=sm_61 --generate-code=arch=compute_61,code=compute_61 -Xcompiler=-fopenmp -std=c++14 -x cu -M /home/fis/Workspace/XGBoost/xgboost/src/data/simple_dmatrix.cu -MT src/CMakeFiles/objxgboost.dir/data/simple_dmatrix.cu.o -o src/CMakeFiles/objxgboost.dir/data/simple_dmatrix.cu.o.d
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/copy_if.h(371): error: identifier "xgboost::Entry::Entry" is undefined in device code

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/copy_if.h(371): error: identifier "xgboost::Entry::Entry" is undefined in device code

....

8 errors detected in the compilation of "/tmp/tmpxft_00001679_00000000-6_simple_dmatrix.cpp1.ii".
[39/233] Building CUDA object src/CMakeFiles/objxgboost.dir/data/iterative_device_dmatrix.cu.o
FAILED: src/CMakeFiles/objxgboost.dir/data/iterative_device_dmatrix.cu.o 
/usr/local/cuda/bin/nvcc -ccbin=/usr/bin/g++-8 -DDMLC_CORE_USE_CMAKE -DDMLC_LOG_CUSTOMIZE=1 -DDMLC_USE_CXX11=1 -DXGBOOST_BUILD_ARROW_SUPPORT=1 -DXGBOOST_BUILTIN_PREFETCH_PRESENT=1 -DXGBOOST_MM_PREFETCH_PRESENT=1 -DXGBOOST_USE_CUDA=1 -DXGBOOST_USE_NCCL=1 -D_GLIBCXX_USE_CXX11_ABI=0 -D_MWAITXINTRIN_H_INCLUDED -D__USE_XOPEN2K8 -I/home/fis/Workspace/XGBoost/xgboost/cub -I/usr/local/include -I/home/fis/Workspace/XGBoost/xgboost/include -I/home/fis/Workspace/XGBoost/xgboost/dmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/include -I/home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include -I/usr/include/python3.7m -Idmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/../dmlc-core/include -g -Xcompiler=-fPIC   --expt-extended-lambda --expt-relaxed-constexpr -lineinfo --generate-code=arch=compute_61,code=sm_61 --generate-code=arch=compute_61,code=compute_61 -Xcompiler=-fopenmp -std=c++14 -x cu -c /home/fis/Workspace/XGBoost/xgboost/src/data/iterative_device_dmatrix.cu -o src/CMakeFiles/objxgboost.dir/data/iterative_device_dmatrix.cu.o && /usr/local/cuda/bin/nvcc -ccbin=/usr/bin/g++-8 -DDMLC_CORE_USE_CMAKE -DDMLC_LOG_CUSTOMIZE=1 -DDMLC_USE_CXX11=1 -DXGBOOST_BUILD_ARROW_SUPPORT=1 -DXGBOOST_BUILTIN_PREFETCH_PRESENT=1 -DXGBOOST_MM_PREFETCH_PRESENT=1 -DXGBOOST_USE_CUDA=1 -DXGBOOST_USE_NCCL=1 -D_GLIBCXX_USE_CXX11_ABI=0 -D_MWAITXINTRIN_H_INCLUDED -D__USE_XOPEN2K8 -I/home/fis/Workspace/XGBoost/xgboost/cub -I/usr/local/include -I/home/fis/Workspace/XGBoost/xgboost/include -I/home/fis/Workspace/XGBoost/xgboost/dmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/include -I/home/fis/Workspace/XGBoost/lib/python3.7/site-packages/pyarrow/include -I/usr/include/python3.7m -Idmlc-core/include -I/home/fis/Workspace/XGBoost/xgboost/rabit/../dmlc-core/include -g -Xcompiler=-fPIC   --expt-extended-lambda --expt-relaxed-constexpr -lineinfo --generate-code=arch=compute_61,code=sm_61 --generate-code=arch=compute_61,code=compute_61 -Xcompiler=-fopenmp -std=c++14 -x cu -M /home/fis/Workspace/XGBoost/xgboost/src/data/iterative_device_dmatrix.cu -MT src/CMakeFiles/objxgboost.dir/data/iterative_device_dmatrix.cu.o -o src/CMakeFiles/objxgboost.dir/data/iterative_device_dmatrix.cu.o.d
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/copy_if.h(371): error: identifier "xgboost::Entry::Entry" is undefined in device code

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/copy_if.h(371): error: identifier "xgboost::Entry::Entry" is undefined in device code

....

/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/copy_if.h(371): error: identifier "xgboost::Entry::Entry" is undefined in device code


24 errors detected in the compilation of "/tmp/tmpxft_00001630_00000000-6_iterative_device_dmatrix.cpp1.ii".
[40/233] Building CUDA object src/CMakeFiles/objxgboost.dir/common/host_device_vector.cu.o
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/allocator/allocator_traits.inl(97): warning: calling a __host__ function from a __host__ __device__ function is not allowed
          detected during:
            instantiation of "thrust::detail::disable_if<thrust::detail::allocator_traits_detail::has_member_construct1<Alloc, T>::value, void>::type thrust::detail::allocator_traits_detail::construct(Alloc &, T *) [with Alloc=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
(310): here
            instantiation of "void thrust::detail::allocator_traits<Alloc>::construct(thrust::detail::allocator_traits<Alloc>::allocator_type &, T *) [with Alloc=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/allocator/default_construct_range.inl(46): here
            instantiation of "void thrust::detail::allocator_traits_detail::construct1_via_allocator<Allocator>::operator()(T &) [with Allocator=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/function.h(51): here
            instantiation of "Result thrust::detail::wrapped_function<Function, Result>::operator()(Argument &) const [with Function=thrust::detail::allocator_traits_detail::construct1_via_allocator<dh::XGBDeviceAllocator<xgboost::Entry>>, Result=void, Argument=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/for_each.h(57): here
            instantiation of "void thrust::cuda_cub::for_each_f<Input, UnaryOp>::operator()(Size) [with Input=thrust::device_ptr<xgboost::Entry>, UnaryOp=thrust::detail::wrapped_function<thrust::detail::allocator_traits_detail::construct1_via_allocator<dh::XGBDeviceAllocator<xgboost::Entry>>, void>, Size=unsigned long]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/parallel_for.h(98): here
            [ 12 instantiation contexts not shown ]
            instantiation of "void thrust::detail::vector_base<T, Alloc>::append(thrust::detail::vector_base<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/vector_base.inl(324): here
            instantiation of "void thrust::detail::vector_base<T, Alloc>::resize(thrust::detail::vector_base<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/home/fis/Workspace/XGBoost/xgboost/src/common/host_device_vector.cu(252): here
            instantiation of "void xgboost::HostDeviceVectorImpl<T>::LazyResizeDevice(size_t) [with T=xgboost::Entry]" 
/home/fis/Workspace/XGBoost/xgboost/src/common/host_device_vector.cu(43): here
            instantiation of "xgboost::HostDeviceVectorImpl<T>::HostDeviceVectorImpl(const Initializer &, int) [with T=xgboost::Entry, Initializer=std::initializer_list<xgboost::Entry>]" 
/home/fis/Workspace/XGBoost/xgboost/src/common/host_device_vector.cu(275): here
            instantiation of "xgboost::HostDeviceVector<T>::HostDeviceVector(std::initializer_list<T>, int) [with T=xgboost::Entry]" 
/home/fis/Workspace/XGBoost/xgboost/src/common/host_device_vector.cu(402): here

[44/233] Building CUDA object src/CMakeFiles/objxgboost.dir/linear/updater_gpu_coordinate.cu.o
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/allocator/allocator_traits.inl(97): warning: calling a __host__ function from a __host__ __device__ function is not allowed
          detected during:
            instantiation of "thrust::detail::disable_if<thrust::detail::allocator_traits_detail::has_member_construct1<Alloc, T>::value, void>::type thrust::detail::allocator_traits_detail::construct(Alloc &, T *) [with Alloc=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
(310): here
            instantiation of "void thrust::detail::allocator_traits<Alloc>::construct(thrust::detail::allocator_traits<Alloc>::allocator_type &, T *) [with Alloc=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/allocator/default_construct_range.inl(46): here
            instantiation of "void thrust::detail::allocator_traits_detail::construct1_via_allocator<Allocator>::operator()(T &) [with Allocator=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/function.h(51): here
            instantiation of "Result thrust::detail::wrapped_function<Function, Result>::operator()(Argument &) const [with Function=thrust::detail::allocator_traits_detail::construct1_via_allocator<dh::XGBDeviceAllocator<xgboost::Entry>>, Result=void, Argument=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/for_each.h(57): here
            instantiation of "void thrust::cuda_cub::for_each_f<Input, UnaryOp>::operator()(Size) [with Input=thrust::device_ptr<xgboost::Entry>, UnaryOp=thrust::detail::wrapped_function<thrust::detail::allocator_traits_detail::construct1_via_allocator<dh::XGBDeviceAllocator<xgboost::Entry>>, void>, Size=unsigned long]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/parallel_for.h(98): here
            [ 9 instantiation contexts not shown ]
            instantiation of "thrust::detail::enable_if<thrust::detail::allocator_traits_detail::needs_default_construct_via_allocator<Allocator, thrust::detail::pointer_element<Pointer>::type>::value, void>::type thrust::detail::allocator_traits_detail::default_construct_range(Allocator &, Pointer, Size) [with Allocator=dh::XGBDeviceAllocator<xgboost::Entry>, Pointer=thrust::device_ptr<xgboost::Entry>, Size=std::size_t]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/allocator/default_construct_range.inl(105): here
            instantiation of "void thrust::detail::default_construct_range(Allocator &, Pointer, Size) [with Allocator=dh::XGBDeviceAllocator<xgboost::Entry>, Pointer=thrust::device_ptr<xgboost::Entry>, Size=std::size_t]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/contiguous_storage.inl(251): here
            instantiation of "void thrust::detail::contiguous_storage<T, Alloc>::default_construct_n(thrust::detail::contiguous_storage<T, Alloc>::iterator, thrust::detail::contiguous_storage<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/vector_base.inl(848): here
            instantiation of "void thrust::detail::vector_base<T, Alloc>::append(thrust::detail::vector_base<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/vector_base.inl(324): here
            instantiation of "void thrust::detail::vector_base<T, Alloc>::resize(thrust::detail::vector_base<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/home/fis/Workspace/XGBoost/xgboost/src/linear/updater_gpu_coordinate.cu(89): here

[48/233] Building CUDA object src/CMakeFiles/objxgboost.dir/data/ellpack_page.cu.o
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/allocator/allocator_traits.inl(97): warning: calling a __host__ function from a __host__ __device__ function is not allowed
          detected during:
            instantiation of "thrust::detail::disable_if<thrust::detail::allocator_traits_detail::has_member_construct1<Alloc, T>::value, void>::type thrust::detail::allocator_traits_detail::construct(Alloc &, T *) [with Alloc=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
(310): here
            instantiation of "void thrust::detail::allocator_traits<Alloc>::construct(thrust::detail::allocator_traits<Alloc>::allocator_type &, T *) [with Alloc=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/allocator/default_construct_range.inl(46): here
            instantiation of "void thrust::detail::allocator_traits_detail::construct1_via_allocator<Allocator>::operator()(T &) [with Allocator=dh::XGBDeviceAllocator<xgboost::Entry>, T=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/function.h(51): here
            instantiation of "Result thrust::detail::wrapped_function<Function, Result>::operator()(Argument &) const [with Function=thrust::detail::allocator_traits_detail::construct1_via_allocator<dh::XGBDeviceAllocator<xgboost::Entry>>, Result=void, Argument=xgboost::Entry]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/for_each.h(57): here
            instantiation of "void thrust::cuda_cub::for_each_f<Input, UnaryOp>::operator()(Size) [with Input=thrust::device_ptr<xgboost::Entry>, UnaryOp=thrust::detail::wrapped_function<thrust::detail::allocator_traits_detail::construct1_via_allocator<dh::XGBDeviceAllocator<xgboost::Entry>>, void>, Size=unsigned long]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/system/cuda/detail/parallel_for.h(98): here
            [ 10 instantiation contexts not shown ]
            instantiation of "void thrust::detail::default_construct_range(Allocator &, Pointer, Size) [with Allocator=dh::XGBDeviceAllocator<xgboost::Entry>, Pointer=thrust::device_ptr<xgboost::Entry>, Size=std::size_t]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/contiguous_storage.inl(251): here
            instantiation of "void thrust::detail::contiguous_storage<T, Alloc>::default_construct_n(thrust::detail::contiguous_storage<T, Alloc>::iterator, thrust::detail::contiguous_storage<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/vector_base.inl(220): here
            instantiation of "void thrust::detail::vector_base<T, Alloc>::default_init(thrust::detail::vector_base<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/detail/vector_base.inl(65): here
            instantiation of "thrust::detail::vector_base<T, Alloc>::vector_base(thrust::detail::vector_base<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/usr/local/cuda/bin/../targets/x86_64-linux/include/thrust/device_vector.h(95): here
            instantiation of "thrust::device_vector<T, Alloc>::device_vector(thrust::device_vector<T, Alloc>::size_type) [with T=xgboost::Entry, Alloc=dh::XGBDeviceAllocator<xgboost::Entry>]" 
/home/fis/Workspace/XGBoost/xgboost/src/data/ellpack_page.cu(398): here

[50/233] Building CUDA object src/CMakeFiles/objxgboost.dir/metric/survival_metric.cu.o
ninja: build stopped: subcommand failed.

real	1m0.871s
user	11m54.580s
sys	0m41.546s

@trivialfis
Copy link
Member

I will continue the work on this on a private branch first to avoid spamming the CI.

@zhangzhang10
Copy link
Contributor Author

I will continue the work on this on a private branch first to avoid spamming the CI.

@trivialfis Are you still working on this PR? Any update? Please let me know if I can help. Thanks!

@trivialfis
Copy link
Member

Sorry for leaving this PR open without update. At the moment, no. I really don't want to have a C/C++ level dependency. But arrow is specified as such that it can trunk the data arbitrarily, which is really painful for me to have a simple adapter.

@zhangzhang10
Copy link
Contributor Author

@trivialfis It's been a long while since we touched this PR. One major reason this has been stuck so long is the dependence on Arrow C++ library introduced. It would be a headache to maintain yet another third-party dependence, given the complex dependences XGBoost already has. I'd like to propose a solution that can completely remove Arrow C++ dependence but still support the Arrow data format. The Arrow C data interface is just a couple of C structs that enables a project to integrate with the Arrow format only. It makes it possible to easily exchange columnar data between different runtime components, e.g. between the Python and C++ layers. This can also be used to support other types of columnar data, such as Spark ColumnarBatch, as long as it can be marshalled into the Arrow C data interface.

Shall we continue to use this PR to further discuss this idea? Or, shall we close this and create another one? Thanks!

@trivialfis
Copy link
Member

Sorry for the stalled progress. Revisiting.

@trivialfis
Copy link
Member

It's not the dependency that I'm worrying about, it's the specailized code.

@trivialfis
Copy link
Member

I will pick up the progress starting Monday, just cleared out some other to-do items.

@trivialfis
Copy link
Member

Doing some refactorings.

@zhangzhang10
Copy link
Contributor Author

It's not the dependency that I'm worrying about, it's the specailized code.

If by "specialized code" you mean code written with the only purpose of supporting Arrow, then I believe the worry is unnecessary.

Firstly, my new proposal of using the Arrow C Data Interface aims at supporting a general columnar data format. Currently in XGBoost there isn't columnar data format support in general, except for the cudf support on GPUs. But my proposal is a solution that all kinds of columnar data sources can benefit from. It simply integrates the Arrow format only. It doesn't integrate the Arrow API. Thanks to the Arrow C Data Interface, the format can be expressed with only two C structs. And their specification is frozen by the Arrow project. This means we have a robust mechanism to exchange columnar data between different XGBoost components, for example, between the C++ lib and the Python layer (with help of Python FFI), and between the C++ lib and the JVM layer (with help of JNI).

Secondly, many ingredients of the implementation already exist. The Arrow project itself maintains a producer that generates the aforementioned C structs from Arrow arrays or Arrow record batches. This mechanism is also available in PyArrow. Other columnar data sources can easily model the same approach (even reuse much of the code) of the producer. On the other hand, within XGBoost, creating DMatrix from cudf, whose data format is based on the Arrow columnar format, is already supported. We should be able to borrow some of the mechanisms to provide more general support for the columnar format.

Let me know what you think. Thanks.

@trivialfis
Copy link
Member

trivialfis commented May 18, 2021

@zhangzhang10 Thanks for the detailed explanation! For the "specialized code" part, I was referring to the new Push function. One difference between pyarrow input and cuDF is, cuDF doesn't chunk the data, which makes it quite easy to get all the data in one go and merge the code path with other inputs. For arrow, if we want complete support (instead of calling to_pandas), then there's a need to create a different Push function for chunked data as you have accomplished in this PR. Also arrow is heterogeneous which means we might also need to create yet another code path for CUDA (or maybe future implementation on OneAPI?). Right now we have CPU adapters and GPU adapters being the stopping point of divergence in backends. This makes the code relatively simple (which is important) and is the basic assumption of the inplace_predict function. If we put the ArrowAdapter in this PR into inplace_predict it won't compile. I'm happy to be wrong here and feel free to point out.

But I will look deeper and keep you posted here.

@trivialfis
Copy link
Member

trivialfis commented May 18, 2021

I think @SmirnovEgorRu and @ShvetsKS are also familiar with the inplace predict implementation. Feel free to share your opinion. I can help refactor the code.

@zhangzhang10
Copy link
Contributor Author

@zhangzhang10 Thanks for the detailed explanation! For the "specialized code" part, I was referring to the new Push function. One difference between pyarrow input and cuDF is, cuDF doesn't chunk the data, which makes it quite easy to get all the data in one go and merge the code path with other inputs. For arrow, if we want complete support (instead of calling to_pandas), then there's a need to create a different Push function for chunked data as you have accomplished in this PR. Also arrow is heterogeneous which means we might also need to create yet another code path for CUDA (or maybe future implementation on OneAPI?). Right now we have CPU adapters and GPU adapters being the stopping point of divergence in backends. This makes the code relatively simple (which is important) and is the basic assumption of the inplace_predict function. If we put the ArrowAdapter in this PR into inplace_predict it won't compile. I'm happy to be wrong here and feel free to point out.

But I will look deeper and keep you posted here.

@trivialfis Thanks for your thoughts! You are correct that in order to support Arrow data format we need a new Push function. We also need a new Adapter class. But I don't think data chunking in pyarrow will be a problem for inplace_predict. I looked deeper into this function and realized that it depends on an Adapter that can produce one row at a time for a prediction. As long as the Arrow adapter complies with this requirement then it doesn't need to pull in all data at once. Let me know if I'm mistaken here.

It is true we need to maintain the new Push function and the new Adapter class, including some auxiliary functions and data structures. However, given the fact that the Arrow data format definition and the C data interface are both stable, there shouldn't be a lot of maintenance going on once the implementations are done.

As to the code path for CUDA, cudf itself is based on the Arrow data format. So it sort of already achieves the same goal we want to achieve, supporting the Arrow columnar data format. The work I propose should not affect the GPU backend and will not change the existing situation of having both CPU and GPU adapters. But it would be desirable to structure our implementation in a way to reuse and unify some key data structures across the CPU and GPU code path. One example, probably, is the DataIter class that is currently only available for GPU code path. We could modify it and make it also works for the CPU backend to handle the chunked data from pyarrow.

Thanks.

@zhangzhang10
Copy link
Contributor Author

@trivialfis @hcho3, I have a rework of this PR. Let's close this one and continue the discussion in the new PR (#7283). Thanks.

@zhangzhang10
Copy link
Contributor Author

Closed. See #7283

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants