diff --git a/.github/workflows/python_test.yaml b/.github/workflows/python_test.yaml index 13516ff699da..e689396b5dcd 100644 --- a/.github/workflows/python_test.yaml +++ b/.github/workflows/python_test.yaml @@ -53,7 +53,7 @@ jobs: pip install -r requirements.txt maturin develop - python -m unittest discover tests + pytest -v . env: CARGO_HOME: "/home/runner/.cargo" CARGO_TARGET_DIR: "/home/runner/target" diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 6126699bbc1f..96beccd0af81 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -105,3 +105,4 @@ benchmarks/queries/q*.sql ballista/rust/scheduler/testdata/* ballista/ui/scheduler/yarn.lock python/rust-toolchain +python/requirements*.txt diff --git a/python/requirements.in b/python/requirements.in index 3ef9f18966d4..4ff7f4ee618b 100644 --- a/python/requirements.in +++ b/python/requirements.in @@ -17,3 +17,4 @@ maturin toml pyarrow +pytest diff --git a/python/requirements.txt b/python/requirements.txt index ff02b80cf6fc..f7ede1ebd58e 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,25 +1,17 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. # # This file is autogenerated by pip-compile # To update, run: # -# pip-compile --generate-hashes +# pip-compile --generate-hashes requirements.in # +attrs==21.2.0 \ + --hash=sha256:149e90d6d8ac20db7a955ad60cf0e6881a3f20d37096140088356da6c716b0b1 \ + --hash=sha256:ef6aaac3ca6cd92904cdd0d83f629a15f18053ec84e6432106f7a4d04ae4f5fb + # via pytest +iniconfig==1.1.1 \ + --hash=sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3 \ + --hash=sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32 + # via pytest maturin==0.10.6 \ --hash=sha256:0e81496f70a4805e6ea7dda7b0425246c111ccb119a2e22c64abeff131f4dd21 \ --hash=sha256:3b5d5429bc05a816824420d99973f0cab39d8e274f6c3647bfd9afd95a030304 \ @@ -59,6 +51,18 @@ numpy==1.20.3 \ --hash=sha256:f1452578d0516283c87608a5a5548b0cdde15b99650efdfd85182102ef7a7c17 \ --hash=sha256:f39a995e47cb8649673cfa0579fbdd1cdd33ea497d1728a6cb194d6252268e48 # via pyarrow +packaging==20.9 \ + --hash=sha256:5b327ac1320dc863dca72f4514ecc086f31186744b84a230374cc1fd776feae5 \ + --hash=sha256:67714da7f7bc052e064859c05c595155bd1ee9f69f76557e21f051443c20947a + # via pytest +pluggy==0.13.1 \ + --hash=sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0 \ + --hash=sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d + # via pytest +py==1.10.0 \ + --hash=sha256:21b81bda15b66ef5e1a777a21c4dcd9c20ad3efd0b3f817e7a809035269e1bd3 \ + --hash=sha256:3b80836aa6d1feeaa108e046da6423ab8f6ceda6468545ae8d02d9d58d18818a + # via pytest pyarrow==4.0.1 \ --hash=sha256:04be0f7cb9090bd029b5b53bed628548fef569e5d0b5c6cd7f6d0106dbbc782d \ --hash=sha256:0fde9c7a3d5d37f3fe5d18c4ed015e8f585b68b26d72a10d7012cad61afe43ff \ @@ -86,9 +90,18 @@ pyarrow==4.0.1 \ --hash=sha256:fa7b165cfa97158c1e6d15c68428317b4f4ae786d1dc2dbab43f1328c1eb43aa \ --hash=sha256:fe976695318560a97c6d31bba828eeca28c44c6f6401005e54ba476a28ac0a10 # via -r requirements.in +pyparsing==2.4.7 \ + --hash=sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1 \ + --hash=sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b + # via packaging +pytest==6.2.4 \ + --hash=sha256:50bcad0a0b9c5a72c8e4e7c9855a3ad496ca6a881a3641b4260605450772c54b \ + --hash=sha256:91ef2131a9bd6be8f76f1f08eac5c5317221d6ad1e143ae03894b862e8976890 + # via -r requirements.in toml==0.10.2 \ --hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \ --hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f # via # -r requirements.in # maturin + # pytest diff --git a/python/tests/generic.py b/python/tests/generic.py index 267d6f656ce0..e61542e6ab37 100644 --- a/python/tests/generic.py +++ b/python/tests/generic.py @@ -16,24 +16,30 @@ # under the License. import datetime -import numpy -import pyarrow + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq # used to write parquet files -import pyarrow.parquet def data(): - data = numpy.concatenate( - [numpy.random.normal(0, 0.01, size=50), numpy.random.normal(50, 0.01, size=50)] + np.random.seed(1) + data = np.concatenate( + [ + np.random.normal(0, 0.01, size=50), + np.random.normal(50, 0.01, size=50), + ] ) - return pyarrow.array(data) + return pa.array(data) def data_with_nans(): - data = numpy.random.normal(0, 0.01, size=50) - mask = numpy.random.randint(0, 2, size=50) - data[mask == 0] = numpy.NaN + np.random.seed(0) + data = np.random.normal(0, 0.01, size=50) + mask = np.random.randint(0, 2, size=50) + data[mask == 0] = np.NaN return data @@ -43,8 +49,19 @@ def data_datetime(f): datetime.datetime.now() - datetime.timedelta(days=1), datetime.datetime.now() + datetime.timedelta(days=1), ] - return pyarrow.array( - data, type=pyarrow.timestamp(f), mask=numpy.array([False, True, False]) + return pa.array( + data, type=pa.timestamp(f), mask=np.array([False, True, False]) + ) + + +def data_date32(): + data = [ + datetime.date(2000, 1, 1), + datetime.date(1980, 1, 1), + datetime.date(2030, 1, 1), + ] + return pa.array( + data, type=pa.date32(), mask=np.array([False, True, False]) ) @@ -54,16 +71,16 @@ def data_timedelta(f): datetime.timedelta(days=1), datetime.timedelta(seconds=1), ] - return pyarrow.array( - data, type=pyarrow.duration(f), mask=numpy.array([False, True, False]) + return pa.array( + data, type=pa.duration(f), mask=np.array([False, True, False]) ) def data_binary_other(): - return numpy.array([1, 0, 0], dtype="u4") + return np.array([1, 0, 0], dtype="u4") def write_parquet(path, data): - table = pyarrow.Table.from_arrays([data], names=["a"]) - pyarrow.parquet.write_table(table, path) - return path + table = pa.Table.from_arrays([data], names=["a"]) + pq.write_table(table, path) + return str(path) diff --git a/python/tests/test_df.py b/python/tests/test_df.py index fdafdfa7f509..5b6cbddbd74b 100644 --- a/python/tests/test_df.py +++ b/python/tests/test_df.py @@ -15,100 +15,98 @@ # specific language governing permissions and limitations # under the License. -import unittest - import pyarrow as pa -import datafusion +import pytest +from datafusion import ExecutionContext +from datafusion import functions as f + + +@pytest.fixture +def df(): + ctx = ExecutionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) -f = datafusion.functions + return ctx.create_dataframe([[batch]]) -class TestCase(unittest.TestCase): - def _prepare(self): - ctx = datafusion.ExecutionContext() +def test_select(df): + df = df.select( + f.col("a") + f.col("b"), + f.col("a") - f.col("b"), + ) - # create a RecordBatch and a new DataFrame from it - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) + # execute and collect the first (and only) batch + result = df.collect()[0] - def test_select(self): - df = self._prepare() + assert result.column(0) == pa.array([5, 7, 9]) + assert result.column(1) == pa.array([-3, -3, -3]) - df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ) - # execute and collect the first (and only) batch - result = df.collect()[0] +def test_filter(df): + df = df.select( + f.col("a") + f.col("b"), + f.col("a") - f.col("b"), + ).filter(f.col("a") > f.lit(2)) - self.assertEqual(result.column(0), pa.array([5, 7, 9])) - self.assertEqual(result.column(1), pa.array([-3, -3, -3])) + # execute and collect the first (and only) batch + result = df.collect()[0] - def test_filter(self): - df = self._prepare() + assert result.column(0) == pa.array([9]) + assert result.column(1) == pa.array([-3]) - df = df.select( - f.col("a") + f.col("b"), - f.col("a") - f.col("b"), - ).filter(f.col("a") > f.lit(2)) - # execute and collect the first (and only) batch - result = df.collect()[0] +def test_sort(df): + df = df.sort([f.col("b").sort(ascending=False)]) - self.assertEqual(result.column(0), pa.array([9])) - self.assertEqual(result.column(1), pa.array([-3])) + table = pa.Table.from_batches(df.collect()) + expected = {"a": [3, 2, 1], "b": [6, 5, 4]} - def test_sort(self): - df = self._prepare() - df = df.sort([f.col("b").sort(ascending=False)]) + assert table.to_pydict() == expected - table = pa.Table.from_batches(df.collect()) - expected = {"a": [3, 2, 1], "b": [6, 5, 4]} - self.assertEqual(table.to_pydict(), expected) - def test_limit(self): - df = self._prepare() +def test_limit(df): + df = df.limit(1) - df = df.limit(1) + # execute and collect the first (and only) batch + result = df.collect()[0] - # execute and collect the first (and only) batch - result = df.collect()[0] + assert len(result.column(0)) == 1 + assert len(result.column(1)) == 1 - self.assertEqual(len(result.column(0)), 1) - self.assertEqual(len(result.column(1)), 1) - def test_udf(self): - df = self._prepare() +def test_udf(df): + # is_null is a pa function over arrays + udf = f.udf(lambda x: x.is_null(), [pa.int64()], pa.bool_()) - # is_null is a pa function over arrays - udf = f.udf(lambda x: x.is_null(), [pa.int64()], pa.bool_()) + df = df.select(udf(f.col("a"))) + result = df.collect()[0].column(0) - df = df.select(udf(f.col("a"))) + assert result == pa.array([False, False, False]) - self.assertEqual(df.collect()[0].column(0), pa.array([False, False, False])) - def test_join(self): - ctx = datafusion.ExecutionContext() +def test_join(): + ctx = ExecutionContext() - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2, 3]), pa.array([4, 5, 6])], - names=["a", "b"], - ) - df = ctx.create_dataframe([[batch]]) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]]) - batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2]), pa.array([8, 10])], - names=["a", "c"], - ) - df1 = ctx.create_dataframe([[batch]]) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([8, 10])], + names=["a", "c"], + ) + df1 = ctx.create_dataframe([[batch]]) - df = df.join(df1, on="a", how="inner") - df = df.sort([f.col("a").sort(ascending=True)]) - table = pa.Table.from_batches(df.collect()) + df = df.join(df1, on="a", how="inner") + df = df.sort([f.col("a").sort(ascending=True)]) + table = pa.Table.from_batches(df.collect()) - expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} - self.assertEqual(table.to_pydict(), expected) + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 117284973fb7..361526d06970 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -15,286 +15,182 @@ # specific language governing permissions and limitations # under the License. -import unittest -import tempfile -import datetime -import os.path -import shutil +import numpy as np +import pyarrow as pa +import pytest +from datafusion import ExecutionContext -import numpy -import pyarrow -import datafusion +from . import generic as helpers -# used to write parquet files -import pyarrow.parquet -from tests.generic import * +@pytest.fixture +def ctx(): + return ExecutionContext() -class TestCase(unittest.TestCase): - def setUp(self): - # Create a temporary directory - self.test_dir = tempfile.mkdtemp() - numpy.random.seed(1) +def test_no_table(ctx): + with pytest.raises(Exception, match="DataFusion error"): + ctx.sql("SELECT a FROM b").collect() - def tearDown(self): - # Remove the directory after the test - shutil.rmtree(self.test_dir) - def test_no_table(self): - with self.assertRaises(Exception): - datafusion.Context().sql("SELECT a FROM b").collect() +def test_register(ctx, tmp_path): + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + ctx.register_parquet("t", path) - def test_register(self): - ctx = datafusion.ExecutionContext() + assert ctx.tables() == {"t"} - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) - ctx.register_parquet("t", path) +def test_execute(ctx, tmp_path): + data = [1, 1, 2, 2, 3, 11, 12] - self.assertEqual(ctx.tables(), {"t"}) + # single column, "a" + path = helpers.write_parquet(tmp_path / "a.parquet", pa.array(data)) + ctx.register_parquet("t", path) - def test_execute(self): - data = [1, 1, 2, 2, 3, 11, 12] + assert ctx.tables() == {"t"} - ctx = datafusion.ExecutionContext() + # count + result = ctx.sql("SELECT COUNT(a) FROM t").collect() - # single column, "a" - path = write_parquet( - os.path.join(self.test_dir, "a.parquet"), pyarrow.array(data) - ) - ctx.register_parquet("t", path) + expected = pa.array([7], pa.uint64()) + expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + assert result == expected - self.assertEqual(ctx.tables(), {"t"}) + # where + expected = pa.array([2], pa.uint64()) + expected = [pa.RecordBatch.from_arrays([expected], ["COUNT(a)"])] + result = ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() + assert result == expected - # count - result = ctx.sql("SELECT COUNT(a) FROM t").collect() + # group by + results = ctx.sql( + "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" + ).collect() - expected = pyarrow.array([7], pyarrow.uint64()) - expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - self.assertEqual(expected, result) + # group by returns batches + result_keys = [] + result_values = [] + for result in results: + pydict = result.to_pydict() + result_keys.extend(pydict["CAST(a AS Int32)"]) + result_values.extend(pydict["COUNT(a)"]) - # where - expected = pyarrow.array([2], pyarrow.uint64()) - expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])] - self.assertEqual( - expected, ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect() - ) + result_keys, result_values = ( + list(t) for t in zip(*sorted(zip(result_keys, result_values))) + ) - # group by - results = ctx.sql( - "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)" - ).collect() - - # group by returns batches - result_keys = [] - result_values = [] - for result in results: - pydict = result.to_pydict() - result_keys.extend(pydict["CAST(a AS Int32)"]) - result_values.extend(pydict["COUNT(a)"]) - - result_keys, result_values = ( - list(t) for t in zip(*sorted(zip(result_keys, result_values))) - ) + assert result_keys == [1, 2, 3, 11, 12] + assert result_values == [2, 2, 1, 1, 1] - self.assertEqual(result_keys, [1, 2, 3, 11, 12]) - self.assertEqual(result_values, [2, 2, 1, 1, 1]) - - # order by - result = ctx.sql( - "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" - ).collect() - expected_a = pyarrow.array([50.0219, 50.0152], pyarrow.float64()) - expected_cast = pyarrow.array([50, 50], pyarrow.int32()) - expected = [ - pyarrow.RecordBatch.from_arrays( - [expected_a, expected_cast], ["a", "CAST(a AS Int32)"] - ) - ] - numpy.testing.assert_equal(expected[0].column(1), expected[0].column(1)) - - def test_cast(self): - """ - Verify that we can cast - """ - ctx = datafusion.ExecutionContext() - - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data()) - ctx.register_parquet("t", path) - - valid_types = [ - "smallint", - "int", - "bigint", - "float(32)", - "float(64)", - "float", - ] - - select = ", ".join( - [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)] + # order by + result = ctx.sql( + "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2" + ).collect() + expected_a = pa.array([50.0219, 50.0152], pa.float64()) + expected_cast = pa.array([50, 50], pa.int32()) + expected = [ + pa.RecordBatch.from_arrays( + [expected_a, expected_cast], ["a", "CAST(a AS Int32)"] ) - - # can execute, which implies that we can cast - ctx.sql(f"SELECT {select} FROM t").collect() - - def _test_udf(self, udf, args, return_type, array, expected): - ctx = datafusion.ExecutionContext() - - # write to disk - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), array) - ctx.register_parquet("t", path) - - ctx.register_udf("udf", udf, args, return_type) - - batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() - - result = batches[0].column(0) - - self.assertEqual(expected, result) - - def test_udf_identity(self): - self._test_udf( + ] + np.testing.assert_equal(expected[0].column(1), expected[0].column(1)) + + +def test_cast(ctx, tmp_path): + """ + Verify that we can cast + """ + path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) + ctx.register_parquet("t", path) + + valid_types = [ + "smallint", + "int", + "bigint", + "float(32)", + "float(64)", + "float", + ] + + select = ", ".join( + [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)] + ) + + # can execute, which implies that we can cast + ctx.sql(f"SELECT {select} FROM t").collect() + + +@pytest.mark.parametrize( + ("fn", "input_types", "output_type", "input_values", "expected_values"), + [ + ( lambda x: x, - [pyarrow.float64()], - pyarrow.float64(), - pyarrow.array([-1.2, None, 1.2]), - pyarrow.array([-1.2, None, 1.2]), - ) - - def test_udf(self): - self._test_udf( + [pa.float64()], + pa.float64(), + [-1.2, None, 1.2], + [-1.2, None, 1.2], + ), + ( lambda x: x.is_null(), - [pyarrow.float64()], - pyarrow.bool_(), - pyarrow.array([-1.2, None, 1.2]), - pyarrow.array([False, True, False]), - ) - - -class TestIO(unittest.TestCase): - def setUp(self): - # Create a temporary directory - self.test_dir = tempfile.mkdtemp() - - def tearDown(self): - # Remove the directory after the test - shutil.rmtree(self.test_dir) - - def _test_data(self, data): - ctx = datafusion.ExecutionContext() - - # write to disk - path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data) - ctx.register_parquet("t", path) - - batches = ctx.sql("SELECT a AS tt FROM t").collect() - - result = batches[0].column(0) - - numpy.testing.assert_equal(data, result) - - def test_nans(self): - self._test_data(data_with_nans()) - - def test_utf8(self): - array = pyarrow.array( - ["a", "b", "c"], pyarrow.utf8(), numpy.array([False, True, False]) - ) - self._test_data(array) - - def test_large_utf8(self): - array = pyarrow.array( - ["a", "b", "c"], pyarrow.large_utf8(), numpy.array([False, True, False]) - ) - self._test_data(array) - - # Error from Arrow - @unittest.expectedFailure - def test_datetime_s(self): - self._test_data(data_datetime("s")) - - # C data interface missing - @unittest.expectedFailure - def test_datetime_ms(self): - self._test_data(data_datetime("ms")) - - # C data interface missing - @unittest.expectedFailure - def test_datetime_us(self): - self._test_data(data_datetime("us")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_datetime_ns(self): - self._test_data(data_datetime("ns")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_s(self): - self._test_data(data_timedelta("s")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_ms(self): - self._test_data(data_timedelta("ms")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_us(self): - self._test_data(data_timedelta("us")) - - # Not writtable to parquet - @unittest.expectedFailure - def test_timedelta_ns(self): - self._test_data(data_timedelta("ns")) - - def test_date32(self): - array = pyarrow.array( - [ - datetime.date(2000, 1, 1), - datetime.date(1980, 1, 1), - datetime.date(2030, 1, 1), - ], - pyarrow.date32(), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_binary_variable(self): - array = pyarrow.array( - [b"1", b"2", b"3"], pyarrow.binary(), numpy.array([False, True, False]) - ) - self._test_data(array) - - # C data interface missing - @unittest.expectedFailure - def test_binary_fixed(self): - array = pyarrow.array( - [b"1111", b"2222", b"3333"], - pyarrow.binary(4), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_large_binary(self): - array = pyarrow.array( - [b"1111", b"2222", b"3333"], - pyarrow.large_binary(), - numpy.array([False, True, False]), - ) - self._test_data(array) - - def test_binary_other(self): - self._test_data(data_binary_other()) - - def test_bool(self): - array = pyarrow.array( - [False, True, True], None, numpy.array([False, True, False]) - ) - self._test_data(array) - - def test_u32(self): - array = pyarrow.array([0, 1, 2], None, numpy.array([False, True, False])) - self._test_data(array) + [pa.float64()], + pa.bool_(), + [-1.2, None, 1.2], + [False, True, False], + ), + ], +) +def test_udf( + ctx, tmp_path, fn, input_types, output_type, input_values, expected_values +): + # write to disk + path = helpers.write_parquet( + tmp_path / "a.parquet", pa.array(input_values) + ) + ctx.register_parquet("t", path) + ctx.register_udf("udf", fn, input_types, output_type) + + batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect() + result = batches[0].column(0) + + assert result == pa.array(expected_values) + + +_null_mask = np.array([False, True, False]) + + +@pytest.mark.parametrize( + "arr", + [ + pa.array(["a", "b", "c"], pa.utf8(), _null_mask), + pa.array(["a", "b", "c"], pa.large_utf8(), _null_mask), + pa.array([b"1", b"2", b"3"], pa.binary(), _null_mask), + pa.array([b"1111", b"2222", b"3333"], pa.large_binary(), _null_mask), + pa.array([False, True, True], None, _null_mask), + pa.array([0, 1, 2], None), + helpers.data_binary_other(), + helpers.data_date32(), + helpers.data_with_nans(), + # C data interface missing + pytest.param( + pa.array([b"1111", b"2222", b"3333"], pa.binary(4), _null_mask), + marks=pytest.mark.xfail, + ), + pytest.param(helpers.data_datetime("s"), marks=pytest.mark.xfail), + pytest.param(helpers.data_datetime("ms"), marks=pytest.mark.xfail), + pytest.param(helpers.data_datetime("us"), marks=pytest.mark.xfail), + pytest.param(helpers.data_datetime("ns"), marks=pytest.mark.xfail), + # Not writtable to parquet + pytest.param(helpers.data_timedelta("s"), marks=pytest.mark.xfail), + pytest.param(helpers.data_timedelta("ms"), marks=pytest.mark.xfail), + pytest.param(helpers.data_timedelta("us"), marks=pytest.mark.xfail), + pytest.param(helpers.data_timedelta("ns"), marks=pytest.mark.xfail), + ], +) +def test_simple_select(ctx, tmp_path, arr): + path = helpers.write_parquet(tmp_path / "a.parquet", arr) + ctx.register_parquet("t", path) + + batches = ctx.sql("SELECT a AS tt FROM t").collect() + result = batches[0].column(0) + + np.testing.assert_equal(result, arr) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index e1e4f933a9b4..b24c08dbc867 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -15,12 +15,11 @@ # specific language governing permissions and limitations # under the License. -import unittest -import pyarrow -import pyarrow.compute -import datafusion - -f = datafusion.functions +import pyarrow as pa +import pyarrow.compute as pc +import pytest +from datafusion import ExecutionContext +from datafusion import functions as f class Accumulator: @@ -29,63 +28,54 @@ class Accumulator: """ def __init__(self): - self._sum = pyarrow.scalar(0.0) + self._sum = pa.scalar(0.0) - def to_scalars(self) -> [pyarrow.Scalar]: + def to_scalars(self) -> [pa.Scalar]: return [self._sum] - def update(self, values: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar( - self._sum.as_py() + pyarrow.compute.sum(values).as_py() - ) + def update(self, values: pa.Array) -> None: + # Not nice since pyarrow scalars can't be summed yet. + # This breaks on `None` + self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: pyarrow.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar( - self._sum.as_py() + pyarrow.compute.sum(states).as_py() - ) + def merge(self, states: pa.Array) -> None: + # Not nice since pyarrow scalars can't be summed yet. + # This breaks on `None` + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states).as_py()) - def evaluate(self) -> pyarrow.Scalar: + def evaluate(self) -> pa.Scalar: return self._sum -class TestCase(unittest.TestCase): - def _prepare(self): - ctx = datafusion.ExecutionContext() +@pytest.fixture +def df(): + ctx = ExecutionContext() - # create a RecordBatch and a new DataFrame from it - batch = pyarrow.RecordBatch.from_arrays( - [pyarrow.array([1, 2, 3]), pyarrow.array([4, 4, 6])], - names=["a", "b"], - ) - return ctx.create_dataframe([[batch]]) + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]]) - def test_aggregate(self): - df = self._prepare() - udaf = f.udaf( - Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] - ) +def test_aggregate(df): + udaf = f.udaf(Accumulator, pa.float64(), pa.float64(), [pa.float64()]) - df = df.aggregate([], [udaf(f.col("a"))]) + df = df.aggregate([], [udaf(f.col("a"))]) - # execute and collect the first (and only) batch - result = df.collect()[0] + # execute and collect the first (and only) batch + result = df.collect()[0] - self.assertEqual(result.column(0), pyarrow.array([1.0 + 2.0 + 3.0])) + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) - def test_group_by(self): - df = self._prepare() - udaf = f.udaf( - Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()] - ) +def test_group_by(df): + udaf = f.udaf(Accumulator, pa.float64(), pa.float64(), [pa.float64()]) - df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) + df = df.aggregate([f.col("b")], [udaf(f.col("a"))]) - # execute and collect the first (and only) batch - batches = df.collect() - arrays = [batch.column(1) for batch in batches] - joined = pyarrow.concat_arrays(arrays) - self.assertEqual(joined, pyarrow.array([1.0 + 2.0, 3.0])) + batches = df.collect() + arrays = [batch.column(1) for batch in batches] + joined = pa.concat_arrays(arrays) + assert joined == pa.array([1.0 + 2.0, 3.0])