-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize UDF with parallel execution #713
base: main
Are you sure you want to change the base?
Changes from all commits
5953afe
c751005
699fb96
4e5b602
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
from datachain.data_storage.schema import PARTITION_COLUMN_ID | ||
from datachain.data_storage.warehouse import SELECT_BATCH_SIZE | ||
from datachain.query.utils import get_query_column, get_query_id_column | ||
|
||
if TYPE_CHECKING: | ||
from sqlalchemy import Select | ||
|
@@ -23,11 +24,14 @@ | |
class BatchingStrategy(ABC): | ||
"""BatchingStrategy provides means of batching UDF executions.""" | ||
|
||
is_batching: bool | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, | ||
execute: Callable[..., Generator[Sequence, None, None]], | ||
execute: Callable, | ||
query: "Select", | ||
ids_only: bool = False, | ||
) -> Generator[RowsOutput, None, None]: | ||
"""Apply the provided parameters to the UDF.""" | ||
|
||
|
@@ -38,11 +42,16 @@ | |
batch UDF calls. | ||
""" | ||
|
||
is_batching = False | ||
|
||
def __call__( | ||
self, | ||
execute: Callable[..., Generator[Sequence, None, None]], | ||
execute: Callable, | ||
query: "Select", | ||
ids_only: bool = False, | ||
) -> Generator[Sequence, None, None]: | ||
if ids_only: | ||
query = query.with_only_columns(get_query_id_column(query)) | ||
return execute(query) | ||
|
||
|
||
|
@@ -52,14 +61,20 @@ | |
is passed a sequence of multiple parameter sets. | ||
""" | ||
|
||
is_batching = True | ||
|
||
def __init__(self, count: int): | ||
self.count = count | ||
|
||
def __call__( | ||
self, | ||
execute: Callable[..., Generator[Sequence, None, None]], | ||
execute: Callable, | ||
query: "Select", | ||
ids_only: bool = False, | ||
) -> Generator[RowsOutputBatch, None, None]: | ||
if ids_only: | ||
query = query.with_only_columns(get_query_id_column(query)) | ||
|
||
# choose page size that is a multiple of the batch size | ||
page_size = math.ceil(SELECT_BATCH_SIZE / self.count) * self.count | ||
|
||
|
@@ -84,19 +99,31 @@ | |
Dataset rows need to be sorted by the grouping column. | ||
""" | ||
|
||
is_batching = True | ||
|
||
def __call__( | ||
self, | ||
execute: Callable[..., Generator[Sequence, None, None]], | ||
execute: Callable, | ||
query: "Select", | ||
ids_only: bool = False, | ||
) -> Generator[RowsOutputBatch, None, None]: | ||
id_col = get_query_id_column(query) | ||
if (partition_col := get_query_column(query, PARTITION_COLUMN_ID)) is None: | ||
raise RuntimeError("partition column not found in query") | ||
|
||
if ids_only: | ||
query = query.with_only_columns(id_col, partition_col) | ||
|
||
current_partition: Optional[int] = None | ||
batch: list[Sequence] = [] | ||
|
||
query_fields = [str(c.name) for c in query.selected_columns] | ||
# query_fields = [column_name(col) for col in query.inner_columns] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commented out code |
||
id_column_idx = query_fields.index("sys__id") | ||
partition_column_idx = query_fields.index(PARTITION_COLUMN_ID) | ||
|
||
ordered_query = query.order_by(None).order_by( | ||
PARTITION_COLUMN_ID, | ||
partition_col, | ||
*query._order_by_clauses, | ||
) | ||
|
||
|
@@ -108,7 +135,7 @@ | |
if len(batch) > 0: | ||
yield RowsOutputBatch(batch) | ||
batch = [] | ||
batch.append(row) | ||
batch.append([row[id_column_idx]] if ids_only else row) | ||
|
||
if len(batch) > 0: | ||
yield RowsOutputBatch(batch) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,8 +43,9 @@ | |
from datachain.dataset import DatasetStatus, RowDict | ||
from datachain.error import DatasetNotFoundError, QueryScriptCancelError | ||
from datachain.func.base import Function | ||
from datachain.lib.udf import UDFAdapter | ||
from datachain.progress import CombinedDownloadCallback | ||
from datachain.query.schema import C, UDFParamSpec, normalize_param | ||
from datachain.query.session import Session | ||
from datachain.sql.functions.random import rand | ||
from datachain.utils import ( | ||
batched, | ||
|
@@ -53,9 +54,6 @@ | |
get_datachain_executable, | ||
) | ||
|
||
from .schema import C, UDFParamSpec, normalize_param | ||
from .session import Session | ||
|
||
if TYPE_CHECKING: | ||
from sqlalchemy.sql.elements import ClauseElement | ||
from sqlalchemy.sql.schema import Table | ||
|
@@ -65,7 +63,8 @@ | |
from datachain.catalog import Catalog | ||
from datachain.data_storage import AbstractWarehouse | ||
from datachain.dataset import DatasetRecord | ||
from datachain.lib.udf import UDFResult | ||
from datachain.lib.udf import UDFAdapter, UDFResult | ||
from datachain.query.udf import UdfInfo | ||
|
||
P = ParamSpec("P") | ||
|
||
|
@@ -301,7 +300,7 @@ def adjust_outputs( | |
return row | ||
|
||
|
||
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: UDFAdapter) -> list[tuple]: | ||
def get_udf_col_types(warehouse: "AbstractWarehouse", udf: "UDFAdapter") -> list[tuple]: | ||
"""Optimization: Precompute UDF column types so these don't have to be computed | ||
in the convert_type function for each row in a loop.""" | ||
dialect = warehouse.db.dialect | ||
|
@@ -322,7 +321,7 @@ def process_udf_outputs( | |
warehouse: "AbstractWarehouse", | ||
udf_table: "Table", | ||
udf_results: Iterator[Iterable["UDFResult"]], | ||
udf: UDFAdapter, | ||
udf: "UDFAdapter", | ||
batch_size: int = INSERT_BATCH_SIZE, | ||
cb: Callback = DEFAULT_CALLBACK, | ||
) -> None: | ||
|
@@ -347,6 +346,8 @@ def process_udf_outputs( | |
for row_chunk in batched(rows, batch_size): | ||
warehouse.insert_rows(udf_table, row_chunk) | ||
|
||
warehouse.insert_rows_done(udf_table) | ||
|
||
|
||
def get_download_callback() -> Callback: | ||
return CombinedDownloadCallback( | ||
|
@@ -366,7 +367,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback: | |
|
||
@frozen | ||
class UDFStep(Step, ABC): | ||
udf: UDFAdapter | ||
udf: "UDFAdapter" | ||
catalog: "Catalog" | ||
partition_by: Optional[PartitionByType] = None | ||
parallel: Optional[int] = None | ||
|
@@ -440,7 +441,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: | |
raise RuntimeError( | ||
"In-memory databases cannot be used with parallel processing." | ||
) | ||
udf_info = { | ||
udf_info: UdfInfo = { | ||
"udf_data": filtered_cloudpickle_dumps(self.udf), | ||
"catalog_init": self.catalog.get_init_params(), | ||
"metastore_clone_params": self.catalog.metastore.clone_params(), | ||
|
@@ -464,8 +465,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: | |
|
||
with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603 | ||
process.communicate(process_data) | ||
if process.poll(): | ||
raise RuntimeError("UDF Execution Failed!") | ||
if ret := process.poll(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would maybe put full variable name as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Identifiers that exist for short scopes should be short." It is consumed in the next line. So, this is okay. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, it could be renamed |
||
raise RuntimeError(f"UDF Execution Failed! Exit code: {ret}") | ||
else: | ||
# Otherwise process single-threaded (faster for smaller UDFs) | ||
warehouse = self.catalog.warehouse | ||
|
@@ -479,7 +480,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: | |
udf_fields, | ||
udf_inputs, | ||
self.catalog, | ||
self.is_generator, | ||
self.cache, | ||
download_cb, | ||
processed_cb, | ||
|
@@ -496,8 +496,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: | |
processed_cb.close() | ||
generated_cb.close() | ||
|
||
warehouse.insert_rows_done(udf_table) | ||
|
||
except QueryScriptCancelError: | ||
self.catalog.warehouse.close() | ||
sys.exit(QUERY_SCRIPT_CANCELED_EXIT_CODE) | ||
|
@@ -1491,7 +1489,7 @@ def chunk(self, index: int, total: int) -> "Self": | |
@detach | ||
def add_signals( | ||
self, | ||
udf: UDFAdapter, | ||
udf: "UDFAdapter", | ||
parallel: Optional[int] = None, | ||
workers: Union[bool, int] = False, | ||
min_task_size: Optional[int] = None, | ||
|
@@ -1535,7 +1533,7 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self": | |
@detach | ||
def generate( | ||
self, | ||
udf: UDFAdapter, | ||
udf: "UDFAdapter", | ||
parallel: Optional[int] = None, | ||
workers: Union[bool, int] = False, | ||
min_task_size: Optional[int] = None, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used anywhere