Skip to content

Commit

Permalink
Move 'join' SQL implementation to warehouse
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Sep 9, 2024
1 parent d36d06f commit a611327
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 19 deletions.
19 changes: 19 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from sqlalchemy.dialects.sqlite import Insert
from sqlalchemy.engine.base import Engine
from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql._typing import _FromClauseArgument, _OnClauseArgument
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.selectable import Join
from sqlalchemy.types import TypeEngine


Expand Down Expand Up @@ -779,6 +781,23 @@ def copy_table(
if progress_cb:
progress_cb(len(batch_ids))

def join(
self,
left: "_FromClauseArgument",
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
"""
Join two tables together.
"""
return sqlalchemy.join(
left,
right,
onclause,
isouter=not inner,
)

def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Create a temporary table from a query for use in a UDF.
Expand Down
20 changes: 18 additions & 2 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@
from datachain.utils import sql_escape_like

if TYPE_CHECKING:
from sqlalchemy.sql._typing import _ColumnsClauseArgument
from sqlalchemy.sql._typing import (
_ColumnsClauseArgument,
_FromClauseArgument,
_OnClauseArgument,
)
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.selectable import Select
from sqlalchemy.sql.selectable import Join, Select
from sqlalchemy.types import TypeEngine

from datachain.data_storage import AbstractIDGenerator, schema
Expand Down Expand Up @@ -911,6 +915,18 @@ def copy_table(
Copy the results of a query into a table.
"""

@abstractmethod
def join(
self,
left: "_FromClauseArgument",
right: "_FromClauseArgument",
onclause: "_OnClauseArgument",
inner: bool = True,
) -> "Join":
"""
Join two tables together.
"""

@abstractmethod
def create_pre_udf_table(self, query: "Select") -> "Table":
"""
Expand Down
50 changes: 33 additions & 17 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from sqlalchemy.sql.expression import label
from sqlalchemy.sql.schema import TableClause
from sqlalchemy.sql.selectable import Select
from tqdm import tqdm

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE, get_catalog
Expand Down Expand Up @@ -903,12 +902,38 @@ def q(*columns):

@frozen
class SQLJoin(Step):
catalog: "Catalog"
query1: "DatasetQuery"
query2: "DatasetQuery"
predicates: Union[JoinPredicateType, tuple[JoinPredicateType, ...]]
inner: bool
rname: str

def get_query(self, query: "DatasetQuery", temp_tables: list[str]) -> sa.Subquery:
select_query = query.apply_steps().select()
temp_tables.extend(query.temp_table_names)

if not any(isinstance(step, (SQLJoin, SQLUnion)) for step in query.steps):
return select_query.subquery(query.table.name)

warehouse = self.catalog.warehouse

Check warning on line 919 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L919

Added line #L919 was not covered by tests

table_name = warehouse.temp_table_name()

Check warning on line 921 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L921

Added line #L921 was not covered by tests
columns = [
c if isinstance(c, Column) else Column(c.name, c.type)
for c in select_query.columns
]
temp_table = warehouse.create_dataset_rows_table(

Check warning on line 926 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L926

Added line #L926 was not covered by tests
table_name,
columns=columns,
if_not_exists=False,
)
temp_tables.append(temp_table.name)

Check warning on line 931 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L931

Added line #L931 was not covered by tests

warehouse.copy_table(temp_table, select_query)

Check warning on line 933 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L933

Added line #L933 was not covered by tests

return temp_table.select().subquery(query.table.name)

Check warning on line 935 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L935

Added line #L935 was not covered by tests

def validate_expression(self, exp: "ClauseElement", q1, q2):
"""
Checking if columns used in expression actually exist in left / right
Expand Down Expand Up @@ -941,10 +966,8 @@ def validate_expression(self, exp: "ClauseElement", q1, q2):
def apply(
self, query_generator: QueryGenerator, temp_tables: list[str]
) -> StepResult:
q1 = self.query1.apply_steps().select().subquery(self.query1.table.name)
temp_tables.extend(self.query1.temp_table_names)
q2 = self.query2.apply_steps().select().subquery(self.query2.table.name)
temp_tables.extend(self.query2.temp_table_names)
q1 = self.get_query(self.query1, temp_tables)
q2 = self.get_query(self.query2, temp_tables)

q1_columns = list(q1.c)
q1_column_names = {c.name for c in q1_columns}
Expand Down Expand Up @@ -983,15 +1006,13 @@ def apply(
self.validate_expression(join_expression, q1, q2)

def q(*columns):
join_query = sqlalchemy.join(
join_query = self.catalog.warehouse.join(
q1,
q2,
join_expression,
isouter=not self.inner,
inner=self.inner,
)

res = sqlalchemy.select(*columns).select_from(join_query)
subquery = res.subquery()
subquery = sqlalchemy.select(*columns).select_from(join_query).subquery()
return sqlalchemy.select(*subquery.c).select_from(subquery)

return step_result(
Expand Down Expand Up @@ -1515,7 +1536,7 @@ def join(
if isinstance(predicates, (str, ColumnClause, ColumnElement))
else tuple(predicates)
)
new_query.steps = [SQLJoin(left, right, predicates, inner, rname)]
new_query.steps = [SQLJoin(self.catalog, left, right, predicates, inner, rname)]
return new_query

@detach
Expand Down Expand Up @@ -1691,12 +1712,7 @@ def save(

dr = self.catalog.warehouse.dataset_rows(dataset)

with tqdm(desc="Saving", unit=" rows") as pbar:
self.catalog.warehouse.copy_table(
dr.get_table(),
query.select(),
progress_cb=pbar.update,
)
self.catalog.warehouse.copy_table(dr.get_table(), query.select())

self.catalog.metastore.update_dataset_status(
dataset, DatasetStatus.COMPLETE, version=version
Expand Down

0 comments on commit a611327

Please sign in to comment.