Skip to content

Commit

Permalink
allow merge on expressions (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon authored Sep 9, 2024
1 parent ae493e7 commit 2444964
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 59 deletions.
19 changes: 10 additions & 9 deletions examples/multimodal/clip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
from torch.utils.data import DataLoader

from datachain import C, DataChain
from datachain.sql.functions import path

source = "gs://datachain-demo/50k-laion-files/000000/00000000*"


def create_dataset():
imgs = (
DataChain.from_storage(source, type="image")
.filter(C("file.path").glob("*.jpg"))
.map(stem=lambda file: file.get_file_stem(), params=["file"], output=str)
imgs = DataChain.from_storage(source, type="image").filter(
C("file.path").glob("*.jpg")
)
captions = (
DataChain.from_storage(source, type="text")
.filter(C("file.path").glob("*.txt"))
.map(stem=lambda file: file.get_file_stem(), params=["file"], output=str)
captions = DataChain.from_storage(source, type="text").filter(
C("file.path").glob("*.txt")
)
return imgs.merge(
captions,
on=path.file_stem(imgs.c("file.path")),
right_on=path.file_stem(captions.c("file.path")),
)
return imgs.merge(captions, on="stem")


if __name__ == "__main__":
Expand Down
21 changes: 10 additions & 11 deletions examples/multimodal/wds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from datachain import C, DataChain
from datachain import DataChain
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta
from datachain.sql.functions import path
Expand All @@ -25,21 +25,20 @@
DataChain.from_parquet(PARQUET_METADATA)
.settings(cache=True)
.merge(wds_images, on="uid", right_on="laion.json.uid", inner=True)
.mutate(stem=path.file_stem(C("source.file.path")))
)

res = (
wds_npz = (
DataChain.from_storage(NPZ_METADATA)
.settings(cache=True)
.gen(emd=process_laion_meta)
.mutate(stem=path.file_stem(C("emd.file.path")))
.merge(
wds_with_pq,
on=["stem", "emd.index"],
right_on=["stem", "source.index"],
inner=True,
)
.save("wds")
)


res = wds_npz.merge(
wds_with_pq,
on=[path.file_stem(wds_npz.c("emd.file.path")), "emd.index"],
right_on=[path.file_stem(wds_with_pq.c("source.file.path")), "source.index"],
inner=True,
).save("wds")

res.show(5)
107 changes: 74 additions & 33 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
PartitionByType,
detach,
)
from datachain.query.schema import Column, DatasetRow
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
from datachain.sql.functions import path as pathfunc
from datachain.utils import inside_notebook

Expand Down Expand Up @@ -112,11 +112,31 @@ def __init__(self, name, msg): # noqa: D107
super().__init__(f"Dataset{name} from values error: {msg}")


def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str:
if isinstance(col, str):
return col
if isinstance(col, sqlalchemy.Column):
return col.name.replace(DEFAULT_DELIMITER, ".")
if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
return f"{col.name} expression"
return str(col)


class DatasetMergeError(DataChainParamsError): # noqa: D101
def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str): # noqa: D107
on_str = ", ".join(on) if isinstance(on, Sequence) else ""
def __init__( # noqa: D107
self,
on: Sequence[Union[str, sqlalchemy.ColumnElement]],
right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]],
msg: str,
):
def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
if not isinstance(on, Sequence):
return str(on) # type: ignore[unreachable]
return ", ".join([_get_merge_error_str(col) for col in on])

on_str = _get_str(on)
right_on_str = (
", right_on='" + ", ".join(right_on) + "'"
", right_on='" + _get_str(right_on) + "'"
if right_on and isinstance(right_on, Sequence)
else ""
)
Expand Down Expand Up @@ -252,13 +272,24 @@ def column(self, name: str) -> Column:
"""Returns Column instance with a type if name is found in current schema,
otherwise raises an exception.
"""
name_path = name.split(".")
if "." in name:
name_path = name.split(".")
elif DEFAULT_DELIMITER in name:
name_path = name.split(DEFAULT_DELIMITER)
else:
name_path = [name]
for path, type_, _, _ in self.signals_schema.get_flat_tree():
if path == name_path:
return Column(name, python_to_sql(type_))

raise ValueError(f"Column with name {name} not found in the schema")

def c(self, column: Union[str, Column]) -> Column:
"""Returns Column instance attached to the current chain."""
c = self.column(column) if isinstance(column, str) else self.column(column.name)
c.table = self.table
return c

