Skip to content

Commit

Permalink
feat: api letsql api methods (letsql#105)
Browse files Browse the repository at this point in the history
- replace the ibis api methods in the library
- add register_dataframe to support arbitrary expressions
over datafusion expressions
  • Loading branch information
mesejo authored Jun 25, 2024
1 parent 759c97d commit 37df954
Show file tree
Hide file tree
Showing 27 changed files with 2,009 additions and 251 deletions.
10 changes: 9 additions & 1 deletion python/letsql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from letsql import examples
from letsql.config import options
from letsql.expr import api
from letsql.expr.api import * # noqa: F403
from letsql.backends.let import Backend


Expand All @@ -12,7 +14,13 @@
except ModuleNotFoundError:
import importlib_metadata

__all__ = ["examples", "connect", "options"]
__all__ = [ # noqa: PLE0604
"api",
"examples",
"connect",
"options",
*api.__all__,
]


def connect() -> Backend:
Expand Down
72 changes: 16 additions & 56 deletions python/letsql/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
import letsql.internal as df
from letsql.backends.datafusion.compiler import DataFusionCompiler
from letsql.backends.datafusion.provider import IbisTableProvider
from letsql.internal import SessionConfig, SessionContext, TableProvider, Table
from letsql.internal import (
SessionConfig,
SessionContext,
TableProvider,
Table,
DataFrame,
)

if TYPE_CHECKING:
import pandas as pd
Expand All @@ -46,35 +52,11 @@ class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema,
def version(self):
return letsql.__version__

def do_connect(
self, config: Mapping[str, str | Path] | SessionContext | None = None
) -> None:
"""Create a Datafusion backend for use with Ibis.
Parameters
----------
config
Mapping of table names to files.
Examples
--------
>>> import letsql as ls
>>> config = {"t": "path/to/file.parquet", "s": "path/to/file.csv"}
>>> ls.connect(config)
"""
if isinstance(config, SessionContext):
(self.con, config) = (config, None)
else:
if config is not None and not isinstance(config, Mapping):
raise TypeError("Input to datafusion.connect must be a mapping")
if SessionConfig is not None:
df_config = SessionConfig(
{"datafusion.sql_parser.dialect": "PostgreSQL"}
).with_information_schema(True)
else:
df_config = None
self.con = SessionContext(config=df_config)
def do_connect(self, config: Mapping[str, str | Path] | None = None) -> None:
df_config = SessionConfig(
{"datafusion.sql_parser.dialect": "PostgreSQL"}
).with_information_schema(True)
self.con = SessionContext(config=df_config)

self._register_builtin_udfs()

Expand Down Expand Up @@ -185,9 +167,6 @@ def raw_sql(self, query: str | sge.Expression) -> Any:
----------
query
Raw SQL string
kwargs
Backend specific query arguments
"""
with contextlib.suppress(AttributeError):
query = query.sql(dialect=self.dialect, pretty=True)
Expand Down Expand Up @@ -305,29 +284,6 @@ def register(
kwargs
Datafusion-specific keyword arguments
Examples
--------
Register a csv:
>>> import ibis
>>> conn = ibis.datafusion.connect(config)
>>> conn.register("path/to/data.csv", "my_table")
>>> conn.table("my_table")
Register a PyArrow table:
>>> import pyarrow as pa
>>> tab = pa.table({"x": [1, 2, 3]})
>>> conn.register(tab, "my_table")
>>> conn.table("my_table")
Register a PyArrow dataset:
>>> import pyarrow.dataset as ds
>>> dataset = ds.dataset("path/to/table")
>>> conn.register(dataset, "my_table")
>>> conn.table("my_table")
"""
import pandas as pd

Expand Down Expand Up @@ -380,6 +336,10 @@ def register(
self.con.deregister_table(table_ident)
self.con.register_table(table_ident, source)
return self.table(table_name)
elif isinstance(source, DataFrame):
self.con.deregister_table(table_ident)
self.con.register_dataframe(table_ident, source)
return self.table(table_name)
else:
raise ValueError(f"Unknown `source` type {type(source)}")

Expand Down
53 changes: 34 additions & 19 deletions python/letsql/backends/let/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pyarrow_hotfix # noqa: F401
from ibis import BaseBackend
from ibis.expr import types as ir
from ibis.expr import operations as ops
from ibis.expr.schema import SchemaLike
from ibis.backends.datafusion import Backend as IbisDataFusionBackend
from sqlglot import exp, parse_one
Expand All @@ -24,7 +23,6 @@
replace_cache_table,
)
from letsql.expr.translate import sql_to_ibis
from letsql.internal import SessionContext


def _get_datafusion_table(con, table_name, database="public"):
Expand All @@ -33,6 +31,16 @@ def _get_datafusion_table(con, table_name, database="public"):
return public.table(table_name)


def _get_datafusion_dataframe(con, expr, **kwargs):
con._register_udfs(expr)
con._register_in_memory_tables(expr)

table_expr = expr.as_table()
raw_sql = con.compile(table_expr, **kwargs)

return con.con.sql(raw_sql)


class Backend(DataFusionBackend):
name = "let"

Expand All @@ -51,17 +59,13 @@ def register(
table_or_expr = source.op()
backend = source._find_backend(use_default=False)

if backend == self:
table_or_expr = self._sources.get_table_or_op(table_or_expr)
original_backend = self._sources.get_backend(table_or_expr)
is_a_datafusion_backed_table = isinstance(
original_backend, (DataFusionBackend, IbisDataFusionBackend)
) and isinstance(table_or_expr, ops.DatabaseTable)
if is_a_datafusion_backed_table:
source = _get_datafusion_table(
original_backend.con, table_or_expr.name
)
table_or_expr = None
if isinstance(backend, Backend):
if backend is self:
table_or_expr = self._sources.get_table_or_op(table_or_expr)
backend = self._sources.get_backend(table_or_expr)

if isinstance(backend, (DataFusionBackend, IbisDataFusionBackend)):
source = _get_datafusion_dataframe(backend, source)

registered_table = super().register(source, table_name=table_name, **kwargs)
self._sources[registered_table.op()] = table_or_expr or registered_table.op()
Expand Down Expand Up @@ -99,7 +103,7 @@ def read_delta(
return registered_table

def execute(self, expr: ir.Expr, **kwargs: Any):
not_multi_engine = self._get_source(expr) != self
not_multi_engine = self._get_source(expr) is not self
if (
not_multi_engine
): # this means is a single source that is not the letsql backend
Expand All @@ -112,14 +116,25 @@ def replace_table(node, _, **_kwargs):
expr = self._register_and_transform_cache_tables(expr)
backend = self._get_source(expr)

if backend == self:
if backend is self:
backend = super()

return backend.execute(expr, **kwargs)

def do_connect(
self, config: Mapping[str, str | Path] | SessionContext | None = None
) -> None:
def do_connect(self, config: Mapping[str, str | Path] | None = None) -> None:
"""Creates a connection.
Parameters
----------
config
Mapping of table names to files.
Examples
--------
>>> import letsql as ls
>>> con = ls.connect()
"""
super().do_connect(config=config)

def _get_source(self, expr: ir.Expr):
Expand Down Expand Up @@ -161,7 +176,7 @@ def fn(node, _, **kwargs):
uncached_to_expr = uncached.to_expr()
node = storage.set_default(uncached_to_expr, uncached)
table = node.to_expr()
if node.source == self:
if node.source is self:
table = _get_datafusion_table(self.con, node.name)
self.register(table, table_name=node.name)
return node
Expand Down
11 changes: 9 additions & 2 deletions python/letsql/backends/let/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def dirty(pg):
def remove_unexpected_tables(dirty):
for table in dirty.list_tables():
if table not in expected_tables:
dirty.drop_table(table)
dirty.drop_table(table, force=True)

for table in dirty.list_tables():
if table not in expected_tables:
dirty.drop_view(table, force=True)

if sorted(dirty.list_tables()) != sorted(expected_tables):
raise ValueError

Expand All @@ -72,7 +77,9 @@ def dirty_ls_con():
def ls_con(dirty_ls_con):
yield dirty_ls_con
for table_name in dirty_ls_con.list_tables():
dirty_ls_con.drop_table(table_name)
dirty_ls_con.drop_table(table_name, force=True)
for table_name in dirty_ls_con.list_tables():
dirty_ls_con.drop_view(table_name, force=True)


@pytest.fixture(scope="session")
Expand Down
60 changes: 60 additions & 0 deletions python/letsql/backends/let/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from pathlib import Path

import pytest


import letsql as ls


@pytest.fixture(scope="session")
def csv_dir():
root = Path(__file__).absolute().parents[5]
data_dir = root / "ci" / "ibis-testing-data" / "csv"
return data_dir


@pytest.fixture(scope="session")
def parquet_dir():
root = Path(__file__).absolute().parents[5]
data_dir = root / "ci" / "ibis-testing-data" / "parquet"
return data_dir


def test_register_read_csv(con, csv_dir):
api_batting = con.register(
ls.read_csv(csv_dir / "batting.csv"), table_name="api_batting"
)
result = api_batting.execute()

assert result is not None


def test_register_read_parquet(con, parquet_dir):
api_batting = con.register(
ls.read_parquet(parquet_dir / "batting.parquet"), table_name="api_batting"
)
result = api_batting.execute()

assert result is not None


def test_executed_on_original_backend(ls_con, parquet_dir, csv_dir, mocker):
con = ls.config._backend_init()
spy = mocker.spy(con, "execute")

table_name = "batting"
parquet_table = ls.read_parquet(parquet_dir / "batting.parquet")[
lambda t: t.yearID == 2015
].pipe(ls_con.register, f"parquet-{table_name}")

csv_table = ls.read_csv(csv_dir / "batting.csv")[lambda t: t.yearID == 2014].pipe(
ls_con.register, f"csv-{table_name}"
)

expr = parquet_table.join(
csv_table,
"playerID",
)

assert expr.execute() is not None
assert spy.call_count == 1
9 changes: 5 additions & 4 deletions python/letsql/backends/let/tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,14 +541,15 @@ def test_read_csv_compute_and_cache(con, csv_dir, tmp_path):
assert expr.execute() is not None


def test_multi_engine_cache(pg, ls_con, tmp_path):
db_con = ibis.duckdb.connect()
@pytest.mark.parametrize("other_con", [letsql.connect(), ibis.duckdb.connect()])
def test_multi_engine_cache(pg, ls_con, tmp_path, other_con):
other_con = ibis.duckdb.connect()

table_name = "batting"
pg_t = pg.table(table_name)[lambda t: t.yearID > 2014].pipe(
ls_con.register, f"pg-{table_name}"
)
db_t = db_con.register(pg.table(table_name).to_pyarrow(), f"{table_name}")[
db_t = other_con.register(pg.table(table_name).to_pyarrow(), f"{table_name}")[
lambda t: t.stint == 1
].pipe(ls_con.register, f"db-{table_name}")

Expand Down Expand Up @@ -625,7 +626,7 @@ def test_replace_table_matching_kwargs(pg, ls_con, tmp_path):


def test_cache_default_path_set(pg, ls_con, tmp_path):
letsql.options.cache_default_path = tmp_path
letsql.options.cache.default_path = tmp_path

storage = ParquetCacheStorage(
source=ls_con,
Expand Down
14 changes: 11 additions & 3 deletions python/letsql/backends/let/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from pytest import param

import letsql
from letsql.tests.util import (
assert_frame_equal,
)
Expand Down Expand Up @@ -341,6 +342,7 @@ def test_multiple_execution_letsql_register_table(con, csv_dir):
@pytest.mark.parametrize(
"other_con",
[
letsql.connect(),
ibis.datafusion.connect(),
ibis.duckdb.connect(),
ibis.postgres.connect(
Expand Down Expand Up @@ -497,7 +499,14 @@ def test_register_arbitrary_expression_multiple_tables(con, duckdb_con):
assert_frame_equal(result, expected, check_like=True)


def test_multiple_pipes(ls_con, pg):
@pytest.mark.parametrize(
"new_con",
[
letsql.connect(),
ibis.duckdb.connect(),
],
)
def test_multiple_pipes(ls_con, pg, new_con):
"""This test address the issue reported on bug #69
link: https://github.com/letsql/letsql/issues/69
Expand All @@ -506,12 +515,11 @@ def test_multiple_pipes(ls_con, pg):
In this test (and the rest) ls_con is a clean (no tables) letsql connection
"""

duckdb_con = ibis.duckdb.connect()
table_name = "batting"
pg_t = pg.table(table_name)[lambda t: t.yearID == 2015].pipe(
ls_con.register, f"pg-{table_name}"
)
db_t = duckdb_con.register(pg_t.to_pyarrow(), f"{table_name}")[
db_t = new_con.register(pg_t.to_pyarrow(), f"{table_name}")[
lambda t: t.yearID == 2014
].pipe(ls_con.register, f"db-{table_name}")

Expand Down
Loading

0 comments on commit 37df954

Please sign in to comment.