diff --git a/src/datachain/lib/udf.py b/src/datachain/lib/udf.py index d708c0330..c59442d6b 100644 --- a/src/datachain/lib/udf.py +++ b/src/datachain/lib/udf.py @@ -85,7 +85,6 @@ def run( udf_fields: "Sequence[str]", udf_inputs: "Iterable[RowsOutput]", catalog: "Catalog", - is_generator: bool, cache: bool, download_cb: Callback = DEFAULT_CALLBACK, processed_cb: Callback = DEFAULT_CALLBACK, diff --git a/src/datachain/query/batch.py b/src/datachain/query/batch.py index 8f24ec895..6be29b1fe 100644 --- a/src/datachain/query/batch.py +++ b/src/datachain/query/batch.py @@ -7,6 +7,7 @@ from datachain.data_storage.schema import PARTITION_COLUMN_ID from datachain.data_storage.warehouse import SELECT_BATCH_SIZE +from datachain.query.utils import get_query_column, get_query_id_column if TYPE_CHECKING: from sqlalchemy import Select @@ -23,11 +24,14 @@ class RowsOutputBatch: class BatchingStrategy(ABC): """BatchingStrategy provides means of batching UDF executions.""" + is_batching: bool + @abstractmethod def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutput, None, None]: """Apply the provided parameters to the UDF.""" @@ -38,11 +42,16 @@ class NoBatching(BatchingStrategy): batch UDF calls. """ + is_batching = False + def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[Sequence, None, None]: + if ids_only: + query = query.with_only_columns(get_query_id_column(query)) return execute(query) @@ -52,14 +61,20 @@ class Batch(BatchingStrategy): is passed a sequence of multiple parameter sets. """ + is_batching = True + def __init__(self, count: int): self.count = count def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutputBatch, None, None]: + if ids_only: + query = query.with_only_columns(get_query_id_column(query)) + # choose page size that is a multiple of the batch size page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count @@ -84,19 +99,31 @@ class Partition(BatchingStrategy): Dataset rows need to be sorted by the grouping column. """ + is_batching = True + def __call__( self, - execute: Callable[..., Generator[Sequence, None, None]], + execute: Callable, query: "Select", + ids_only: bool = False, ) -> Generator[RowsOutputBatch, None, None]: + id_col = get_query_id_column(query) + if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None: + raise RuntimeError("partition column not found in query") + + if ids_only: + query = query.with_only_columns(id_col, partition_col) + current_partition: Optional[int] = None batch: list[Sequence] = [] query_fields = [str(c.name) for c in query.selected_columns] + # query_fields = [column_name(col) for col in query.inner_columns] + id_column_idx = query_fields.index("sys__id") partition_column_idx = query_fields.index(PARTITION_COLUMN_ID) ordered_query = query.order_by(None).order_by( - PARTITION_COLUMN_ID, + partition_col, *query._order_by_clauses, ) @@ -108,7 +135,7 @@ def __call__( if len(batch) > 0: yield RowsOutputBatch(batch) batch = [] - batch.append(row) + batch.append([row[id_column_idx]] if ids_only else row) if len(batch) > 0: yield RowsOutputBatch(batch) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 46058ba83..b4294ac7d 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -44,7 +44,6 @@ from datachain.dataset import DatasetStatus, RowDict from datachain.error import DatasetNotFoundError, QueryScriptCancelError from datachain.func.base import Function -from datachain.lib.udf import UDFAdapter from datachain.progress import CombinedDownloadCallback from datachain.sql.functions.random import rand from datachain.utils import ( @@ -66,7 +65,7 @@ from datachain.catalog import Catalog from datachain.data_storage import AbstractWarehouse from datachain.dataset import DatasetRecord - from datachain.lib.udf import UDFResult + from datachain.lib.udf import UDFAdapter, UDFResult P = ParamSpec("P") @@ -302,7 +301,7 @@ def adjust_outputs( return row -def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]: +def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]: """Optimization: Precompute UDF column types so these don't have to be computed in the convert_type function for each row in a loop.""" dialect = warehouse.db.dialect @@ -323,7 +322,7 @@ def process_udf_outputs( warehouse: "AbstractWarehouse", udf_table: "Table", udf_results: Iterator[Iterable["UDFResult"]], - udf: UDFAdapter, + udf: "UDFAdapter", batch_size: int = INSERT_BATCH_SIZE, cb: Callback = DEFAULT_CALLBACK, ) -> None: @@ -367,7 +366,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback: @frozen class UDFStep(Step, ABC): - udf: UDFAdapter + udf: "UDFAdapter" catalog: "Catalog" partition_by: Optional[PartitionByType] = None parallel: Optional[int] = None @@ -478,7 +477,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: udf_fields, udf_inputs, self.catalog, - self.is_generator, self.cache, download_cb, processed_cb, @@ -1487,7 +1485,7 @@ def chunk(self, index: int, total: int) -> "Self": @detach def add_signals( self, - udf: UDFAdapter, + udf: "UDFAdapter", parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, @@ -1531,7 +1529,7 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self": @detach def generate( self, - udf: UDFAdapter, + udf: "UDFAdapter", parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index 5392cf491..2d85fe551 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -1,9 +1,10 @@ import contextlib -from collections.abc import Iterator, Sequence +from collections.abc import Sequence from itertools import chain from multiprocessing import cpu_count from sys import stdin -from typing import Optional +from threading import Timer +from typing import TYPE_CHECKING, Optional import attrs import multiprocess @@ -13,22 +14,23 @@ from datachain.catalog import Catalog from datachain.catalog.loader import get_distributed_class -from datachain.lib.udf import UDFAdapter, UDFResult +from datachain.query.batch import RowsOutputBatch from datachain.query.dataset import ( get_download_callback, get_generated_callback, get_processed_callback, process_udf_outputs, ) -from datachain.query.queue import ( - get_from_queue, - marshal, - msgpack_pack, - msgpack_unpack, - put_into_queue, - unmarshal, -) -from datachain.utils import batched_it +from datachain.query.queue import get_from_queue, put_into_queue +from datachain.query.utils import get_query_id_column +from datachain.utils import batched, flatten + +if TYPE_CHECKING: + from sqlalchemy import Select, Table + + from datachain.data_storage import AbstractMetastore, AbstractWarehouse + from datachain.lib.udf import UDFAdapter + from datachain.query.batch import BatchingStrategy DEFAULT_BATCH_SIZE = 10000 STOP_SIGNAL = "STOP" @@ -54,12 +56,9 @@ def udf_entrypoint() -> int: # Load UDF info from stdin udf_info = load(stdin.buffer) - ( - warehouse_class, - warehouse_args, - warehouse_kwargs, - ) = udf_info["warehouse_clone_params"] - warehouse = warehouse_class(*warehouse_args, **warehouse_kwargs) + query: Select = udf_info["query"] + table: Table = udf_info["table"] + batching: BatchingStrategy = udf_info["batching"] # Parallel processing (faster for more CPU-heavy UDFs) dispatch = UDFDispatcher( @@ -67,41 +66,39 @@ def udf_entrypoint() -> int: udf_info["catalog_init"], udf_info["metastore_clone_params"], udf_info["warehouse_clone_params"], + query=query, + table=table, udf_fields=udf_info["udf_fields"], cache=udf_info["cache"], is_generator=udf_info.get("is_generator", False), + is_batching=batching.is_batching, ) - query = udf_info["query"] - batching = udf_info["batching"] - table = udf_info["table"] n_workers = udf_info["processes"] - udf = loads(udf_info["udf_data"]) if n_workers is True: - # Use default number of CPUs (cores) - n_workers = None + n_workers = None # Use default number of CPUs (cores) + + wh_cls, wh_args, wh_kwargs = udf_info["warehouse_clone_params"] + warehouse: AbstractWarehouse = wh_cls(*wh_args, **wh_kwargs) with contextlib.closing( - batching(warehouse.dataset_select_paginated, query) + batching(warehouse.db.execute, query, ids_only=True) ) as udf_inputs: download_cb = get_download_callback() processed_cb = get_processed_callback() generated_cb = get_generated_callback(dispatch.is_generator) try: - udf_results = dispatch.run_udf_parallel( - marshal(udf_inputs), + dispatch.run_udf_parallel( + udf_inputs, n_workers=n_workers, processed_cb=processed_cb, download_cb=download_cb, ) - process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb) finally: download_cb.close() processed_cb.close() generated_cb.close() - warehouse.insert_rows_done(table) - return 0 @@ -120,26 +117,24 @@ def __init__( catalog_init_params, metastore_clone_params, warehouse_clone_params, + query: "Select", + table: "Table", udf_fields: "Sequence[str]", cache: bool, is_generator: bool = False, + is_batching: bool = False, buffer_size: int = DEFAULT_BATCH_SIZE, ): self.udf_data = udf_data self.catalog_init_params = catalog_init_params - ( - self.metastore_class, - self.metastore_args, - self.metastore_kwargs, - ) = metastore_clone_params - ( - self.warehouse_class, - self.warehouse_args, - self.warehouse_kwargs, - ) = warehouse_clone_params + self.metastore_clone_params = metastore_clone_params + self.warehouse_clone_params = warehouse_clone_params + self.query = query + self.table = table self.udf_fields = udf_fields self.cache = cache self.is_generator = is_generator + self.is_batching = is_batching self.buffer_size = buffer_size self.catalog = None self.task_queue = None @@ -148,12 +143,10 @@ def __init__( def _create_worker(self) -> "UDFWorker": if not self.catalog: - metastore = self.metastore_class( - *self.metastore_args, **self.metastore_kwargs - ) - warehouse = self.warehouse_class( - *self.warehouse_args, **self.warehouse_kwargs - ) + ms_cls, ms_args, ms_kwargs = self.metastore_clone_params + metastore: AbstractMetastore = ms_cls(*ms_args, **ms_kwargs) + ws_cls, ws_args, ws_kwargs = self.warehouse_clone_params + warehouse: AbstractWarehouse = ws_cls(*ws_args, **ws_kwargs) self.catalog = Catalog(metastore, warehouse, **self.catalog_init_params) self.udf = loads(self.udf_data) return UDFWorker( @@ -161,7 +154,9 @@ def _create_worker(self) -> "UDFWorker": self.udf, self.task_queue, self.done_queue, - self.is_generator, + self.query, + self.table, + self.is_batching, self.cache, self.udf_fields, ) @@ -194,7 +189,7 @@ def run_udf_parallel( # noqa: C901, PLR0912 input_queue=None, processed_cb: Callback = DEFAULT_CALLBACK, download_cb: Callback = DEFAULT_CALLBACK, - ) -> Iterator[Sequence[UDFResult]]: + ) -> None: n_workers = get_n_workers_from_arg(n_workers) if self.buffer_size < n_workers: @@ -224,6 +219,9 @@ def run_udf_parallel( # noqa: C901, PLR0912 input_finished = False if not streaming_mode: + if not self.is_batching: + input_rows = batched(flatten(input_rows), DEFAULT_BATCH_SIZE) + # Stop all workers after the input rows have finished processing input_data = chain(input_rows, [STOP_SIGNAL] * n_workers) @@ -238,19 +236,17 @@ def run_udf_parallel( # noqa: C901, PLR0912 # Process all tasks while n_workers > 0: result = get_from_queue(self.done_queue) + + if downloaded := result.get("downloaded"): + download_cb.relative_update(downloaded) + if processed := result.get("processed"): + processed_cb.relative_update(processed) + status = result["status"] - if status == NOTIFY_STATUS: - if downloaded := result.get("downloaded"): - download_cb.relative_update(downloaded) - if processed := result.get("processed"): - processed_cb.relative_update(processed) + if status in (OK_STATUS, NOTIFY_STATUS): + pass # Do nothing here elif status == FINISHED_STATUS: - # Worker finished - n_workers -= 1 - elif status == OK_STATUS: - if processed := result.get("processed"): - processed_cb.relative_update(processed) - yield msgpack_unpack(result["result"]) + n_workers -= 1 # Worker finished else: # Failed / error n_workers -= 1 if exc := result.get("exception"): @@ -311,11 +307,13 @@ def relative_update(self, inc: int = 1) -> None: @attrs.define class UDFWorker: - catalog: Catalog - udf: UDFAdapter + catalog: "Catalog" + udf: "UDFAdapter" task_queue: "multiprocess.Queue" done_queue: "multiprocess.Queue" - is_generator: bool + query: "Select" + table: "Table" + is_batching: bool cache: bool udf_fields: Sequence[str] cb: Callback = attrs.field() @@ -325,31 +323,57 @@ def _default_callback(self) -> WorkerCallback: return WorkerCallback(self.done_queue) def run(self) -> None: + warehouse = self.catalog.warehouse.clone() processed_cb = ProcessedCallback() + udf_results = self.udf.run( self.udf_fields, - unmarshal(self.get_inputs()), + self.get_inputs(), self.catalog, - self.is_generator, self.cache, download_cb=self.cb, processed_cb=processed_cb, ) - for udf_output in udf_results: - for batch in batched_it(udf_output, DEFAULT_BATCH_SIZE): - put_into_queue( - self.done_queue, - { - "status": OK_STATUS, - "result": msgpack_pack(list(batch)), - }, - ) + process_udf_outputs( + warehouse, + self.table, + self.notify_and_process(udf_results, processed_cb), + self.udf, + cb=processed_cb, + ) + warehouse.insert_rows_done(self.table) + + put_into_queue( + self.done_queue, + {"status": FINISHED_STATUS, "processed": processed_cb.processed_rows}, + ) + + def notify_and_process(self, udf_results, processed_cb): + for row in udf_results: put_into_queue( self.done_queue, - {"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows}, + {"status": OK_STATUS, "processed": processed_cb.processed_rows}, ) - put_into_queue(self.done_queue, {"status": FINISHED_STATUS}) + yield row def get_inputs(self): - while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: - yield batch + warehouse = self.catalog.warehouse.clone() + col_id = get_query_id_column(self.query) + + if self.is_batching: + while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: + ids = [row[0] for row in batch.rows] + rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids))) + yield RowsOutputBatch(list(rows)) + else: + while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL: + rows = warehouse.dataset_rows_select( + self.query.where(col_id.in_(batch)) + ) + yield from rows + + +class RepeatTimer(Timer): + def run(self): + while not self.finished.wait(self.interval): + self.function(*self.args, **self.kwargs) diff --git a/src/datachain/query/utils.py b/src/datachain/query/utils.py new file mode 100644 index 000000000..0d92226b1 --- /dev/null +++ b/src/datachain/query/utils.py @@ -0,0 +1,42 @@ +from typing import TYPE_CHECKING, Optional, Union + +from sqlalchemy import Column + +if TYPE_CHECKING: + from sqlalchemy import ColumnElement, Select, TextClause + + +ColT = Union[Column, "ColumnElement", "TextClause"] + + +def column_name(col: ColT) -> str: + """Returns column name from column element.""" + return col.name if isinstance(col, Column) else str(col) + + +def get_query_column(query: "Select", name: str) -> Optional[ColT]: + """Returns column element from query by name or None if column not found.""" + return next((col for col in query.inner_columns if column_name(col) == name), None) + + +def get_query_id_column(query: "Select") -> ColT: + """Returns ID column element from query or None if column not found.""" + col = get_query_column(query, "sys__id") + if col is None: + raise RuntimeError("sys__id column not found in query") + return col + + +def select_only_columns(query: "Select", *names: str) -> "Select": + """Returns query selecting defined columns only.""" + if not names: + return query + + cols: list[ColT] = [] + for name in names: + col = get_query_column(query, name) + if col is None: + raise ValueError(f"Column '{name}' not found in query") + cols.append(col) + + return query.with_only_columns(*cols) diff --git a/src/datachain/utils.py b/src/datachain/utils.py index 21fcd6e49..11018df08 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -263,7 +263,7 @@ def batched_it(iterable: Iterable[_T_co], n: int) -> Iterator[Iterator[_T_co]]: def flatten(items): for item in items: - if isinstance(item, list): + if isinstance(item, (list, tuple)): yield from item else: yield item