Skip to content

Commit

Permalink
Window functions (#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour authored Oct 20, 2024
1 parent 5353966 commit dfd7fb4
Show file tree
Hide file tree
Showing 8 changed files with 684 additions and 29 deletions.
25 changes: 20 additions & 5 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,9 @@ def group_by(
"""Group rows by specified set of signals and return new signals
with aggregated values.
The supported functions:
count(), sum(), avg(), min(), max(), any_value(), collect(), concat()
Example:
```py
chain = chain.group_by(
Expand Down Expand Up @@ -1071,13 +1074,22 @@ def mutate(self, **kwargs) -> "Self":
Filename: name(), parent(), file_stem(), file_ext()
Array: length(), sip_hash_64(), euclidean_distance(),
cosine_distance()
Window: row_number(), rank(), dense_rank(), first()
Example:
```py
dc.mutate(
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
area=Column("image.height") * Column("image.width"),
extension=file_ext(Column("file.name")),
dist=cosine_distance(embedding_text, embedding_image)
)
```
Window function example:
```py
window = func.window(partition_by="file.parent", order_by="file.size")
dc.mutate(
row_number=func.row_number().over(window),
)
```
Expand All @@ -1088,7 +1100,7 @@ def mutate(self, **kwargs) -> "Self":
Example:
```py
dc.mutate(
newkey=Column("oldkey")
newkey=Column("oldkey")
)
```
"""
Expand All @@ -1101,7 +1113,7 @@ def mutate(self, **kwargs) -> "Self":
"Use a different name for the new column.",
)
for col_name, expr in kwargs.items():
if not isinstance(expr, Column) and isinstance(expr.type, NullType):
if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
raise DataChainColumnError(
col_name, f"Cannot infer type with expression {expr}"
)
Expand All @@ -1113,6 +1125,9 @@ def mutate(self, **kwargs) -> "Self":
# renaming existing column
for signal in schema.db_signals(name=value.name, as_columns=True):
mutated[signal.name.replace(value.name, name, 1)] = signal # type: ignore[union-attr]
elif isinstance(value, Func):
# adding new signal
mutated[name] = value.get_column(schema)
else:
# adding new signal
mutated[name] = value
Expand Down
22 changes: 20 additions & 2 deletions src/datachain/lib/func/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from .aggregate import any_value, avg, collect, concat, count, max, min, sum
from .func import Func
from .aggregate import (
any_value,
avg,
collect,
concat,
count,
dense_rank,
first,
max,
min,
rank,
row_number,
sum,
)
from .func import Func, window

__all__ = [
"Func",
Expand All @@ -8,7 +21,12 @@
"collect",
"concat",
"count",
"dense_rank",
"first",
"max",
"min",
"rank",
"row_number",
"sum",
"window",
]
Loading

0 comments on commit dfd7fb4

Please sign in to comment.