Skip to content

Commit

Permalink
Optimize UDF with parallel execution
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Dec 13, 2024
1 parent 415454a commit 5953afe
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 93 deletions.
1 change: 0 additions & 1 deletion src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def run(
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[RowsOutput]",
catalog: "Catalog",
is_generator: bool,
cache: bool,
download_cb: Callback = DEFAULT_CALLBACK,
processed_cb: Callback = DEFAULT_CALLBACK,
Expand Down
39 changes: 33 additions & 6 deletions src/datachain/query/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from datachain.data_storage.schema import PARTITION_COLUMN_ID
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE
from datachain.query.utils import get_query_column, get_query_id_column

if TYPE_CHECKING:
from sqlalchemy import Select
Expand All @@ -23,11 +24,14 @@ class RowsOutputBatch:
class BatchingStrategy(ABC):
"""BatchingStrategy provides means of batching UDF executions."""

is_batching: bool

@abstractmethod
def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[RowsOutput, None, None]:
"""Apply the provided parameters to the UDF."""

Expand All @@ -38,11 +42,16 @@ class NoBatching(BatchingStrategy):
batch UDF calls.
"""

is_batching = False

def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[Sequence, None, None]:
if ids_only:
query = query.with_only_columns(get_query_id_column(query))
return execute(query)


Expand All @@ -52,14 +61,20 @@ class Batch(BatchingStrategy):
is passed a sequence of multiple parameter sets.
"""

is_batching = True

def __init__(self, count: int):
self.count = count

def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[RowsOutputBatch, None, None]:
if ids_only:
query = query.with_only_columns(get_query_id_column(query))

Check warning on line 76 in src/datachain/query/batch.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/batch.py#L76

Added line #L76 was not covered by tests

# choose page size that is a multiple of the batch size
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count

Expand All @@ -84,19 +99,31 @@ class Partition(BatchingStrategy):
Dataset rows need to be sorted by the grouping column.
"""

is_batching = True

def __call__(
self,
execute: Callable[..., Generator[Sequence, None, None]],
execute: Callable,
query: "Select",
ids_only: bool = False,
) -> Generator[RowsOutputBatch, None, None]:
id_col = get_query_id_column(query)
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None:
raise RuntimeError("partition column not found in query")

Check warning on line 112 in src/datachain/query/batch.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/batch.py#L112

Added line #L112 was not covered by tests

if ids_only:
query = query.with_only_columns(id_col, partition_col)

Check warning on line 115 in src/datachain/query/batch.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/batch.py#L115

Added line #L115 was not covered by tests

current_partition: Optional[int] = None
batch: list[Sequence] = []

query_fields = [str(c.name) for c in query.selected_columns]
# query_fields = [column_name(col) for col in query.inner_columns]
id_column_idx = query_fields.index("sys__id")
partition_column_idx = query_fields.index(PARTITION_COLUMN_ID)

ordered_query = query.order_by(None).order_by(
PARTITION_COLUMN_ID,
partition_col,
*query._order_by_clauses,
)

Expand All @@ -108,7 +135,7 @@ def __call__(
if len(batch) > 0:
yield RowsOutputBatch(batch)
batch = []
batch.append(row)
batch.append([row[id_column_idx]] if ids_only else row)

if len(batch) > 0:
yield RowsOutputBatch(batch)
14 changes: 6 additions & 8 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from datachain.dataset import DatasetStatus, RowDict
from datachain.error import DatasetNotFoundError, QueryScriptCancelError
from datachain.func.base import Function
from datachain.lib.udf import UDFAdapter
from datachain.progress import CombinedDownloadCallback
from datachain.sql.functions.random import rand
from datachain.utils import (
Expand All @@ -66,7 +65,7 @@
from datachain.catalog import Catalog
from datachain.data_storage import AbstractWarehouse
from datachain.dataset import DatasetRecord
from datachain.lib.udf import UDFResult
from datachain.lib.udf import UDFAdapter, UDFResult

P = ParamSpec("P")

Expand Down Expand Up @@ -302,7 +301,7 @@ def adjust_outputs(
return row


def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]:
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]:
"""Optimization: Precompute UDF column types so these don't have to be computed
in the convert_type function for each row in a loop."""
dialect = warehouse.db.dialect
Expand All @@ -323,7 +322,7 @@ def process_udf_outputs(
warehouse: "AbstractWarehouse",
udf_table: "Table",
udf_results: Iterator[Iterable["UDFResult"]],
udf: UDFAdapter,
udf: "UDFAdapter",
batch_size: int = INSERT_BATCH_SIZE,
cb: Callback = DEFAULT_CALLBACK,
) -> None:
Expand Down Expand Up @@ -367,7 +366,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:

@frozen
class UDFStep(Step, ABC):
udf: UDFAdapter
udf: "UDFAdapter"
catalog: "Catalog"
partition_by: Optional[PartitionByType] = None
parallel: Optional[int] = None
Expand Down Expand Up @@ -478,7 +477,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
udf_fields,
udf_inputs,
self.catalog,
self.is_generator,
self.cache,
download_cb,
processed_cb,
Expand Down Expand Up @@ -1487,7 +1485,7 @@ def chunk(self, index: int, total: int) -> "Self":
@detach
def add_signals(
self,
udf: UDFAdapter,
udf: "UDFAdapter",
parallel: Optional[int] = None,
workers: Union[bool, int] = False,
min_task_size: Optional[int] = None,
Expand Down Expand Up @@ -1531,7 +1529,7 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self":
@detach
def generate(
self,
udf: UDFAdapter,
udf: "UDFAdapter",
parallel: Optional[int] = None,
workers: Union[bool, int] = False,
min_task_size: Optional[int] = None,
Expand Down
Loading

0 comments on commit 5953afe

Please sign in to comment.