Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize UDF with parallel execution #713

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

dreadatour
Copy link
Contributor

@dreadatour dreadatour commented Dec 13, 2024

Highlights

In this PR I am:

  • Passing IDs only instead of whole rows into parallel processes UDFs. This will prevent multiple data conversion (pickling) along the way
  • Write to DB right from parallel process instead of passing results pack into main process

This will prevent multiple types conversion.

Before it was:

  1. Read rows from DB -> convert them from raw DB types into Python types (whole raw)
  2. Pass rows into parallel processes via multiprocess.Queue -> convert rows from Python types with msgpack
  3. Read rows in parallel process from multiprocess.Queue -> convert rows to Python types (maspack)
  4. Process UDF
  5. Send rows back to the main process via multiprocess.Queue -> convert from Python type (msgpack)
  6. In main process read rows from multiprocess.Queue -> convert them back into Python type (msgpack)
  7. Write rows into DB -> convert them from Python types into DB raw types

After:

  1. Read rows IDs from DB -> convert DB ints into Python ints is quick, stable, predictable
  2. Pass rows IDs to parallel processes in batches via multiprocess.Queue -> convert list of IDs with pickle (by default)
  3. In parallel process for each batch of IDs read rows from DB -> convert them from raw DB types into Python types
  4. Process UDF
  5. Write result rows into DB right from the parallel process -> convert them from Python types into DB raw types

In the end:

  • no multiple type conversion
  • no more single point reading/writing from/to warehouse DB: read and write from parallel processes (worker machines in future)

Test scenario

Simple script to check raw parallel setting only:

import sys
from datachain import DataChain, File


def path_len(file: File) -> int:
    return len(file.path)


parallel = int(sys.argv[1]) if len(sys.argv) > 1 else None


DataChain \
    .from_dataset("50k-laion-files") \
    .settings(parallel=parallel, prefetch=0) \
    .map(path_len=path_len) \
    .save("50k-laion-files-len")

This is very simple and basic scenario, but it helps us to test parallel setting only, without any overheads.

Note prefetch is off in this case to measure parallel only

Overview

On the chart below there are two series: before optimization (blue) and after (green). On the X axis is parallel processes count, on the Y axis is number of total rows processed by UDF in parallel. This is valid for SQLite warehouse on my local machine.

image

As we can see, "before" parallel option does not affect performance at all, there is a strict limit on performance and it does not depends on number of parallel processes.

The reason is because we pass rows into UDF for each parallel process via multiprocess.Queue and get results back the same way. Queue performance is very limited. I wrote a simple script to test Queue only and it is limited indeed. I have tried different ways for IPC (Pipes, ZeroMQ) and they all have this limit. This can be solved introducing external dependencies (Redis, RabbitMQ, etc), but it is not what we want for CLI tool.

"After" performs much better, Queue is used only to pass IDs in batches and is performant enough to show the performance boost depending on parallel processes count. It is not linear, because performance of SQLite warehouse is now the limit, but it is much better, stable and predictable.

Also note "1 parallel process" performance is ~2.15 times slower than clean "no parallel processes" and this is basically overhead for using parallel processes and queues to read and pass IDs. on 2-3 parallel processes performance is the same as on "no parallel" and it is increasing over parallel processes count increasing.

Next I am going to measure the same numbers on ClickHouse DB warehouse, I suppose it is going to be much better and linear.

More measurements for those who love raw numbers

Before

Not parallel (for reference)
$ time python path-len.py
Preparing: 129136 rows [00:01, 117496.00 rows/s]
Processed: 129136 rows [00:04, 30999.24 rows/s]
Cleanup: 2 tables [00:00, 297.27 tables/s]
python path-len.py  7.14s user 3.01s system 131% cpu 7.728 total
Parallel = 1 (edge case)
$ time python path-len.py 1
Preparing: 129136 rows [00:01, 114113.84 rows/s]
Processed: 129136 rows [00:11, 11689.08 rows/s]
Cleanup: 2 tables [00:00, 310.41 tables/s]
python path-len.py 1  20.59s user 11.62s system 202% cpu 15.872 total
Parallel = 8
$ time python path-len.py 8
Preparing: 129136 rows [00:01, 117677.17 rows/s]
Processed: 129136 rows [00:10, 12675.34 rows/s]
Cleanup: 2 tables [00:00, 331.29 tables/s]
python path-len.py 8  32.53s user 15.31s system 306% cpu 15.630 total

After

