Skip to content

Commit

Permalink
fix: equijoin is out of date (#269)
Browse files Browse the repository at this point in the history
* fix: equijoin is out of date

* fix: equijoin is out of date and there are runtime errors

* chore: bump up version

* fix: equijoin is out of date and there are runtime errors

* fix: equijoin is out of date and there are runtime errors

* fix: equijoin is out of date and there are runtime errors

* fix: equijoin is out of date and there are runtime errors

* fix: equijoin is out of date and there are runtime errors
  • Loading branch information
shreyashankar authored Jan 8, 2025
1 parent eb995fd commit 8f1036b
Show file tree
Hide file tree
Showing 9 changed files with 405 additions and 380 deletions.
2 changes: 1 addition & 1 deletion docetl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2"
__version__ = "0.21"

from docetl.runner import DSLRunner
from docetl.builder import Optimizer
Expand Down
27 changes: 17 additions & 10 deletions docetl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,12 +645,13 @@ def optimize(self) -> float:

for op in optimized_step["operations"]:
changed_op = False
op_name = list(op.keys())[0] if isinstance(op, dict) else op
for i, op_config in enumerate(self.optimized_config["operations"]):
if op_config["name"] == op:
self.optimized_config["operations"][i] = step_operations[op]
if op_config["name"] == op_name:
self.optimized_config["operations"][i] = step_operations[op_name]
changed_op = True
if not changed_op:
self.optimized_config["operations"].append(step_operations[op])
self.optimized_config["operations"].append(step_operations[op_name])

self.optimized_config["pipeline"]["steps"] = [
step
Expand All @@ -670,7 +671,7 @@ def optimize(self) -> float:
if s["name"] == step_name
][0],
"operations": [
self.find_operation(op, self.optimized_config)
self.find_operation(list(op.keys())[0] if isinstance(op, dict) else op, self.optimized_config)
for op in optimized_step["operations"]
],
}
Expand Down Expand Up @@ -830,8 +831,10 @@ def _optimize_step(
# Run the pipeline
step_ops = []
for step_op in step.get("operations"):
if step_op in replacement_operations:
step_ops.extend(replacement_operations[step_op])
step_op_name = list(step_op.keys())[0] if isinstance(step_op, dict) else step_op

if step_op_name in replacement_operations:
step_ops.extend(replacement_operations[step_op_name])
else:
step_ops.append(step_op)

Expand All @@ -852,7 +855,10 @@ def _optimize_step(
# If optimize is False or operation type is not supported, just use the operation without optimization
output_data = self._run_operation(op_object, input_data)
optimized_operations[operation_name] = op_object
optimized_operation_names.append(operation_name)
if op_object.get("type") == "equijoin":
optimized_operation_names.append(operation)
else:
optimized_operation_names.append(operation_name)

selectivity = len(output_data) / len(input_data)
self.selectivities[step.get("name")][operation_name] = selectivity
Expand Down Expand Up @@ -934,8 +940,8 @@ def _optimize_step(
new_right_name,
) = self._optimize_equijoin(
op_object,
operation["left"],
operation["right"],
next(iter(operation.values()))["left"],
next(iter(operation.values()))["right"],
input_data["left"],
input_data["right"],
status,
Expand Down Expand Up @@ -1033,6 +1039,7 @@ def _optimize_step(
output_data = input_data

optimized_step = step.copy()

optimized_step["operations"] = optimized_operation_names
return optimized_step, optimized_operations, output_data

Expand Down Expand Up @@ -1075,7 +1082,7 @@ def _get_sample_data(
{
"step": step,
"operations": [
self.find_operation(op) for op in step["operations"]
self.find_operation(list(op.keys())[0] if isinstance(op, dict) else op, self.optimized_config) for op in step["operations"]
],
}
).encode()
Expand Down
1 change: 0 additions & 1 deletion docetl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

app = typer.Typer()


@app.command()
def build(
yaml_file: Path = typer.Argument(
Expand Down
160 changes: 137 additions & 23 deletions docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@
from multiprocessing import Pool, cpu_count
from typing import Any, Dict, List, Tuple, Optional

from docetl import console
from docetl.operations.utils import strict_render
from docetl.operations.utils.progress import RichLoopBar
import numpy as np
from jinja2 import Template
from litellm import model_cost
from rich.prompt import Confirm
from rich.console import Console

from docetl.operations.base import BaseOperation
from docetl.operations.utils import (
rich_as_completed,
)
from docetl.utils import completion_cost
from pydantic import Field


# Global variables to store shared data
Expand All @@ -40,6 +39,10 @@ def is_match(left_item: Dict[str, Any], right_item: Dict[str, Any]) -> bool:
for condition in _blocking_conditions
)

# LLM-based comparison for blocked pairs
def get_hashable_key(item: Dict) -> str:
return json.dumps(item, sort_keys=True)


def process_left_item(
left_item: Dict[str, Any]
Expand Down Expand Up @@ -69,7 +72,7 @@ class schema(BaseOperation.schema):
limit_comparisons: Optional[int] = None
blocking_keys: Optional[Dict[str, List[str]]] = None
timeout: Optional[int] = None
litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict)
litellm_completion_kwargs: Dict[str, Any] = {}

def compare_pair(
self,
Expand All @@ -95,8 +98,11 @@ def compare_pair(
Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison.
"""


prompt = strict_render(comparison_prompt, {"left": item1, "right": item2})
try:
prompt = strict_render(comparison_prompt, {"left": item1, "right": item2})
except Exception as e:
self.console.print(f"[red]Error rendering prompt: {e}[/red]")
return False, 0
response = self.runner.api.call_llm(
model,
"compare",
Expand All @@ -107,10 +113,16 @@ def compare_pair(
bypass_cache=self.config.get("bypass_cache", False),
litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}),
)
output = self.runner.api.parse_llm_response(
response.response, {"is_match": "bool"}
)[0]
return output["is_match"], response.total_cost
cost = 0
try:
cost = response.total_cost
output = self.runner.api.parse_llm_response(
response.response, {"is_match": "bool"}
)[0]
except Exception as e:
self.console.print(f"[red]Error parsing LLM response: {e}[/red]")
return False, cost
return output["is_match"], cost

def syntax_check(self) -> None:
"""
Expand Down Expand Up @@ -215,10 +227,6 @@ def execute(
limit_comparisons = self.config.get("limit_comparisons")
total_cost = 0

# LLM-based comparison for blocked pairs
def get_hashable_key(item: Dict) -> str:
return json.dumps(item, sort_keys=True)

if len(left_data) == 0 or len(right_data) == 0:
return [], 0

Expand All @@ -244,8 +252,13 @@ def get_hashable_key(item: Dict) -> str:

# Check if we have exceeded the pairwise comparison limit
if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons:
# Sample pairs randomly
sampled_pairs = random.sample(blocked_pairs, limit_comparisons)
# Sample pairs based on cardinality and length
sampled_pairs = stratified_length_sample(
blocked_pairs,
limit_comparisons,
sample_size=1000,
console=self.console
)

# Calculate number of dropped pairs
dropped_pairs = len(blocked_pairs) - limit_comparisons
Expand All @@ -257,7 +270,7 @@ def get_hashable_key(item: Dict) -> str:
f"[yellow]Warning: {dropped_pairs} pairs will be dropped due to the comparison limit. "
f"Proceeding with {limit_comparisons} randomly sampled pairs. "
f"Do you want to continue?[/yellow]",
self.console,
console=self.console,
):
raise ValueError("Operation cancelled by user due to pair limit.")

Expand Down Expand Up @@ -410,6 +423,9 @@ def get_embeddings(
results = []
comparison_costs = 0

if self.status:
self.status.stop()

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
future_to_pair = {
executor.submit(
Expand All @@ -424,12 +440,14 @@ def get_embeddings(
for left, right in blocked_pairs
}

for future in rich_as_completed(
future_to_pair,
total=len(future_to_pair),
desc="Comparing pairs",
pbar = RichLoopBar(
range(len(future_to_pair)),
desc=f"Comparing pairs",
console=self.console,
):
)

for i in pbar:
future = list(future_to_pair.keys())[i]
pair = future_to_pair[future]
is_match, cost = future.result()
comparison_costs += cost
Expand Down Expand Up @@ -460,6 +478,9 @@ def get_embeddings(

total_cost += comparison_costs

if self.status:
self.status.start()

# Calculate and print the join selectivity
join_selectivity = (
len(results) / (len(left_data) * len(right_data))
Expand All @@ -472,3 +493,96 @@ def get_embeddings(
self.status.start()

return results, total_cost


def estimate_length(items: List[Dict], sample_size: int = 1000) -> float:
"""
Estimates average document length in the relation.
Returns a normalized score (0-1) representing relative document size.
Args:
items: List of dictionary items to analyze
sample_size: Number of items to sample for estimation
Returns:
float: Normalized score based on average document length
"""
if not items:
return 0.0

sample_size = min(len(items), sample_size)
sample = random.sample(items, sample_size)

def get_doc_length(doc: Dict) -> int:
"""Calculate total length of all string values in document"""
total_len = 0
for value in doc.values():
if isinstance(value, str):
total_len += len(value)
elif isinstance(value, (list, dict)):
# For nested structures, use their string representation
total_len += len(str(value))
return total_len

lengths = [get_doc_length(item) for item in sample]
if not lengths:
return 0.0

avg_length = sum(lengths) / len(lengths)
return avg_length


def stratified_length_sample(
blocked_pairs: List[Tuple[Dict, Dict]],
limit_comparisons: int,
sample_size: int = 1000,
console: Console = None
) -> List[Tuple[Dict, Dict]]:
"""
Samples pairs stratified by the smaller cardinality relation,
prioritizing longer matches within each stratum.
"""
# Extract left and right items
left_items = [left for left, _ in blocked_pairs]
right_items = [right for _, right in blocked_pairs]

# Estimate length for both relations
left_length = estimate_length(left_items, sample_size)
right_length = estimate_length(right_items, sample_size)

# Group by the relation with estimated lower length
use_left_as_key = left_length > right_length
if console:
longer_length = max(left_length, right_length)
longer_side = "left" if left_length > right_length else "right"
console.log(f"Longer length is {longer_length:.2f} ({longer_side} side). Using {longer_side} to sample matches.")
groups = defaultdict(list)

for left, right in blocked_pairs:
key = get_hashable_key(left if use_left_as_key else right)
value = (left, right)
groups[key].append(value)

# Sort each group by length of the other relation's item
for key in groups:
groups[key].sort(
key=lambda x: len(x[1 if use_left_as_key else 0]),
reverse=True # Prioritize longer matches
)

# Calculate samples per group
n_groups = len(groups)
base_samples_per_group = limit_comparisons // n_groups
extra_samples = limit_comparisons % n_groups

# Sample from each group
sampled_pairs = []
for i, (key, pairs) in enumerate(groups.items()):
# Add one extra sample to early groups if we have remainder
group_sample_size = min(
len(pairs),
base_samples_per_group + (1 if i < extra_samples else 0)
)
sampled_pairs.extend(pairs[:group_sample_size])

return sampled_pairs
2 changes: 1 addition & 1 deletion docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class schema(BaseOperation.schema):
batch_size: Optional[int] = None
clustering_method: Optional[str] = None
batch_prompt: Optional[str] = None
litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict)
litellm_completion_kwargs: Dict[str, Any] = {}
@field_validator("drop_keys")
def validate_drop_keys(cls, v):
if isinstance(v, str):
Expand Down
2 changes: 1 addition & 1 deletion docetl/optimizers/join_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def optimize_equijoin(
if self.status:
self.status.stop()
# Use Rich's Confirm for input
if Confirm.ask("Use this rule?", self.console):
if Confirm.ask("Use this rule?", console=self.console):
selected_containment_rules.append(rule)
# Restart the status
if self.status:
Expand Down
4 changes: 2 additions & 2 deletions docetl/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ def execute_step(
self.status,
)
if op_object["type"] == "equijoin":
left_data = self.datasets[op_object["left"]].load()
right_data = self.datasets[op_object["right"]].load()
left_data = self.datasets[next(iter(operation.values()))["left"]].load()
right_data = self.datasets[next(iter(operation.values()))["right"]].load()
input_data, cost = operation_instance.execute(left_data, right_data)
else:
input_data, cost = operation_instance.execute(input_data)
Expand Down
Loading

0 comments on commit 8f1036b

Please sign in to comment.