Skip to content

Commit

Permalink
Add pytest test cases for Arrow support
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhang Zhang committed Jun 17, 2020
1 parent dbefba8 commit 2bcf1b9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,17 @@ else (USE_CUDA)
endif (USE_CUDA)

if (USE_ARROW)
enable_arrow(objxgboost)
find_package(Arrow REQUIRED)
list(APPEND SRC_LIBS ${ARROW_SHARED_LIB})
find_package(ArrowPython REQUIRED)
list(APPEND SRC_LIBS ${ARROW_PYTHON_SHARED_LIB})
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
list(APPEND SRC_LIBS ${Python3_LIBRARIES})
target_include_directories(objxgboost PRIVATE
${ARROW_INCLUDE_DIR}
${ARROW_PYTHON_INCLUDE_DIR}
${Python3_INCLUDE_DIRS})
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_BUILD_ARROW_SUPPORT=1)
endif (USE_ARROW)

target_include_directories(objxgboost
Expand Down
58 changes: 58 additions & 0 deletions tests/python/test_with_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
import unittest
import pytest
import numpy as np

import testing as tm
import xgboost as xgb

try:
import pyarrow as pa
import pyarrow.csv as pc
import pandas as pd
except ImportError:
pass

pytestmark = pytest.mark.skipif(
tm.no_arrow()['condition'] or tm.no_pandas()['condition'],
reason=tm.no_arrow()['reason'] + ' or ' + tm.no_pandas()['reason'])

dpath = 'demo/data/'

class TestArrowTable(unittest.TestCase):

def test_arrow_table(self):
df = pd.DataFrame([[0, 1, 2., 3.], [1, 2, 3., 4.]],
columns=['a', 'b', 'c', 'd'])
table = pa.Table.from_pandas(df)
dm = xgb.DMatrix(table)
assert dm.num_row() == 2
assert dm.num_col() == 4

def test_arrow_table_with_label(self):
df = pd.DataFrame([[0, 1, 2., 3.], [1, 2, 3., 4.]],
columns=['label', 'a', 'b', 'c'])
table = pa.Table.from_pandas(df)
dm = xgb.DMatrix(table, label='label')
assert dm.num_row() == 2
assert dm.num_col() == 3
np.testing.assert_array_equal(dm.get_label(), np.array([0, 1]))

def test_arrow_table_from_np(self):
coldata = np.array([[1., 1., 0., 0.],
[2., 0., 1., 0.],
[3., 0., 0., 1.]])
cols = list(map(pa.array, coldata))
table = pa.Table.from_arrays(cols, ['a', 'b', 'c'])
dm = xgb.DMatrix(table)
assert dm.num_row() == 4
assert dm.num_col() == 3

def test_arrow_table_from_csv(self):
dfile = dpath + 'veterans_lung_cancer.csv'
table = pc.read_csv(dfile)
dm = xgb.DMatrix(table)
assert dm.num_row() == 137
assert dm.num_col() == 13


6 changes: 6 additions & 0 deletions tests/python/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
except ImportError:
cp = None

from xgboost.compat import PYARROW_INSTALLED

memory = Memory('./cachedir', verbose=0)


Expand Down Expand Up @@ -188,3 +190,7 @@ def _dataset_and_weight(draw):

def non_increasing(L, tolerance=1e-4):
return all((y - x) < tolerance for x, y in zip(L, L[1:]))

def no_arrow():
return {'condition': not PYARROW_INSTALLED,
'reason': 'Pyarrow is not installed'}

0 comments on commit 2bcf1b9

Please sign in to comment.