def print_schema(self) -> None:
"""Print schema of the chain."""
self._effective_signals_schema.print_tree()
Expand Down Expand Up @@ -1140,8 +1171,17 @@ def remove_file_signals(self) -> "Self": # noqa: D102
def merge(
self,
right_ds: "DataChain",
on: Union[str, Sequence[str]],
right_on: Union[str, Sequence[str], None] = None,
on: Union[
str,
sqlalchemy.ColumnElement,
Sequence[Union[str, sqlalchemy.ColumnElement]],
],
right_on: Union[
str,
sqlalchemy.ColumnElement,
Sequence[Union[str, sqlalchemy.ColumnElement]],
None,
] = None,
inner=False,
rname="right_",
) -> "Self":
Expand All @@ -1166,7 +1206,7 @@ def merge(
if on is None:
raise DatasetMergeError(["None"], None, "'on' must be specified")

if isinstance(on, str):
if isinstance(on, (str, sqlalchemy.ColumnElement)):
on = [on]
elif not isinstance(on, Sequence):
raise DatasetMergeError(
Expand All @@ -1175,54 +1215,55 @@ def merge(
f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
)

signals_schema = self.signals_schema.clone_without_sys_signals()
on_columns: list[str] = signals_schema.resolve(*on).db_signals() # type: ignore[assignment]

right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
if right_on is not None:
if isinstance(right_on, str):
if isinstance(right_on, (str, sqlalchemy.ColumnElement)):
right_on = [right_on]
elif not isinstance(right_on, Sequence):
raise DatasetMergeError(
on,
right_on,
"'right_on' must be 'str' or 'Sequence' object"
f" but got type '{right_on}'",
f" but got type '{type(right_on)}'",
)

if len(right_on) != len(on):
raise DatasetMergeError(
on, right_on, "'on' and 'right_on' must have the same length'"
)

right_on_columns: list[str] = right_signals_schema.resolve(
*right_on
).db_signals() # type: ignore[assignment]

if len(right_on_columns) != len(on_columns):
on_str = ", ".join(right_on_columns)
right_on_str = ", ".join(right_on_columns)
raise DatasetMergeError(
on,
right_on,
"'on' and 'right_on' must have the same number of columns in db'."
f" on -> {on_str}, right_on -> {right_on_str}",
)
else:
right_on = on
right_on_columns = on_columns

if self == right_ds:
right_ds = right_ds.clone(new_table=True)

errors = []

def _resolve(
ds: DataChain,
col: Union[str, sqlalchemy.ColumnElement],
side: Union[str, None],
):
try:
return ds.c(col) if isinstance(col, (str, C)) else col
except ValueError:
if side:
errors.append(f"{_get_merge_error_str(col)} in {side}")

ops = [
self.c(left) == right_ds.c(right)
for left, right in zip(on_columns, right_on_columns)
_resolve(self, left, "left")
== _resolve(right_ds, right, "right" if right_on else None)
for left, right in zip(on, right_on or on)
]

if errors:
raise DatasetMergeError(
on, right_on, f"Could not resolve {', '.join(errors)}"
)

ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")

ds.feature_schema = None

signals_schema = self.signals_schema.clone_without_sys_signals()
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
right_signals_schema, rname
)
Expand Down
8 changes: 6 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,8 +1168,12 @@ def attached(self) -> bool:
"""
return self.name is not None and self.version is not None

def c(self, name: Union[C, str]) -> "ColumnClause[Any]":
col = sqlalchemy.column(name) if isinstance(name, str) else name
def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
col: sqlalchemy.ColumnClause = (
sqlalchemy.column(column)
if isinstance(column, str)
else sqlalchemy.column(column.name, column.type)
)
col.table = self.table
return col

Expand Down
61 changes: 57 additions & 4 deletions tests/unit/lib/test_datachain_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pandas as pd
import pytest
from pydantic import BaseModel
from sqlalchemy import func

from datachain.lib.dc import DataChain, DatasetMergeError
from datachain.lib.signal_schema import SignalResolvingError
from datachain.lib.dc import C, DataChain, DatasetMergeError
from datachain.sql.types import String
from tests.utils import skip_if_not_sqlite

Expand Down Expand Up @@ -196,7 +196,7 @@ def test_merge_multi_conditions(test_session):
id=delivery_ids, d_name=delivery_name, time=delivery_time, session=test_session
)

ch = ch1.merge(ch2, ("id", "name"), ("id", "d_name"))
ch = ch1.merge(ch2, ("id", "name"), ("id", C("d_name")))

res = list(ch.collect())

Expand All @@ -213,11 +213,23 @@ def test_merge_errors(test_session):
ch1 = DataChain.from_values(emp=employees, session=test_session)
ch2 = DataChain.from_values(team=team, session=test_session)

with pytest.raises(SignalResolvingError):
with pytest.raises(DatasetMergeError):
ch1.merge(ch2, "unknown")

with pytest.raises(DatasetMergeError):
ch1.merge(ch2, ["emp.person.name"], "unknown")

with pytest.raises(DatasetMergeError):
ch1.merge(ch2, ["emp.person.name"], ["unknown"])

with pytest.raises(DatasetMergeError):
ch1.merge(
ch2, ("emp.person.age", func.substr(["emp.person.name"], 2)), "unknown"
)

ch1.merge(ch2, ["emp.person.name"], ["team.sport"])
ch1.merge(ch2, ["emp.person.name"], ["team.sport"])

with pytest.raises(DatasetMergeError):
ch1.merge(ch2, ["emp.person.name"], ["team.player", "team.sport"])

Expand All @@ -240,3 +252,44 @@ def test_merge_with_itself(test_session):
count += 1

assert count == len(employees)


def test_merge_with_itself_column(test_session):
ch = DataChain.from_values(emp=employees, session=test_session)
merged = ch.merge(ch, C("emp.id"))

count = 0
for left, right in merged.collect():
assert isinstance(left, Employee)
assert isinstance(right, Employee)
assert left == right == employees[count]
count += 1

assert count == len(employees)


def test_merge_on_expression(test_session):
def _get_expr(dc):
c = dc.c("team.sport")
return func.substr(c, func.length(c) - 3)

dc = DataChain.from_values(team=team, session=test_session)
right_dc = dc.clone(new_table=True)

# cross join on "ball" from sport
merged = dc.merge(right_dc, on=_get_expr(dc), right_on=_get_expr(right_dc))

cross_team = [
(left_member, right_member) for left_member in team for right_member in team
]

count = 0
for left, right_dc in merged.collect():
assert isinstance(left, TeamMember)
assert isinstance(right_dc, TeamMember)
left_member, right_member = cross_team[count]
assert left == left_member
assert right_dc == right_member
count += 1

assert count == len(team) * len(team)

0 comments on commit 2444964

Please sign in to comment.