Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move 'join' SQL implementation to warehouse #409

Merged
merged 6 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from sqlalchemy.dialects.sqlite import Insert
from sqlalchemy.engine.base import Engine
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql._typing import _FromClauseArgument, _OnClauseArgument
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.selectable import Join
from sqlalchemy.types import TypeEngine

from datachain.lib.file import File
Expand Down Expand Up @@ -788,6 +790,23 @@ def copy_table(
if progress_cb:
progress_cb(len(batch_ids))

def join(
self,
left: "_FromClauseArgument",
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
"""
Join two tables together.
"""
return sqlalchemy.join(
left,
right,
onclause,
isouter=not inner,
)

def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Create a temporary table from a query for use in a UDF.
Expand Down
22 changes: 19 additions & 3 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
from datachain.utils import sql_escape_like

if TYPE_CHECKING:
from sqlalchemy.sql._typing import _ColumnsClauseArgument
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql._typing import (
_ColumnsClauseArgument,
_FromClauseArgument,
_OnClauseArgument,
)
from sqlalchemy.sql.selectable import Join, Select
from sqlalchemy.types import TypeEngine

from datachain.data_storage import AbstractIDGenerator, schema
Expand Down Expand Up @@ -894,6 +898,18 @@ def copy_table(
Copy the results of a query into a table.
"""

@abstractmethod
def join(
self,
left: "_FromClauseArgument",
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
"""
Join two tables together.
"""

@abstractmethod
def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Expand Down Expand Up @@ -922,7 +938,7 @@ def cleanup_tables(self, names: Iterable[str]) -> None:
are cleaned up as soon as they are no longer needed.
"""
with tqdm(desc="Cleanup", unit=" tables") as pbar:
for name in names:
for name in set(names):
self.db.drop_table(Table(name, self.db.metadata), if_exists=True)
pbar.update(1)

Expand Down
57 changes: 38 additions & 19 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from sqlalchemy.sql.expression import label
from sqlalchemy.sql.schema import TableClause
from sqlalchemy.sql.selectable import Select
from tqdm import tqdm

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
Expand Down Expand Up @@ -899,12 +898,36 @@ def q(*columns):

@frozen
class SQLJoin(Step):
catalog: "Catalog"
query1: "DatasetQuery"
query2: "DatasetQuery"
predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
inner: bool
rname: str

def get_query(self, dq: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
query = dq.apply_steps().select()
temp_tables.extend(dq.temp_table_names)

if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in dq.steps):
return query.subquery(dq.table.name)

warehouse = self.catalog.warehouse

columns = [
c if isinstance(c, Column) else Column(c.name, c.type)
for c in query.subquery().columns
]
temp_table = warehouse.create_dataset_rows_table(
warehouse.temp_table_name(),
columns=columns,
)
temp_tables.append(temp_table.name)

warehouse.copy_table(temp_table, query)

return temp_table.select().subquery(dq.table.name)

def validate_expression(self, exp: "ClauseElement", q1, q2):
"""
Checking if columns used in expression actually exist in left / right
Expand Down Expand Up @@ -937,10 +960,8 @@ def validate_expression(self, exp: "ClauseElement", q1, q2):
def apply(
self, query_generator: QueryGenerator, temp_tables: list[str]
) -> StepResult:
q1 = self.query1.apply_steps().select().subquery(self.query1.table.name)
temp_tables.extend(self.query1.temp_table_names)
q2 = self.query2.apply_steps().select().subquery(self.query2.table.name)
temp_tables.extend(self.query2.temp_table_names)
q1 = self.get_query(self.query1, temp_tables)
q2 = self.get_query(self.query2, temp_tables)

q1_columns = list(q1.c)
q1_column_names = {c.name for c in q1_columns}
Expand All @@ -951,7 +972,12 @@ def apply(
continue

if c.name in q1_column_names:
c = c.label(self.rname.format(name=c.name))
new_name = self.rname.format(name=c.name)
new_name_idx = 0
while new_name in q1_column_names:
new_name_idx += 1
new_name = self.rname.format(name=f"{c.name}_{new_name_idx}")
c = c.label(new_name)
q2_columns.append(c)

res_columns = q1_columns + q2_columns
Expand Down Expand Up @@ -979,16 +1005,14 @@ def apply(
self.validate_expression(join_expression, q1, q2)

def q(*columns):
join_query = sqlalchemy.join(
join_query = self.catalog.warehouse.join(
q1,
q2,
join_expression,
isouter=not self.inner,
inner=self.inner,
)

res = sqlalchemy.select(*columns).select_from(join_query)
subquery = res.subquery()
return sqlalchemy.select(*subquery.c).select_from(subquery)
return sqlalchemy.select(*columns).select_from(join_query)
# return sqlalchemy.select(*subquery.c).select_from(subquery)

return step_result(
q,
Expand Down Expand Up @@ -1511,7 +1535,7 @@ def join(
if isinstance(predicates, (str, ColumnClause, ColumnElement))
else tuple(predicates)
)
new_query.steps = [SQLJoin(left, right, predicates, inner, rname)]
new_query.steps = [SQLJoin(self.catalog, left, right, predicates, inner, rname)]
return new_query

@detach
Expand Down Expand Up @@ -1687,12 +1711,7 @@ def save(

dr = self.catalog.warehouse.dataset_rows(dataset)

with tqdm(desc="Saving", unit=" rows") as pbar:
self.catalog.warehouse.copy_table(
dr.get_table(),
query.select(),
progress_cb=pbar.update,
)
self.catalog.warehouse.copy_table(dr.get_table(), query.select())

self.catalog.metastore.update_dataset_status(
dataset, DatasetStatus.COMPLETE, version=version
Expand Down
117 changes: 117 additions & 0 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,123 @@ def test_union(cloud_test_catalog):
assert count == 6


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
@pytest.mark.parametrize("inner", [True, False])
def test_union_join(cloud_test_catalog, inner):
catalog = cloud_test_catalog.catalog
sources = [str(cloud_test_catalog.src_uri)]
catalog.index(sources)

src = cloud_test_catalog.src_uri
catalog.create_dataset_from_sources("dogs", [f"{src}/dogs/*"], recursive=True)
catalog.create_dataset_from_sources("cats", [f"{src}/cats/*"], recursive=True)

dogs = DatasetQuery(name="dogs", version=1, catalog=catalog)
cats = DatasetQuery(name="cats", version=1, catalog=catalog)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

@udf((), {"sig1": Int})
def signals1():
return (1,)

@udf((), {"sig2": Int})
def signals2():
return (2,)

dogs1 = dogs.add_signals(signals1)
dogs2 = dogs.add_signals(signals2)
cats1 = cats.add_signals(signals1)

joined = (dogs1 | cats1).join(dogs2, C.path, inner=inner)
signals = list(joined.select("path", "sig1", "sig2").order_by("path"))

if inner:
assert signals == [
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
("dogs/others/dog4", 1, 2),
]
else:
assert signals == [
("cats/cat1", 1, signal_default_value),
("cats/cat2", 1, signal_default_value),
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
("dogs/others/dog4", 1, 2),
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
@pytest.mark.parametrize("inner1", [True, False])
@pytest.mark.parametrize("inner2", [True, False])
@pytest.mark.parametrize("inner3", [True, False])
def test_multiple_join(cloud_test_catalog, inner1, inner2, inner3):
catalog = cloud_test_catalog.catalog
sources = [str(cloud_test_catalog.src_uri)]
catalog.index(sources)

src = cloud_test_catalog.src_uri
catalog.create_dataset_from_sources("dogs", [f"{src}/dogs/*"], recursive=True)
catalog.create_dataset_from_sources("cats", [f"{src}/cats/*"], recursive=True)

dogs = DatasetQuery(name="dogs", version=1, catalog=catalog)
cats = DatasetQuery(name="cats", version=1, catalog=catalog)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

@udf((), {"sig1": Int})
def signals1():
return (1,)

@udf((), {"sig2": Int})
def signals2():
return (2,)

dogs_and_cats = dogs | cats
dogs1 = dogs.add_signals(signals1)
cats1 = cats.add_signals(signals2)
dogs2 = dogs_and_cats.join(dogs1, C.path, inner=inner1)
cats2 = dogs_and_cats.join(cats1, C.path, inner=inner2)
joined = dogs2.join(cats2, C.path, inner=inner3)

joined_signals = list(joined.select("path", "sig1", "sig2").order_by("path"))

if inner1 and inner2 and inner3:
assert joined_signals == []
elif inner1:
assert joined_signals == [
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
]
elif inner2 and inner3:
assert joined_signals == [
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
]
else:
assert joined_signals == [
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down