diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 99a80c33bc9e..3638e0387d41 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -52,7 +52,17 @@ else (USE_CUDA) endif (USE_CUDA) if (USE_ARROW) - enable_arrow(objxgboost) + find_package(Arrow REQUIRED) + list(APPEND SRC_LIBS ${ARROW_SHARED_LIB}) + find_package(ArrowPython REQUIRED) + list(APPEND SRC_LIBS ${ARROW_PYTHON_SHARED_LIB}) + find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + list(APPEND SRC_LIBS ${Python3_LIBRARIES}) + target_include_directories(objxgboost PRIVATE + ${ARROW_INCLUDE_DIR} + ${ARROW_PYTHON_INCLUDE_DIR} + ${Python3_INCLUDE_DIRS}) + target_compile_definitions(objxgboost PRIVATE -DXGBOOST_BUILD_ARROW_SUPPORT=1) endif (USE_ARROW) target_include_directories(objxgboost diff --git a/tests/python/test_with_arrow.py b/tests/python/test_with_arrow.py new file mode 100644 index 000000000000..66cee92ee93f --- /dev/null +++ b/tests/python/test_with_arrow.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +import unittest +import pytest +import numpy as np + +import testing as tm +import xgboost as xgb + +try: + import pyarrow as pa + import pyarrow.csv as pc + import pandas as pd +except ImportError: + pass + +pytestmark = pytest.mark.skipif( + tm.no_arrow()['condition'] or tm.no_pandas()['condition'], + reason=tm.no_arrow()['reason'] + ' or ' + tm.no_pandas()['reason']) + +dpath = 'demo/data/' + +class TestArrowTable(unittest.TestCase): + + def test_arrow_table(self): + df = pd.DataFrame([[0, 1, 2., 3.], [1, 2, 3., 4.]], + columns=['a', 'b', 'c', 'd']) + table = pa.Table.from_pandas(df) + dm = xgb.DMatrix(table) + assert dm.num_row() == 2 + assert dm.num_col() == 4 + + def test_arrow_table_with_label(self): + df = pd.DataFrame([[0, 1, 2., 3.], [1, 2, 3., 4.]], + columns=['label', 'a', 'b', 'c']) + table = pa.Table.from_pandas(df) + dm = xgb.DMatrix(table, label='label') + assert dm.num_row() == 2 + assert dm.num_col() == 3 + np.testing.assert_array_equal(dm.get_label(), np.array([0, 1])) + + def test_arrow_table_from_np(self): + coldata = np.array([[1., 1., 0., 0.], + [2., 0., 1., 0.], + [3., 0., 0., 1.]]) + cols = list(map(pa.array, coldata)) + table = pa.Table.from_arrays(cols, ['a', 'b', 'c']) + dm = xgb.DMatrix(table) + assert dm.num_row() == 4 + assert dm.num_col() == 3 + + def test_arrow_table_from_csv(self): + dfile = dpath + 'veterans_lung_cancer.csv' + table = pc.read_csv(dfile) + dm = xgb.DMatrix(table) + assert dm.num_row() == 137 + assert dm.num_col() == 13 + + diff --git a/tests/python/testing.py b/tests/python/testing.py index 4f9f3394aadc..81dc845269ad 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -13,6 +13,8 @@ except ImportError: cp = None +from xgboost.compat import PYARROW_INSTALLED + memory = Memory('./cachedir', verbose=0) @@ -188,3 +190,7 @@ def _dataset_and_weight(draw): def non_increasing(L, tolerance=1e-4): return all((y - x) < tolerance for x, y in zip(L, L[1:])) + +def no_arrow(): + return {'condition': not PYARROW_INSTALLED, + 'reason': 'Pyarrow is not installed'}