Not parallel (for reference)
$ time python path-len.py
Preparing: 129136 rows [00:01, 113141.62 rows/s]
Processed: 129136 rows [00:04, 30850.59 rows/s]
Cleanup: 2 tables [00:00, 275.74 tables/s]
python path-len.py  7.06s user 3.38s system 135% cpu 7.701 total
Parallel = 1 (edge case)
$ time python path-len.py 1
Preparing: 129136 rows [00:01, 115298.96 rows/s]
Processed: 129137 rows [00:08, 15122.98 rows/s]
Cleanup: 2 tables [00:00, 317.81 tables/s]
python path-len.py 1  14.27s user 9.30s system 173% cpu 13.569 total
Parallel = 8
$ time python path-len.py 8
Preparing: 129136 rows [00:01, 115127.01 rows/s]
Processed: 129144 rows [00:03, 39761.77 rows/s]
Cleanup: 2 tables [00:00, 295.43 tables/s]
python path-len.py 8  27.65s user 13.99s system 452% cpu 9.206 total

Copy link

cloudflare-workers-and-pages bot commented Dec 13, 2024

Deploying datachain-documentation with  Cloudflare Pages  Cloudflare Pages

Latest commit: 4e5b602
Status: ✅  Deploy successful!
Preview URL: https://8cc90dde.datachain-documentation.pages.dev
Branch Preview URL: https://optimize-parallel-execution.datachain-documentation.pages.dev

View logs

Copy link

codecov bot commented Dec 13, 2024

Codecov Report

Attention: Patch coverage is 78.26087% with 30 lines in your changes missing coverage. Please review.

Project coverage is 87.21%. Comparing base (10e90c5) to head (4e5b602).

Files with missing lines Patch % Lines
src/datachain/query/dispatch.py 84.61% 7 Missing and 5 partials ⚠️
src/datachain/query/utils.py 50.00% 10 Missing and 1 partial ⚠️
src/datachain/query/batch.py 62.50% 3 Missing and 3 partials ⚠️
src/datachain/utils.py 0.00% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #713      +/-   ##
==========================================
- Coverage   87.44%   87.21%   -0.24%     
==========================================
  Files         114      116       +2     
  Lines       10898    10963      +65     
  Branches     1499     1508       +9     
==========================================
+ Hits         9530     9561      +31     
- Misses        990     1024      +34     
  Partials      378      378              
