Skip to content

Commit

Permalink
refactor: optimizer is now using the new pull based execution model
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Jan 10, 2025
1 parent 18cae16 commit 74e2ab9
Show file tree
Hide file tree
Showing 12 changed files with 1,411 additions and 251 deletions.
509 changes: 509 additions & 0 deletions docetl/containers.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion docetl/operations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def execute(
results, total_cost = super().execute(input_data)

# Drop records with filter_key values that are False
results = [result for result in results if result[filter_key]]
if not is_build:
results = [result for result in results if result[filter_key]]

return results, total_cost
1 change: 0 additions & 1 deletion docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ def auto_batch() -> int:
"comparison_prompts"
].append(prompt)

pbar.update(last_processed // batch_size)
total_cost += pair_costs

# Collect final clusters
Expand Down
701 changes: 701 additions & 0 deletions docetl/optimizer.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions docetl/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from docetl.optimizers.join_optimizer import JoinOptimizer
from docetl.optimizers.map_optimizer import MapOptimizer
from docetl.optimizers.reduce_optimizer import ReduceOptimizer

__all__ = ["JoinOptimizer", "MapOptimizer", "ReduceOptimizer"]
16 changes: 5 additions & 11 deletions docetl/optimizers/join_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,26 @@ class JoinOptimizer:
def __init__(
self,
runner,
config: Dict[str, Any],
op_config: Dict[str, Any],
console: Console,
llm_client: Any,
max_threads: int,
target_recall: float = 0.95,
sample_size: int = 500,
sampling_weight: float = 20,
agent_max_retries: int = 5,
estimated_selectivity: float = None,
status: Status = None,
):
self.runner = runner
self.config = config
self.config = runner.config
self.op_config = op_config
self.llm_client = llm_client
self.max_threads = max_threads
self.console = console
self.llm_client = runner.optimizer.llm_client
self.max_threads = runner.max_threads
self.console = runner.console
self.target_recall = target_recall
self.sample_size = sample_size
self.sampling_weight = sampling_weight
self.agent_max_retries = agent_max_retries
self.estimated_selectivity = estimated_selectivity
self.console.log(f"Target Recall: {self.target_recall}")
self.status = status
self.status = self.runner.status
# if self.estimated_selectivity is not None:
# self.console.log(
# f"[yellow]Using estimated selectivity of {self.estimated_selectivity}[/yellow]"
Expand Down Expand Up @@ -443,7 +438,6 @@ def should_optimize(self, input_data: List[Dict[str, Any]]) -> Tuple[bool, str]:
def optimize_resolve(
self, input_data: List[Dict[str, Any]]
) -> Tuple[Dict[str, Any], float]:

# Check if the operation is marked as empty
if self.op_config.get("empty", False):
# Extract the map prompt from the intermediates
Expand Down
52 changes: 25 additions & 27 deletions docetl/optimizers/map_optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class MapOptimizer:
def __init__(
self,
runner,
config: Dict[str, Any],
console: Console,
llm_client: LLMClient,
max_threads: int,
run_operation: Callable,
timeout: int = 10,
is_filter: bool = False,
Expand All @@ -53,45 +49,47 @@ def __init__(
Initialize the MapOptimizer.
Args:
config (Dict[str, Any]): The configuration dictionary for the optimizer.
console (Console): A Rich console object for pretty printing.
llm_client (LLMClient): A client for interacting with a language model.
max_threads (int): The maximum number of threads to use for parallel execution.
runner (Runner): The runner object.
run_operation (Callable): A function to execute operations.
timeout (int, optional): The timeout in seconds for operation execution. Defaults to 10.
is_filter (bool, optional): If True, the operation is a filter operation. Defaults to False.
"""
self.runner = runner
self.config = config
self.console = console
self.llm_client = llm_client
self.config = runner.config
self.console = runner.console
self.llm_client = runner.optimizer.llm_client
self._run_operation = run_operation
self.max_threads = max_threads
self.timeout = timeout
self.max_threads = runner.max_threads
self.timeout = runner.optimizer.timeout
self._num_plans_to_evaluate_in_parallel = 5
self.is_filter = is_filter
self.k_to_pairwise_compare = 6

self.plan_generator = PlanGenerator(
runner,
llm_client,
console,
config,
self.llm_client,
self.console,
self.config,
run_operation,
max_threads,
self.max_threads,
is_filter,
depth,
)
self.evaluator = Evaluator(
llm_client,
console,
run_operation,
timeout,
self.llm_client,
self.console,
self._run_operation,
self.timeout,
self._num_plans_to_evaluate_in_parallel,
is_filter,
self.is_filter,
)
self.prompt_generator = PromptGenerator(
runner, llm_client, console, config, max_threads, is_filter
self.runner,
self.llm_client,
self.console,
self.config,
self.max_threads,
self.is_filter,
)

def should_optimize(
Expand Down Expand Up @@ -180,7 +178,7 @@ def _should_optimize_helper(
no_change_runtime = time.time() - no_change_start

# Capture output for the sample run
self.runner.captured_output.save_optimizer_output(
self.runner.optimizer.captured_output.save_optimizer_output(
stage_type=StageType.SAMPLE_RUN,
output={
"operation_config": op_config,
Expand Down Expand Up @@ -215,7 +213,7 @@ def _should_optimize_helper(
)
self.console.log("\n") # Add a newline for better readability

self.runner.captured_output.save_optimizer_output(
self.runner.optimizer.captured_output.save_optimizer_output(
stage_type=StageType.SHOULD_OPTIMIZE,
output={
"validator_prompt": validator_prompt,
Expand Down Expand Up @@ -384,7 +382,7 @@ def optimize(
plans_list = list(candidate_plans.items())

# Capture candidate plans
self.runner.captured_output.save_optimizer_output(
self.runner.optimizer.captured_output.save_optimizer_output(
stage_type=StageType.CANDIDATE_PLANS,
output=candidate_plans,
)
Expand Down Expand Up @@ -512,7 +510,7 @@ def optimize(
ratings = {k: v[0] for k, v in results.items()}
runtime = {k: v[1] for k, v in results.items()}
sample_outputs = {k: v[2] for k, v in results.items()}
self.runner.captured_output.save_optimizer_output(
self.runner.optimizer.captured_output.save_optimizer_output(
stage_type=StageType.EVALUATION_RESULTS,
output={
"input_data": evaluation_samples,
Expand Down
8 changes: 0 additions & 8 deletions docetl/optimizers/map_optimizer/plan_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,6 @@ def determine_metadata_with_retry():
try:
optimized_reduce_ops, _, cost = ReduceOptimizer(
self.runner,
self.config,
self.console,
self.llm_client,
self.max_threads,
self._run_operation,
).optimize(reduce_op, sample_output)
self.subplan_optimizer_cost += cost
Expand Down Expand Up @@ -918,10 +914,6 @@ def _recursively_optimize_subtask(

subtask_optimizer = MapOptimizer(
self.runner,
self.config,
self.console,
self.llm_client,
self.max_threads,
self._run_operation,
is_filter=self.is_filter,
depth=self.depth + 1,
Expand Down
17 changes: 7 additions & 10 deletions docetl/optimizers/reduce_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,9 @@ class ReduceOptimizer:
def __init__(
self,
runner,
config: Dict[str, Any],
console: Console,
llm_client: LLMClient,
max_threads: int,
run_operation: Callable,
num_fold_prompts: int = 1,
num_samples_in_validation: int = 10,
status: Optional[Status] = None,
):
"""
Initialize the ReduceOptimizer.
Expand All @@ -60,14 +55,14 @@ def __init__(
num_samples_in_validation (int, optional): Number of samples to use in validation. Defaults to 10.
"""
self.runner = runner
self.config = config
self.console = console
self.llm_client = llm_client
self.config = self.runner.config
self.console = self.runner.console
self.llm_client = self.runner.optimizer.llm_client
self._run_operation = run_operation
self.max_threads = max_threads
self.max_threads = self.runner.max_threads
self.num_fold_prompts = num_fold_prompts
self.num_samples_in_validation = num_samples_in_validation
self.status = status
self.status = self.runner.status

def should_optimize_helper(
self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]]
Expand Down Expand Up @@ -1108,6 +1103,8 @@ def _generate_validator_prompt(
4. How well does the output adhere to any specific formatting requirements mentioned in the original prompt, such as character limits for summaries or specific data types for aggregated values?
Note that the output may reflect more than just the input provided, since we only provide a one-item sample input. Provide your response as a single string containing the custom validator prompt. The prompt should be tailored to the task and avoid generic criteria. The prompt should not reference a specific value in the sample input, but rather a general property.
Your prompt should not have any placeholders like {{ reduce_key }} or {{ input_key }}. It should just be a string.
"""

parameters = {
Expand Down
Loading

0 comments on commit 74e2ab9

Please sign in to comment.