Skip to content

Commit

Permalink
Use cloudpickle for parallel UDF processing (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtulga authored Jul 18, 2024
1 parent 00c846a commit 8ae9f8d
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 129 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"sqlalchemy>=2",
"multiprocess==0.70.16",
"dill==0.3.8",
"cloudpickle",
"ujson>=5.9.0",
"pydantic>=2,<3",
"jmespath>=1.0",
Expand Down
113 changes: 15 additions & 98 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
import contextlib
import datetime
import inspect
Expand All @@ -10,7 +9,6 @@
import string
import subprocess
import sys
import types
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Iterator, Sequence
from copy import copy
Expand All @@ -28,9 +26,7 @@
import attrs
import sqlalchemy
from attrs import frozen
from dill import dumps, source
from fsspec.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy.sql import func as f
from sqlalchemy.sql.elements import ColumnClause, ColumnElement
Expand All @@ -54,7 +50,11 @@
from datachain.progress import CombinedDownloadCallback
from datachain.sql.functions import rand
from datachain.storage import Storage, StorageURI
from datachain.utils import batched, determine_processes
from datachain.utils import (
batched,
determine_processes,
filtered_cloudpickle_dumps,
)

from .metrics import metrics
from .schema import C, UDFParamSpec, normalize_param
Expand Down Expand Up @@ -490,7 +490,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
elif processes:
# Parallel processing (faster for more CPU-heavy UDFs)
udf_info = {
"udf": self.udf,
"udf_data": filtered_cloudpickle_dumps(self.udf),
"catalog_init": self.catalog.get_init_params(),
"id_generator_clone_params": (
self.catalog.id_generator.clone_params()
Expand All @@ -511,16 +511,15 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

envs = dict(os.environ)
envs.update({"PYTHONPATH": os.getcwd()})
with self.process_feature_module():
process_data = dumps(udf_info, recurse=True)
result = subprocess.run( # noqa: S603
[datachain_exec_path, "--internal-run-udf"],
input=process_data,
check=False,
env=envs,
)
if result.returncode != 0:
raise RuntimeError("UDF Execution Failed!")
process_data = filtered_cloudpickle_dumps(udf_info)
result = subprocess.run( # noqa: S603
[datachain_exec_path, "--internal-run-udf"],
input=process_data,
check=False,
env=envs,
)
if result.returncode != 0:
raise RuntimeError("UDF Execution Failed!")

else:
# Otherwise process single-threaded (faster for smaller UDFs)
Expand Down Expand Up @@ -569,57 +568,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
self.catalog.warehouse.close()
raise

@contextlib.contextmanager
def process_feature_module(self):
# Generate a random name for the feature module
feature_module_name = "tmp" + _random_string(10)
# Create a dynamic module with the generated name
dynamic_module = types.ModuleType(feature_module_name)
# Get the import lines for the necessary objects from the main module
main_module = sys.modules["__main__"]
if getattr(main_module, "__file__", None):
import_lines = list(get_imports(main_module))
else:
import_lines = [
source.getimport(obj, alias=name)
for name, obj in main_module.__dict__.items()
if _imports(obj) and not (name.startswith("__") and name.endswith("__"))
]

# Get the feature classes from the main module
feature_classes = {
name: obj
for name, obj in main_module.__dict__.items()
if _feature_predicate(obj)
}
if not feature_classes:
yield None
return

# Get the source code of the feature classes
feature_sources = [source.getsource(cls) for _, cls in feature_classes.items()]
# Set the module name for the feature classes to the generated name
for name, cls in feature_classes.items():
cls.__module__ = feature_module_name
setattr(dynamic_module, name, cls)
# Add the dynamic module to the sys.modules dictionary
sys.modules[feature_module_name] = dynamic_module
# Combine the import lines and feature sources
feature_file = "\n".join(import_lines) + "\n" + "\n".join(feature_sources)

# Write the module content to a .py file
with open(f"{feature_module_name}.py", "w") as module_file:
module_file.write(feature_file)

try:
yield feature_module_name
finally:
for cls in feature_classes.values():
cls.__module__ = main_module.__name__
os.unlink(f"{feature_module_name}.py")
# Remove the dynamic module from sys.modules
del sys.modules[feature_module_name]

def create_partitions_table(self, query: Select) -> "Table":
"""
Create temporary table with group by partitions.
Expand Down Expand Up @@ -1829,34 +1777,3 @@ def _random_string(length: int) -> str:
random.choice(string.ascii_letters + string.digits) # noqa: S311
for i in range(length)
)


def _feature_predicate(obj):
return (
inspect.isclass(obj) and source.isfrommain(obj) and issubclass(obj, BaseModel)
)


def _imports(obj):
return not source.isfrommain(obj)


def get_imports(m):
root = ast.parse(inspect.getsource(m))

for node in ast.iter_child_nodes(root):
if isinstance(node, ast.Import):
module = None
elif isinstance(node, ast.ImportFrom):
module = node.module
else:
continue

for n in node.names:
import_script = ""
if module:
import_script += f"from {module} "
import_script += f"import {n.name}"
if n.asname:
import_script += f" as {n.asname}"
yield import_script
28 changes: 15 additions & 13 deletions src/datachain/query/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import attrs
import multiprocess
from dill import load
from cloudpickle import load, loads
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from multiprocess import get_context

Expand Down Expand Up @@ -84,7 +84,7 @@ def put_into_queue(queue: Queue, item: Any) -> None:

def udf_entrypoint() -> int:
# Load UDF info from stdin
udf_info = load(stdin.buffer) # noqa: S301
udf_info = load(stdin.buffer)

(
warehouse_class,
Expand All @@ -95,7 +95,7 @@ def udf_entrypoint() -> int:

# Parallel processing (faster for more CPU-heavy UDFs)
dispatch = UDFDispatcher(
udf_info["udf"],
udf_info["udf_data"],
udf_info["catalog_init"],
udf_info["id_generator_clone_params"],
udf_info["metastore_clone_params"],
Expand All @@ -108,7 +108,7 @@ def udf_entrypoint() -> int:
batching = udf_info["batching"]
table = udf_info["table"]
n_workers = udf_info["processes"]
udf = udf_info["udf"]
udf = loads(udf_info["udf_data"])
if n_workers is True:
# Use default number of CPUs (cores)
n_workers = None
Expand Down Expand Up @@ -146,7 +146,7 @@ class UDFDispatcher:

def __init__(
self,
udf,
udf_data,
catalog_init_params,
id_generator_clone_params,
metastore_clone_params,
Expand All @@ -155,14 +155,7 @@ def __init__(
is_generator=False,
buffer_size=DEFAULT_BATCH_SIZE,
):
# isinstance cannot be used here, as dill packages the entire class definition,
# and so these two types are not considered exactly equal,
# even if they have the same import path.
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
self.udf = udf
else:
self.udf = None
self.udf_factory = udf
self.udf_data = udf_data
self.catalog_init_params = catalog_init_params
(
self.id_generator_class,
Expand Down Expand Up @@ -214,6 +207,15 @@ def _create_worker(self) -> "UDFWorker":
self.catalog = Catalog(
id_generator, metastore, warehouse, **self.catalog_init_params
)
udf = loads(self.udf_data)
# isinstance cannot be used here, as cloudpickle packages the entire class
# definition, and so these two types are not considered exactly equal,
# even if they have the same import path.
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
self.udf = udf
else:
self.udf = None
self.udf_factory = udf
if not self.udf:
self.udf = self.udf_factory()

Expand Down
39 changes: 39 additions & 0 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import glob
import importlib.util
import io
import json
import os
import os.path as osp
Expand All @@ -13,8 +14,10 @@
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from uuid import UUID

import cloudpickle
from dateutil import tz
from dateutil.parser import isoparse
from pydantic import BaseModel

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -388,3 +391,39 @@ def inside_notebook() -> bool:
return False

return False


def get_all_subclasses(cls):
"""Return all subclasses of a given class.
Can return duplicates due to multiple inheritance."""
for subclass in cls.__subclasses__():
yield from get_all_subclasses(subclass)
yield subclass


def filtered_cloudpickle_dumps(obj: Any) -> bytes:
"""Equivalent to cloudpickle.dumps, but this supports Pydantic models."""
model_namespaces = {}

with io.BytesIO() as f:
pickler = cloudpickle.CloudPickler(f)

for model_class in get_all_subclasses(BaseModel):
# This "is not None" check is needed, because due to multiple inheritance,
# it is theoretically possible to get the same class twice from
# get_all_subclasses.
if model_class.__pydantic_parent_namespace__ is not None:
# __pydantic_parent_namespace__ can contain many unnecessary and
# unpickleable entities, so should be removed for serialization.
model_namespaces[model_class] = (
model_class.__pydantic_parent_namespace__
)
model_class.__pydantic_parent_namespace__ = None

try:
pickler.dump(obj)
return f.getvalue()
finally:
for model_class, namespace in model_namespaces.items():
# Restore original __pydantic_parent_namespace__ locally.
model_class.__pydantic_parent_namespace__ = namespace
Loading

0 comments on commit 8ae9f8d

Please sign in to comment.