diff --git a/python/lance/__init__.py b/python/lance/__init__.py index 83e602677a8..5c76bfa4954 100644 --- a/python/lance/__init__.py +++ b/python/lance/__init__.py @@ -1,8 +1,28 @@ +# Copyright 2022 Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union +from pathlib import Path + +import pyarrow as pa import pyarrow.dataset as ds -from lance.lib import LanceFileFormat +from lance.lib import LanceFileFormat, WriteTable + +__all__ = ["dataset", "write_table"] -def dataset(uri: str): +def dataset(uri: str) -> ds.Dataset: """ Create an Arrow Dataset from the given lance uri. @@ -13,3 +33,18 @@ def dataset(uri: str): """ fmt = LanceFileFormat() return ds.dataset(uri, format=fmt) + + +def write_table(table: pa.Table, destination: Union[str, Path], primary_key: str): + """Write an Arrow Table into the destination. + + Parameters + ---------- + table : pa.Table + Apache Arrow Table + destination : str or `Path` + The destination to write dataset to. + primary_key : str + The column name of the primary key. + """ + WriteTable(table, destination, primary_key) diff --git a/python/lance/_lib.pyx b/python/lance/_lib.pyx index 6161d5d9a5d..4c9768fa6a6 100644 --- a/python/lance/_lib.pyx +++ b/python/lance/_lib.pyx @@ -1,17 +1,63 @@ +# distutils: language = c++ + +from typing import Union + +from cython.operator cimport dereference as deref +from libcpp cimport bool from libcpp.memory cimport shared_ptr -from pyarrow.includes.libarrow_dataset cimport CFileFormat +from libcpp.string cimport string +from pathlib import Path +from pyarrow import Table from pyarrow._dataset cimport FileFormat +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport CTable, COutputStream +from pyarrow.includes.libarrow_dataset cimport CFileFormat +from pyarrow.lib cimport pyarrow_unwrap_table, check_status, get_writer -cdef extern from "lance/arrow/file_lance.h" namespace "lance" nogil: +cdef extern from "" namespace "std" nogil: + # Backport https://github.com/cython/cython/blob/master/Cython/Includes/libcpp/optional.pxd + # before cython 3.0 release + cdef cppclass nullopt_t: + nullopt_t() + + cdef nullopt_t nullopt + + cdef cppclass optional[T]: + ctypedef T value_type + optional() + optional(nullopt_t) + optional(optional &) except + + optional(T &) except + + bool has_value() + T& value() + T& value_or[U](U& default_value) + void swap(optional &) + void reset() + T& emplace(...) + T& operator *() + #T* operator->() # Not Supported + optional& operator=(optional &) + optional& operator=[U](U &) + +cdef extern from "lance/arrow/file_lance.h" namespace "lance" nogil: cdef cppclass CLanceFileFormat "::lance::arrow::LanceFileFormat"( - CFileFormat): + CFileFormat): pass + cdef cppclass CFileWriteOptions "::lance::arrow::FileWriteOptions" -cdef class LanceFileFormat(FileFormat): +cdef extern from "lance/arrow/writer.h" namespace "lance::arrow" nogil: + CStatus CWriteTable "::lance::arrow::WriteTable"( + const CTable& table, + shared_ptr[COutputStream] sink, + const c_string& primary_key, + optional[CFileWriteOptions] options) + + +cdef class LanceFileFormat(FileFormat): def __init__(self): self.init(shared_ptr[CFileFormat](new CLanceFileFormat())) @@ -25,3 +71,14 @@ cdef class LanceFileFormat(FileFormat): def __reduce__(self): return LanceFileFormat, tuple() +def WriteTable(table: Table, + sink: Union[str, Path], + primary_key: str): + arrow_table = pyarrow_unwrap_table(table) + cdef shared_ptr[COutputStream] out + get_writer(sink, &out) + cdef string pk = primary_key.encode("utf-8") + + cdef optional[CFileWriteOptions] options = nullopt + with nogil: + check_status(CWriteTable(deref(arrow_table), out, pk, options)) diff --git a/python/lance/tests/api_test.py b/python/lance/tests/api_test.py new file mode 100644 index 00000000000..c6799fcd522 --- /dev/null +++ b/python/lance/tests/api_test.py @@ -0,0 +1,32 @@ +# Copyright 2022 Lance Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import pandas as pd +import pyarrow as pa +from lance import write_table, dataset + + +def test_simple_round_trips(tmp_path: Path): + table = pa.Table.from_pandas(pd.DataFrame({"label": [123, 456, 789], "values": [22, 33, 2.24]})) + write_table(table, tmp_path / "test.lance", "label") + + assert (tmp_path / "test.lance").exists() + + ds = dataset(str(tmp_path / "test.lance")) + actual = ds.to_table() + + assert (table == actual) diff --git a/python/pyproject.toml b/python/pyproject.toml index 7d550d35741..b46503f2b16 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -14,7 +14,7 @@ minversion = "6.0" addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] xfail_strict = true filterwarnings = ["error"] -testpaths = ["pylance/tests"] +testpaths = ["lance/tests"] [tool.cibuildwheel] test-command = "pytest {project}/pylance/tests" diff --git a/python/setup.py b/python/setup.py index bdc5a83b9fb..17380c055b0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -37,7 +37,7 @@ long_description="", ext_modules=cythonize(extensions, language_level="3"), zip_safe=False, - extras_require={"test": ["pytest>=6.0"]}, + extras_require={"test": ["pytest>=6.0", "pandas"]}, python_requires=">=3.8", packages=find_packages() )