Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #229: Add cloudpickle support for type-annotated parse_func #305

Closed
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11.7
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ print(poems)
```
Note that retries and caching are enabled by default.
So now if you run the same prompt again, you will get the same response, pretty much instantly.
You can delete the cache at `~/.cache/curator`.
You can delete the cache at the path specified by `os.path.join(os.path.expanduser("~"), ".cache", "curator")`.

#### Use LiteLLM backend for calling other models
You can use the [LiteLLM](https://docs.litellm.ai/docs/providers) backend for calling other models.
Expand Down Expand Up @@ -127,7 +127,11 @@ poet = curator.LLM(
Here:
* `prompt_func` takes a row of the dataset as input and returns the prompt for the LLM.
* `response_format` is the structured output class we defined above.
* `parse_func` takes the input (`row`) and the structured output (`poems`) and converts it to a list of dictionaries. This is so that we can easily convert the output to a HuggingFace Dataset object.
* `parse_func` takes the input (`row`) and the structured output (`poems`) and converts it to a list of dictionaries. This is so that we can easily convert the output to a HuggingFace Dataset object. For best practices:
* Define `parse_func` as a module-level function rather than a lambda to ensure proper serialization
* Use the `_DictOrBaseModel` type alias for input/output types: `def parse_func(row: _DictOrBaseModel, response: _DictOrBaseModel) -> _DictOrBaseModel`
* Type annotations are fully supported through our CustomPickler implementation
* Function hashing is path-independent, ensuring consistent caching across different environments (e.g., Ray clusters)

Now we can apply the `LLM` object to the dataset, which reads very pythonic.
```python
Expand All @@ -142,8 +146,8 @@ print(poem.to_pandas())
```
Note that `topics` can be created with `curator.LLM` as well,
and we can scale this up to create tens of thousands of diverse poems.
You can see a more detailed example in the [examples/poem.py](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples/poem.py) file,
and other examples in the [examples](https://github.com/bespokelabsai/curator/blob/mahesh/update_doc/examples) directory.
You can see a more detailed example in the [examples/poem-generation/poem.py](https://github.com/bespokelabsai/curator/blob/main/examples/poem-generation/poem.py) file,
and other examples in the [examples](https://github.com/bespokelabsai/curator/blob/main/examples) directory.

See the [docs](https://docs.bespokelabs.ai/) for more details as well as
for troubleshooting information.
Expand Down Expand Up @@ -201,4 +205,4 @@ npm -v # should print `10.9.0`
```

## Contributing
Contributions are welcome!
Contributions are welcome!
25 changes: 23 additions & 2 deletions examples/poem-generation/poem.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,36 @@ class Poems(BaseModel):
poems_list: List[str] = Field(description="A list of poems.")


from typing import Any, Dict, List, Union

# Type alias for input/output types
_DictOrBaseModel = Union[Dict[str, Any], BaseModel]


def parse_poems(row: _DictOrBaseModel, poems: _DictOrBaseModel) -> _DictOrBaseModel:
"""Parse the poems from the LLM response.

Args:
row: The input row containing the topic
poems: The structured output from the LLM (Poems model)

Returns:
A list of dictionaries containing the topic and poem
"""
if isinstance(poems, Poems):
return [{"topic": row["topic"], "poem": p} for p in poems.poems_list]
return [] # Handle edge case where poems is not a Poems instance


# We define an `LLM` object that generates poems which gets applied to the topics dataset.
poet = curator.LLM(
# The prompt_func takes a row of the dataset as input.
# The row is a dictionary with a single key 'topic' in this case.
prompt_func=lambda row: f"Write two poems about {row['topic']}.",
model_name="gpt-4o-mini",
response_format=Poems,
# `row` is the input row, and `poems` is the Poems class which is parsed from the structured output from the LLM.
parse_func=lambda row, poems: [{"topic": row["topic"], "poem": p} for p in poems.poems_list],
# Use the module-level parse function which supports type annotations
parse_func=parse_poems,
)

# We apply the prompter to the topics dataset.
Expand Down
229 changes: 117 additions & 112 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ keywords = ["ai", "curator", "bespoke"]
python = "^3.10"
pydantic = ">=2.9.2"
datasets = "^3.0.2"
cloudpickle = "^3.0.0"
instructor = "^1.6.3"
pytest = "^8.3.3"
pytest-asyncio = "^0.24.0"
Expand Down
16 changes: 12 additions & 4 deletions src/bespokelabs/curator/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from typing import Any, Callable, Dict, Iterable, Optional, Type, TypeVar, Union

from datasets import Dataset
from datasets.utils._dill import Pickler
from pydantic import BaseModel
from xxhash import xxh64

from bespokelabs.curator.utils.custom_pickler import CustomPickler

from bespokelabs.curator.db import MetadataDB
from bespokelabs.curator.llm.prompt_formatter import PromptFormatter
from bespokelabs.curator.request_processor import (
Expand Down Expand Up @@ -62,7 +63,9 @@ def __init__(
prompt_func: A function that takes a single row
and returns either a string (assumed to be a user prompt) or messages list
parse_func: A function that takes the input row and
response object and returns the parsed output
response object and returns the parsed output. Can use type annotations
(e.g., `def parse_func(row, response: ResponseModel) -> OutputType`)
as the function is serialized using cloudpickle for proper type annotation support.
response_format: A Pydantic model specifying the
response format from the LLM.
backend: The backend to use ("openai" or "litellm"). If None, will be auto-determined
Expand Down Expand Up @@ -281,12 +284,17 @@ def __call__(


def _get_function_hash(func) -> str:
"""Get a hash of a function's source code."""
"""Get a hash of a function's source code.

Uses CustomPickler to properly handle both:
1. Path normalization (from HuggingFace's Pickler)
2. Type annotations and closure variables (from cloudpickle)
"""
if func is None:
return xxh64("").hexdigest()

file = BytesIO()
Pickler(file, recurse=True).dump(func)
CustomPickler(file, recurse=True).dump(func)
return xxh64(file.getvalue()).hexdigest()


Expand Down
79 changes: 79 additions & 0 deletions src/bespokelabs/curator/utils/custom_pickler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Custom Pickler that combines HuggingFace's path normalization with type annotation support.

This module provides a CustomPickler class that extends HuggingFace's Pickler to support
both path normalization (for consistent function hashing across different environments)
and type annotations in function signatures.
"""

import os
from io import BytesIO
from typing import Any, Optional, Type, TypeVar, Union

import cloudpickle
from datasets.utils._dill import Pickler as HFPickler


class CustomPickler(HFPickler):
"""A custom pickler that combines HuggingFace's path normalization with type annotation support.

This pickler extends HuggingFace's Pickler to:
1. Preserve path normalization for consistent function hashing
2. Support type annotations in function signatures
3. Handle closure variables properly
"""

def __init__(self, file: BytesIO, recurse: bool = True):
"""Initialize the CustomPickler.

Args:
file: The file-like object to pickle to
recurse: Whether to recursively handle object attributes
"""
super().__init__(file, recurse=recurse)

def save(self, obj: Any, save_persistent_id: bool = True) -> None:
"""Save an object, handling type annotations properly.

This method attempts to use cloudpickle's type annotation handling while
preserving HuggingFace's path normalization logic.

Args:
obj: The object to pickle
save_persistent_id: Whether to save persistent IDs
"""
try:
# First try HuggingFace's pickler for path normalization
super().save(obj, save_persistent_id=save_persistent_id)
except Exception as e:
if "No default __reduce__ due to non-trivial __cinit__" in str(e):
# If HF's pickler fails with type annotation error, use cloudpickle
cloudpickle.dump(obj, self._file)
else:
# Re-raise other exceptions
raise


def dumps(obj: Any) -> bytes:
"""Pickle an object to bytes using CustomPickler.

Args:
obj: The object to pickle

Returns:
The pickled object as bytes
"""
file = BytesIO()
CustomPickler(file, recurse=True).dump(obj)
return file.getvalue()


def loads(data: bytes) -> Any:
"""Unpickle an object from bytes.

Args:
data: The pickled data

Returns:
The unpickled object
"""
return cloudpickle.loads(data)
101 changes: 101 additions & 0 deletions tests/test_custom_pickler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import pytest
from io import BytesIO
from typing import List
from pydantic import BaseModel

from bespokelabs.curator.utils.custom_pickler import CustomPickler, dumps, loads


class TestModel(BaseModel):
value: str
items: List[int]


def test_custom_pickler_type_annotations():
"""Test CustomPickler handles type annotations correctly."""

def func(x: TestModel) -> List[int]:
return x.items

# Test pickling and unpickling
pickled = dumps(func)
unpickled = loads(pickled)

# Test function still works
test_input = TestModel(value="test", items=[1, 2, 3])
assert unpickled(test_input) == [1, 2, 3]


def test_custom_pickler_path_normalization():
"""Test CustomPickler normalizes paths in function source."""
import tempfile
from pathlib import Path

def create_test_function(pkg_dir: Path):
"""Create a test function in a specific package directory."""
# Create a module file in the package directory
module_path = pkg_dir / "test_module.py"
with open(module_path, "w") as f:
f.write(
"""
def func():
path = os.path.join("/home", "user", "file.txt")
return path
"""
)

# Import the function from the file
import importlib.util
from types import ModuleType

spec = importlib.util.spec_from_file_location("test_module", str(module_path))
if spec is None or spec.loader is None:
raise ImportError(f"Could not load module from {module_path}")

module: ModuleType = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

if not hasattr(module, "func"):
raise AttributeError(f"Module {module_path} does not have 'func' attribute")

return module.func

# Create two identical functions in different Ray-like package directories
with tempfile.TemporaryDirectory() as tmp_dir1, tempfile.TemporaryDirectory() as tmp_dir2:
# Simulate Ray package paths
ray_pkg_dir1 = Path(tmp_dir1) / "ray" / "ray_pkg_123"
ray_pkg_dir2 = Path(tmp_dir2) / "ray" / "ray_pkg_456"
ray_pkg_dir1.mkdir(parents=True, exist_ok=True)
ray_pkg_dir2.mkdir(parents=True, exist_ok=True)

# Create and pickle functions from different directories
func1 = create_test_function(ray_pkg_dir1)
func2 = create_test_function(ray_pkg_dir2)

# Get hashes for both functions
pickled1 = dumps(func1)
pickled2 = dumps(func2)

# Hashes should match despite different Ray package paths
assert pickled1 == pickled2, "Function hashes should match regardless of Ray package path"


def test_custom_pickler_hybrid_serialization():
"""Test CustomPickler falls back to cloudpickle for type annotations."""

def func(x: TestModel, items: List[int]) -> List[int]:
return [i for i in items if i > int(x.value)]

# Test pickling with both type annotations and path-dependent code
file = BytesIO()
pickler = CustomPickler(file, recurse=True)
pickler.dump(func)

# Test unpickling
file.seek(0)
unpickled = loads(file.getvalue())

# Test function works
test_input = TestModel(value="2", items=[1, 2, 3])
assert unpickled(test_input, [1, 2, 3, 4]) == [3, 4]
35 changes: 30 additions & 5 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Any, Dict, Optional, Union
from unittest.mock import patch, MagicMock

import pytest
Expand All @@ -8,6 +8,9 @@

from bespokelabs.curator import LLM

# Type alias for input/output types
_DictOrBaseModel = Union[Dict[str, Any], BaseModel]


class MockResponseFormat(BaseModel):
"""Mock response format for testing."""
Expand Down Expand Up @@ -93,8 +96,16 @@ def test_single_completion_batch(prompter: LLM):
"""

# Create a prompter with batch=True
def simple_prompt_func():
return [
def simple_prompt_func(row: _DictOrBaseModel) -> _DictOrBaseModel:
"""Generate a simple prompt for testing.

Args:
row: The input row (unused in this test)

Returns:
A list of messages for the LLM
"""
messages = [
{
"role": "user",
"content": "Write a test message",
Expand All @@ -104,7 +115,10 @@ def simple_prompt_func():
"content": "You are a helpful assistant.",
},
]
return {"messages": messages}

# Set dummy OpenAI API key for testing
os.environ["OPENAI_API_KEY"] = "test-key"
batch_prompter = LLM(
model_name="gpt-4o-mini",
prompt_func=simple_prompt_func,
Expand Down Expand Up @@ -142,8 +156,16 @@ def test_single_completion_no_batch(prompter: LLM):
"""

# Create a prompter without batch parameter
def simple_prompt_func():
return [
def simple_prompt_func(row: _DictOrBaseModel) -> _DictOrBaseModel:
"""Generate a simple prompt for testing.

Args:
row: The input row (unused in this test)

Returns:
A list of messages for the LLM
"""
messages = [
{
"role": "user",
"content": "Write a test message",
Expand All @@ -153,7 +175,10 @@ def simple_prompt_func():
"content": "You are a helpful assistant.",
},
]
return {"messages": messages}

# Set dummy OpenAI API key for testing
os.environ["OPENAI_API_KEY"] = "test-key"
non_batch_prompter = LLM(
model_name="gpt-4o-mini",
prompt_func=simple_prompt_func,
Expand Down
Loading