Flag Coverage Δ
datachain 87.14% <78.26%> (-0.24%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -85,7 +85,6 @@ def run(
udf_fields: "Sequence[str]",
udf_inputs: "Iterable[RowsOutput]",
catalog: "Catalog",
is_generator: bool,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used anywhere


with contextlib.closing(
batching(warehouse.dataset_select_paginated, query)
batching(warehouse.db.execute, query, ids_only=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure yet, but it looks like we don't need pagination here since we are only selecting IDs.
Should be tested on bigger scale and confirmed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our sys__id is 8 bytes, so on 1B scale this will take 8GB of memory by my calculation. I would still maybe leave it paginated.

n_workers=n_workers,
processed_cb=processed_cb,
download_cb=download_cb,
)
process_udf_outputs(warehouse, table, udf_results, udf, cb=generated_cb)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are processing results and inserting them into DB in parallel processes now.

download_cb.relative_update(downloaded)
if processed := result.get("processed"):
processed_cb.relative_update(processed)
if status in (OK_STATUS, NOTIFY_STATUS):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doing updates above now for all types of signals, no need to process these signals here.

Comment on lines 337 to 349
process_udf_outputs(
warehouse,
self.table,
self.notify_and_process(udf_results, processed_cb),
self.udf,
cb=processed_cb,
)
warehouse.insert_rows_done(self.table)

put_into_queue(
self.done_queue,
{"status": FINISHED_STATUS, "processed": processed_cb.processed_rows},
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not pass results into main process, write them right into DB here.

Comment on lines +351 to +357
def notify_and_process(self, udf_results, processed_cb):
for row in udf_results:
put_into_queue(
self.done_queue,
{"status": NOTIFY_STATUS, "processed": processed_cb.processed_rows},
{"status": OK_STATUS, "processed": processed_cb.processed_rows},
)
put_into_queue(self.done_queue, {"status": FINISHED_STATUS})
yield row
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helper function to notify process before writing results into DB.

@amritghimire
Copy link
Contributor

Also note "1 parallel process" performance is ~2.15 times slower than clean "no parallel processes" and this is basically overhead for using parallel processes and queues to read and pass IDs. on 2-3 parallel processes performance is the same as on "no parallel" and it is increasing over parallel processes count increasing.

Is it possible to treat 1 parallel process as no parallel processes or raise error when only one parallel is specified?

@skshetry
Copy link
Member

@dreadatour, while you are working on this, could you please also take a look at this example test:

https://github.com/iterative/datachain/actions/runs/12326251928/job/34406754039?pr=713#step:7:29

862.81s call tests/examples/test_examples.py::test_get_started_examples[examples/get_started/udfs/parallel.py]

Does this PR improve that test? Should it take that long?

@dreadatour
Copy link
Contributor Author

Is it possible to treat 1 parallel process as no parallel processes or raise error when only one parallel is specified?

Sure, sounds reasonable 👍

@dreadatour, while you are working on this, could you please also take a look at this example test:

https://github.com/iterative/datachain/actions/runs/12326251928/job/34406754039?pr=713#step:7:29

862.81s call tests/examples/test_examples.py::test_get_started_examples[examples/get_started/udfs/parallel.py]

Does this PR improve that test? Should it take that long?

Nice catch, let me take a look 🙏

@dreadatour
Copy link
Contributor Author

@dreadatour, while you are working on this, could you please also take a look at this example test:

https://github.com/iterative/datachain/actions/runs/12326251928/job/34406754039?pr=713#step:7:29

862.81s call tests/examples/test_examples.py::test_get_started_examples[examples/get_started/udfs/parallel.py]

Does this PR improve that test? Should it take that long?

Found an issue. This is because of this (basically everything runs in single process because batch size is 10k and number of records is 400).

Couple tests:

main branch
$ time python examples/get_started/udfs/parallel.py
Preparing: 400 rows [00:00, 75699.21 rows/s]
Processed: 400 rows [00:35, 11.34 rows/s]
Cleanup: 2 tables [00:00, 10525.23 tables/s]
python examples/get_started/udfs/parallel.py  391.85s user 15.58s system 948% cpu 42.941 total
this branch

Basically single process

$ time python examples/get_started/udfs/parallel.py
Preparing: 400 rows [00:00, 107926.77 rows/s]
Processed: 401 rows [04:01,  1.66 rows/s]
Cleanup: 2 tables [00:00, 10512.04 tables/s]
python examples/get_started/udfs/parallel.py  260.37s user 15.32s system 110% cpu 4:09.27 total
this branch but batch size is set to 10
$ time python examples/get_started/udfs/parallel.py
Preparing: 400 rows [00:00, 110894.41 rows/s]
Processed: 412 rows [00:38, 10.64 rows/s]
Cleanup: 2 tables [00:00, 10356.31 tables/s]
python examples/get_started/udfs/parallel.py  391.61s user 18.10s system 885% cpu 46.277 total

@dreadatour
Copy link
Contributor Author

@dreadatour, while you are working on this, could you please also take a look at this example test:

https://github.com/iterative/datachain/actions/runs/12326251928/job/34406754039?pr=713#step:7:29

862.81s call tests/examples/test_examples.py::test_get_started_examples[examples/get_started/udfs/parallel.py]

Does this PR improve that test? Should it take that long?

Fixed: https://github.com/iterative/datachain/actions/runs/12372015291/job/34529325219?pr=713 (409 sec)
Current main for reference: https://github.com/iterative/datachain/actions/runs/12367070520/job/34514761451 (same 410 sec)

Copy link
Contributor

@ilongin ilongin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work! I'\ve added a couple of small comments and a few thoughts / questions.

current_partition: Optional[int] = None
batch: list[Sequence] = []

query_fields = [str(c.name) for c in query.selected_columns]
# query_fields = [column_name(col) for col in query.inner_columns]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented out code

@@ -464,8 +465,8 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

with subprocess.Popen(cmd, env=envs, stdin=subprocess.PIPE) as process: # noqa: S603
process.communicate(process_data)
if process.poll():
raise RuntimeError("UDF Execution Failed!")
if ret := process.poll():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe put full variable name as ret is not so common shortcut IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

retval may be? 👀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Identifiers that exist for short scopes should be short." It is consumed in the next line.

So, this is okay.

https://github.com/iterative/studio/wiki/BE-review-cheatsheet#match-a-var-name-length-to-its-scope-

Copy link
Member

@skshetry skshetry Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, it could be renamed retval (as you have proposed), exitcode, retcode, etc. But it's not necessary imo.


with contextlib.closing(
batching(warehouse.dataset_select_paginated, query)
batching(warehouse.db.execute, query, ids_only=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our sys__id is 8 bytes, so on 1B scale this will take 8GB of memory by my calculation. I would still maybe leave it paginated.

if self.is_batching:
while (batch := get_from_queue(self.task_queue)) != STOP_SIGNAL:
ids = [row[0] for row in batch.rows]
rows = warehouse.dataset_rows_select(self.query.where(col_id.in_(ids)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that I'm worried about is this query in Clickhouse. There we have rows sorted and "packed" in granules by sys__id which is primary key. It would be ideal if these id batches are all sorted and "close" to each other, as otherwise we could end up in situation where big chunk of DB is read for every batch just because one id ended up in first granule / part, other id was in second one etc. This is because CH reads whole part / granule even if there is we need only one record from it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants