diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 60a0222f..5907a302 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,10 +20,10 @@ repos: - id: check-docstring-first - repo: https://github.com/hadialqattan/pycln - rev: v2.4.0 + rev: v2.5.0 hooks: - id: pycln - args: [--all] + args: [--all, --exclude, "__init__.py$", --include, "^docetl/"] - repo: https://github.com/psf/black rev: 24.1.1 diff --git a/docetl/__init__.py b/docetl/__init__.py index 8c0e7d8a..bdb99b1e 100644 --- a/docetl/__init__.py +++ b/docetl/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.2.1" from docetl.runner import DSLRunner -from docetl.builder import Optimizer +from docetl.optimizer import Optimizer __all__ = ["DSLRunner", "Optimizer"] diff --git a/docetl/api.py b/docetl/api.py index 5ec744aa..1a5c2028 100644 --- a/docetl/api.py +++ b/docetl/api.py @@ -49,32 +49,31 @@ result = optimized_pipeline.run() """ -import os import inspect -from typing import Any, Dict, List, Optional, Callable, Union +import os +from typing import Any, Callable, Dict, List, Optional, Union import yaml from rich import print -from docetl.builder import Optimizer from docetl.runner import DSLRunner from docetl.schemas import ( + ClusterOp, Dataset, EquijoinOp, FilterOp, GatherOp, MapOp, - ReduceOp, - ResolveOp, - SplitOp, - UnnestOp, - ClusterOp, - SampleOp, OpType, ParallelMapOp, ParsingTool, PipelineOutput, PipelineStep, + ReduceOp, + ResolveOp, + SampleOp, + SplitOp, + UnnestOp, ) @@ -166,9 +165,10 @@ def __init__( self._load_env() def _load_env(self): - from dotenv import load_dotenv import os + from dotenv import load_dotenv + # Get the current working directory cwd = os.getcwd() diff --git a/docetl/base_schemas.py b/docetl/base_schemas.py index 88547a3c..0e678fd4 100644 --- a/docetl/base_schemas.py +++ b/docetl/base_schemas.py @@ -1,10 +1,11 @@ from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel # from ..operations import map # MapOp = map.MapOperation.schema + class ToolFunction(BaseModel): name: str description: str @@ -96,7 +97,8 @@ class PipelineStep(BaseModel): name: str operations: List[Union[Dict[str, Any], str]] input: Optional[str] = None - + + class PipelineOutput(BaseModel): """ Represents the output configuration for a pipeline. diff --git a/docetl/builder.py b/docetl/builder.py deleted file mode 100644 index d99ffcb4..00000000 --- a/docetl/builder.py +++ /dev/null @@ -1,1572 +0,0 @@ -import copy -import hashlib -import json -import math -import os -import random -from collections import Counter, defaultdict -from typing import Any, Dict, List, Optional, Tuple, Union - -import yaml -from docetl.utils import CapturedOutput -from rich.console import Console -from rich.panel import Panel -from rich.status import Status -from rich.traceback import install - -from docetl.dataset import Dataset, create_parsing_tool_map -from docetl.operations import get_operation -from docetl.operations.base import BaseOperation -from docetl.operations.utils import flush_cache -from docetl.optimizers.join_optimizer import JoinOptimizer -from docetl.optimizers.map_optimizer import MapOptimizer -from docetl.optimizers.reduce_optimizer import ReduceOptimizer -from docetl.optimizers.utils import LLMClient -from docetl.config_wrapper import ConfigWrapper - -install(show_locals=True) - -SUPPORTED_OPS = ["map", "resolve", "reduce", "equijoin", "filter"] -NUM_OPTIMIZER_RETRIES = 1 - -SAMPLE_SIZE_MAP = { - "reduce": 40, - "map": 5, - "resolve": 100, - "equijoin": 100, - "filter": 5, - "split": 100, - "gather": 100, -} - - -class DatasetOnDisk(dict): - def __init__(self, dir: str, console: Console): - self.dir = dir - self.console = console - - def __setitem__(self, key, value): - self._save_to_disk(key, value) - - def __getitem__(self, key): - with open(f"{self.dir}/{key}", "r") as f: - self.console.log(f"Loading dataset from disk... {key}") - return json.load(f) - - def _save_to_disk(self, save_suffix: str, value: Any): - with open(f"{self.dir}/{save_suffix}", "w") as f: - json.dump(value, f) - self.console.log( - f"[green]Saved intermediate results to disk at {self.dir}/{save_suffix}[/green]" - ) - - def __len__(self): - return len(os.listdir(self.dir)) - - def __iter__(self): - return iter(os.listdir(self.dir)) - - def __contains__(self, item): - return item in os.listdir(self.dir) - - def keys(self): - return os.listdir(self.dir) - - def values(self): - return [self[key] for key in self.keys()] - - def items(self): - return [(key, self[key]) for key in self.keys()] - - -class Optimizer: - - def __init__( - self, - runner: "DSLRunner", - model: str = "gpt-4o", - resume: bool = False, - timeout: int = 60, - ): - """ - Initialize the Optimizer class. - - This method sets up the optimizer with the given configuration file and parameters. - It loads the configuration, initializes the console for output, sets up the LLM client, - and prepares various attributes for optimization. - - Args: - yaml_file (str): Path to the YAML configuration file. - max_threads (Optional[int]): Maximum number of threads to use for parallel processing. - If None, it will be set to (number of CPUs * 4). - model (str): The name of the language model to use. Defaults to "gpt-4o". - resume (bool): Whether to resume optimization from a previous run. Defaults to False. - timeout (int): Timeout in seconds for operations. Defaults to 60. - - Attributes: - yaml_file_path (str): Stores the path to the YAML file. - config (Dict): Stores the loaded configuration from the YAML file. - console (Console): Rich console for formatted output. - optimized_config (Dict): A copy of the original config to be optimized. - llm_client (LLMClient): Client for interacting with the language model. - max_threads (int): Maximum number of threads for parallel processing. - operations_cost (float): Tracks the total cost of operations. - timeout (int): Timeout for operations in seconds. - selectivities (defaultdict): Stores selectivity information for operations. - Selectivity is the ratio of output size to input size for an operation. - It's used to estimate how much data will flow through the pipeline after - each operation, which helps in optimizing subsequent operations and - determining appropriate sample sizes. For example, a selectivity of 0.5 - means an operation halves the size of its input data. - datasets (Dict): Stores loaded datasets. - - The method also calls print_optimizer_config() to display the initial configuration. - """ - self.config = runner.config - self.console = runner.console - self.max_threads = runner.max_threads - - self.base_name = runner.base_name - self.yaml_file_suffix = runner.yaml_file_suffix - self.config = runner.config - self.runner = runner - self.status = runner.status - - self.optimized_config = copy.deepcopy(self.config) - self.llm_client = LLMClient(model) - self.operations_cost = 0 - self.timeout = timeout - self.selectivities = defaultdict(dict) - self.samples_taken = defaultdict(dict) - self.resume = resume - self.captured_output = CapturedOutput() - - home_dir = os.environ.get("DOCETL_HOME_DIR", os.path.expanduser("~")) - cache_dir = os.path.join(home_dir, f".docetl/cache/{runner.yaml_file_suffix}") - os.makedirs(cache_dir, exist_ok=True) - self.datasets = DatasetOnDisk(dir=cache_dir, console=self.console) - self.optimized_ops_path = f"{cache_dir}/optimized_ops" - - # Update sample size map - self.sample_size_map = SAMPLE_SIZE_MAP - if self.config.get("optimizer_config", {}).get("sample_sizes", {}): - self.sample_size_map.update(self.config["optimizer_config"]["sample_sizes"]) - - self.step_op_to_optimized_ops = {} - - self.print_optimizer_config() - - def find_operation(self, op_name: str, config: Optional[Dict] = None) -> Dict: - if not config: - config = self.config - for operation_config in config["operations"]: - if operation_config["name"] == op_name: - return operation_config - raise ValueError(f"Operation '{op_name}' not found in configuration.") - - def syntax_check(self): - """ - Perform a syntax check on all operations defined in the configuration. - - This method validates each operation by attempting to instantiate it. - If any operation fails to instantiate, a ValueError is raised. - - Raises: - ValueError: If any operation fails the syntax check. - """ - for operation_config in self.config["operations"]: - operation = operation_config["name"] - operation_type = operation_config["type"] - - try: - operation_class = get_operation(operation_type) - operation_class( - self.runner, - operation_config, - self.config.get("default_model", "gpt-4o-mini"), - self.max_threads, - console=self.console, - ) - except Exception as e: - raise ValueError( - f"Syntax check failed for operation '{operation}': {str(e)}" - ) - - self.console.log("[green]Syntax check passed for all operations.[/green]") - - def print_optimizer_config(self): - """ - Print the current configuration of the optimizer. - - This method uses the Rich console to display a formatted output of the optimizer's - configuration. It includes details such as the YAML file path, sample sizes for - different operation types, maximum number of threads, the language model being used, - and the timeout setting. - - The output is color-coded and formatted for easy readability, with a header and - separator lines to clearly delineate the configuration information. - """ - self.console.print(Panel.fit( - "[bold cyan]Optimizer Configuration[/bold cyan]\n" - f"[yellow]Sample Size:[/yellow] {self.sample_size_map}\n" - f"[yellow]Max Threads:[/yellow] {self.max_threads}\n" - f"[yellow]Model:[/yellow] {self.llm_client.model}\n" - f"[yellow]Timeout:[/yellow] {self.timeout} seconds", - title="Optimizer Configuration" - )) - - def compute_sample_size( - self, - step_name: str, - step_ops: List[str], - op_config: Dict[str, Any], - ) -> int: - """ - Compute the sample size necessary for optimizing given operation based on upstream operations. - - This method calculates an appropriate sample size for an operation, taking into - account the selectivities of upstream operations in the same step. It uses a - predefined sample size map (SAMPLE_SIZE_MAP) as a starting point. - - For example, if we have a 'map' operation with a default sample size of 10, - and one upstream operation with a selectivity of 0.5, the computed sample size for the upstream operation would be: - 10 / 0.5 = 20 - - This ensures that after applying the selectivity of the upstream operation, - we still have a representative sample size for the current operation. - - Args: - step_name (str): The name of the current step in the pipeline. - step_ops (List[str]): A list of all operations in the current step. - op_config (Dict[str, Any]): The configuration dictionary for the current operation. - - Returns: - int: The computed sample size for the operation. - - The method works as follows: - 1. If there are no upstream operations, it returns the default sample size for the operation type. - 2. Otherwise, it starts with the default sample size and adjusts it based on the selectivities - of upstream operations. - 3. It iterates through upstream operations in reverse order, dividing the sample size by - each operation's selectivity. - 4. The final result is rounded to the nearest integer. - - Raises: - ValueError: If the selectivity for any upstream operation is not found. - - Note: - - The method assumes that selectivities for all upstream operations have been - previously computed and stored in self.selectivities. - - The sample size is always at least 1, even after all adjustments. - """ - # If an equijoin, load the default. Equijoins are always first - if op_config.get("type") == "equijoin": - return SAMPLE_SIZE_MAP.get(op_config.get("type")) - - # If there are no upstream operations, use the default sample_size - upstream_ops = [] - for step_op in step_ops: - if step_op != op_config.get("name"): - if step_op in self.step_op_to_optimized_ops: - upstream_ops.extend(self.step_op_to_optimized_ops[step_op]) - else: - upstream_ops.append(step_op) - else: - break - - if len(upstream_ops) == 0: - return self.sample_size_map.get(op_config.get("type"), float("inf")) - - # Otherwise, compute the sample size based on the upstream operations - sample_size = self.sample_size_map.get(op_config.get("type"), 100) - for op in reversed(upstream_ops): - # Use the selectivity of the upstream operation to compute the sample size - if op not in self.selectivities[step_name]: - raise ValueError( - f"Selectivity for operation {op} not found in selectivities. Other ops are {self.selectivities[step_name]}" - ) - - sample_size = sample_size / self.selectivities[step_name].get(op) - - return int(math.ceil(sample_size)) - - def _insert_empty_resolve_operations(self): - """ - Determines whether to insert resolve operations in the pipeline. - - This method iterates through each step in the pipeline and checks if there's a reduce - operation that follows a map operation with no resolver in between. If such a case is - found, it synthesizes an empty resolver operation and inserts it into the pipeline. - - The method modifies the pipeline configuration in-place. - - Returns: - None - - Side effects: - - Modifies self.config["pipeline"]["steps"] by potentially inserting new resolve operations. - - Adds new resolve operations to self.config["operations"] if necessary. - """ - for i, step in enumerate(self.config["pipeline"]["steps"]): - operations = step.get("operations", []) - has_map = False - has_reduce = False - has_resolve = False - map_op = None - reduce_op = None - - for op in operations: - if isinstance(op, dict): - op = list(op.keys())[0] - op_config = self.find_operation(op) - op_type = op_config["type"] - if op_type == "map": - has_map = True - map_op = op - elif op_type == "reduce" and op_config.get("synthesize_resolve", True): - reduce_key = op_config.get("reduce_key", "_all") - if isinstance(reduce_key, str): - reduce_key = [reduce_key] - if "_all" not in reduce_key: - has_reduce = True - reduce_op = op - elif op_type == "resolve": - has_resolve = True - - if has_map and has_reduce and not has_resolve: - # Synthesize an empty resolver - self.console.log( - "[yellow]Synthesizing empty resolver operation:[/yellow]" - ) - self.console.log( - f" • [cyan]Reduce operation:[/cyan] [bold]{reduce_op}[/bold]" - ) - self.console.log(f" • [cyan]Step:[/cyan] [bold]{step['name']}[/bold]") - - new_resolve_op = f"synthesized_resolve_{i}" - reduce_key = self.find_operation(reduce_op).get("reduce_key") - if isinstance(reduce_key, str): - reduce_key = [reduce_key] - self.config["operations"].append( - { - "name": new_resolve_op, - "type": "resolve", - "empty": True, - "optimize": True, - "embedding_model": "text-embedding-3-small", - "resolution_model": self.config.get( - "default_model", "gpt-4o-mini" - ), - "comparison_model": self.config.get( - "default_model", "gpt-4o-mini" - ), - "_intermediates": { - "map_prompt": self.find_operation(map_op).get("prompt"), - "reduce_key": reduce_key, - }, - } - ) - - # Insert the new resolve operation before the reduce operation - reduce_index = next( - i - for i, op in enumerate(operations) - if self.find_operation(op).get("type") == "reduce" - ) - operations.insert(reduce_index, new_resolve_op) - - has_resolve = True - - self.config["pipeline"]["steps"][i]["operations"] = operations - - # Update the pipeline configuration - self.config["pipeline"]["steps"] = self.config["pipeline"]["steps"] - - def _add_map_prompts_to_reduce_operations(self): - """ - Add relevant map prompts to reduce operations based on their reduce keys. - - This method iterates through all map operations to create a dictionary mapping - output schema keys to map prompts. It then loops through reduce operations, - adding the relevant map prompts based on the reduce keys and output schema. - - Side effects: - - Modifies reduce operations in self.config["operations"] by adding map prompts. - """ - # Create a dictionary mapping output schema keys to map prompts - output_key_to_prompt = {} - for op_config in self.config["operations"]: - if op_config.get("type") == "map": - output_schema = op_config.get("output", {}).get("schema", {}) - prompt = op_config.get("prompt", "") - for key in output_schema.keys(): - output_key_to_prompt[key] = prompt - - # Add relevant map prompts to reduce operations - for op_config in self.config["operations"]: - if op_config.get("type") == "reduce": - reduce_keys = op_config.get("reduce_key", []) - if isinstance(reduce_keys, str): - reduce_keys = [reduce_keys] - - relevant_prompts = [] - for key in reduce_keys: - if key in output_key_to_prompt: - relevant_prompts.append(output_key_to_prompt[key]) - - if relevant_prompts: - op_config["_intermediates"] = op_config.get("_intermediates", {}) - op_config["_intermediates"]["last_map_prompt"] = relevant_prompts[ - -1 - ] - - def _load_optimized_ops(self): - """ - Load the optimized operations from disk. - """ - if os.path.exists(self.optimized_ops_path): - for filename in os.listdir(self.optimized_ops_path): - if filename.endswith(".json"): - original_op_name = filename[:-5] # Remove '.json' from the filename - with open( - os.path.join(self.optimized_ops_path, filename), "r" - ) as f: - optimized_ops = json.load(f) - - # Update the config with the optimized operations - if original_op_name in [ - op["name"] for op in self.config["operations"] - ]: - # Update the config with the optimized operations - # First, remove all operations that are already in the config with the same name - self.config["operations"] = [ - op - for op in self.config["operations"] - if op["name"] != original_op_name - ] - - for op in optimized_ops: - op["optimize"] = False - self.config["operations"].append(op) - - # Update the step operations - for step in self.config["pipeline"]["steps"]: - if original_op_name in step["operations"]: - index = step["operations"].index(original_op_name) - step["operations"] = ( - step["operations"][:index] - + [op["name"] for op in optimized_ops] - + step["operations"][index + 1 :] - ) - - self.console.log( - f"Loaded optimized operations for {original_op_name}" - ) - - self.console.log("[green]Finished loading optimized operations[/green]") - - # Print out the operations for each step - step_operations = [] - for step in self.config["pipeline"]["steps"]: - step_name = step.get("name") - operations = step.get("operations", []) - step_info = f"[cyan]Step: {step_name}[/cyan]\n" - - for op in operations: - if isinstance(op, dict): - op_name = list(op.keys())[0] - op_details = op[op_name] - step_info += f" - {op_name}: {op_details}\n" - else: - step_info += f" - {op}\n" - step_operations.append(step_info) - - self.console.print(Panel.fit( - "\n".join(step_operations), - title="[bold blue]Operations for each step" - )) - else: - self.console.log("[yellow]No optimized operations found[/yellow]") - - def should_optimize(self, step_name: str, op_name: str) -> Tuple[str, List[Dict[str, Any]], List[Dict[str, Any]], float]: - """ - Determine if an operation should be optimized. - We do this by running the operations on a sample of the input data and checking if the output is correct. - """ - self.console.rule("[bold cyan]Beginning Pipeline Optimization[/bold cyan]") - self.syntax_check() - - self._insert_empty_resolve_operations() - - for step in self.config["pipeline"]["steps"]: - self.captured_output.set_step(step.get("name")) - # Go through each operation in the step until we find the one we want to optimize - ops_run = [] - op_name_to_object = {name: self.find_operation(name) for name in step["operations"]} - for op_idx, operation in enumerate(step["operations"]): - if isinstance(operation, dict): - operation_name = list(operation.keys())[0] - operation_config = operation[operation_name] - else: - operation_name = operation - operation_config = {} - - op_object = self.find_operation(operation_name).copy() - op_object.update(operation_config) - op_object["name"] = operation_name - - # Run the pipeline - sample_size = self.compute_sample_size( - step.get("name"), step.get("operations"), op_object - ) - input_data = self._run_partial_step( - step, ops_run, sample_size, op_name_to_object - ) - output_data = input_data - - # If this is not the operation we want to optimize, just execute it and add to selectivities - if f"{step.get('name')}/{op_name}" != f"{step_name}/{op_name}" and op_object.get("empty", False): - output_data = self._run_operation(op_object, input_data, is_build=True) - self.selectivities[step.get("name")][op_name] = len(output_data) / len(input_data) - ops_run.append(op_name) - - # if this is the operation we want to optimize, invoke the optimizer's should_optimize method - else: - if op_object.get("type") == "map" or op_object.get("type") == "filter": - # Create instance of map optimizer - map_optimizer = MapOptimizer( - self, - self.config, - self.console, - self.llm_client, - self.max_threads, - self._run_operation, - timeout=self.timeout, - is_filter=op_object.get("type") == "filter", - ) - should_optimize_output, input_data, output_data = map_optimizer.should_optimize(op_object, input_data) - elif op_object.get("type") == "reduce": - reduce_optimizer = ReduceOptimizer( - self.runner, - self.config, - self.console, - self.llm_client, - self.max_threads, - self._run_operation, - ) - should_optimize_output, input_data, output_data = reduce_optimizer.should_optimize(op_object, input_data) - elif op_object.get("type") == "resolve": - resolve_optimizer = JoinOptimizer( - self.runner, - self.config, - op_object, - self.console, - self.llm_client, - self.max_threads, - target_recall=self.config.get("optimizer_config", {}) - .get("resolve", {}) - .get("target_recall", 0.95), - ) - _, should_optimize_output = resolve_optimizer.should_optimize(input_data) - - # if should_optimize_output is empty, then we should move to the reduce operation - if should_optimize_output == "": - continue - - # Return the string and operation cost - return should_optimize_output, input_data, output_data, self.operations_cost + self.llm_client.total_cost - - # Should not get here - raise ValueError("No operation to optimize found") - - def optimize(self) -> float: - """ - Optimize the entire pipeline defined in the configuration. - - This method is the main entry point for the optimization process. It iterates through - each step in the pipeline, optimizing from upstream to downstream, and constructs an - optimized version of the configuration. - - The optimization process includes: - 1. Iterating through each step in the pipeline, from upstream to downstream. - 2. Optimizing each step using the _optimize_step method. - 3. Updating the optimized configuration with the new operations and steps. - 4. Saving the optimized configuration to a file. - 5. Logging the total costs (agent cost, operations cost, and total cost). - - Returns: - None - - Side effects: - - Modifies self.optimized_config with the optimized pipeline and operations. - - Updates self.datasets with the results of each step. - - Calls _save_optimized_config to save the optimized configuration to a file. - - Logs cost information to the console. - - Raises: - ValueError: If a step in the pipeline does not have a name. - - Note: - - This method assumes that all necessary data and configurations are already - loaded and initialized in the Optimizer instance. - - The optimization process is performed step by step, from upstream to downstream, - with each step potentially depending on the results of previous steps. - """ - self.console.rule("[bold cyan]Beginning Pipeline Optimization[/bold cyan]") - - self.syntax_check() - - self._insert_empty_resolve_operations() - - # If resume is True, load the optimized operations from disk - if self.resume: - self._load_optimized_ops() - - for step in self.config["pipeline"]["steps"]: - step_name = step.get("name") - if not step_name: - raise ValueError( - "Step does not have a name. Each step must have a unique name." - ) - - optimized_step, step_operations, input_data = self._optimize_step(step) - old_op_names = [ - op - for op in step["operations"] - if op not in optimized_step["operations"] - ] - - # Remove all old_op_names from self.optimized_config["operations"] - self.optimized_config["operations"] = [ - op - for op in self.optimized_config["operations"] - if op["name"] not in old_op_names - ] - - for op in optimized_step["operations"]: - changed_op = False - op_name = list(op.keys())[0] if isinstance(op, dict) else op - for i, op_config in enumerate(self.optimized_config["operations"]): - if op_config["name"] == op_name: - self.optimized_config["operations"][i] = step_operations[op_name] - changed_op = True - if not changed_op: - self.optimized_config["operations"].append(step_operations[op_name]) - - self.optimized_config["pipeline"]["steps"] = [ - step - for step in self.optimized_config["pipeline"]["steps"] - if step["name"] != step_name - ] + [optimized_step] - - self.step_op_to_optimized_ops[step_name] = optimized_step["operations"] - - step_hash = ( - hashlib.md5( - json.dumps( - { - "step": [ - s - for s in self.optimized_config["pipeline"]["steps"] - if s["name"] == step_name - ][0], - "operations": [ - self.find_operation(list(op.keys())[0] if isinstance(op, dict) else op, self.optimized_config) - for op in optimized_step["operations"] - ], - } - ).encode() - ).hexdigest() - + ".json" - ) - # If the dataset already exists, skip the step - if step_hash in self.datasets: - continue - - flush_cache(self.console) - - if step_name in self.config.get("optimizer_config", {}).get( - "run_full_step", [] - ): - # Run the entire step - input_data = self._run_partial_step( - step, - step_operations, - float("inf"), # TODO: FIX THIS - ) - self.datasets[step_hash] = copy.deepcopy(input_data) - else: - self.datasets[step_hash] = copy.deepcopy(input_data) - - self.console.log( - f"[bold]Total agent cost: ${self.llm_client.total_cost:.2f}[/bold]" - ) - self.console.log( - f"[bold]Total operations cost: ${self.operations_cost:.2f}[/bold]" - ) - self.console.log( - f"[bold]Total cost: ${self.llm_client.total_cost + self.operations_cost:.2f}[/bold]" - ) - - return self.llm_client.total_cost + self.operations_cost - - def _run_partial_step( - self, - step: Dict[str, Any], - ops_to_run: List[str], - sample_size: int, - optimized_operations: Dict[str, Dict[str, Any]], - ) -> List[Dict[str, Any]]: - """ - Execute a partial step of the pipeline on a sample of the input data. - - This internal method runs a subset of operations for a given step on a sample - of the input data. It's used as part of the optimization process to evaluate - and optimize individual operations within a step. - - Args: - step (Dict[str, Any]): The step configuration dictionary. - ops_to_run (List[str]): List of operation names to execute in this partial step. - sample_size (int): The number of items to include in the input sample. - optimized_operations (Dict[str, Dict[str, Any]]): Dictionary of optimized operations. - - Returns: - List[Dict[str, Any]]: The output data after running the specified operations. - - The method performs the following steps: - 1. Retrieves a sample of the input data using _get_sample_data. - 2. For equijoin operations, it loads both left and right datasets. - 3. Iterates through the specified operations, running each on the input sample. - 4. Returns the final output after all specified operations have been applied. - - Note: - - The method handles both regular steps and equijoin steps differently. - - Raises: - Any exceptions raised by _get_sample_data or _run_operation methods. - """ - # Take the input data and run the operations in ops_to_run - # Return the output data - input_sample = self._get_sample_data(step.get("input"), None, sample_size) - - if step.get("input") is None: - join_op_name = list(step.get("operations")[0].keys())[0] - # this is an equijoin step, load left and right datasets - left_data = self._get_sample_data( - step.get("operations")[0][join_op_name].get("left"), None, sample_size - ) - right_data = self._get_sample_data( - step.get("operations")[0][join_op_name].get("right"), None, sample_size - ) - input_sample = {"left": left_data, "right": right_data} - - for op in ops_to_run: - op_object = optimized_operations[op] - if "name" not in op_object: - op_object["name"] = op - - input_sample = self._run_operation(op_object, input_sample) - return input_sample - - def _optimize_step( - self, step: Dict[str, Any] - ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: - """ - Optimize a single step in the pipeline. - - This method takes a step configuration and optimizes each operation within it. - It handles different types of operations, including those that require optimization - and those that don't. - - Args: - step (Dict[str, Any]): The configuration dictionary for the step to be optimized. - - Returns: - Tuple[Dict[str, Any], List[Dict[str, Any]], List[Dict[str, Any]]]: - - The optimized step configuration. - - A list of optimized operations. - - The output data after running all operations in the step. - - The method performs the following for each operation in the step: - 1. Extracts the operation configuration. - 2. Computes the appropriate sample size for the operation. - 3. Runs the operation on a sample of the input data. - 4. If the operation is optimizable and of a supported type, it calls the appropriate - optimization method (e.g., _optimize_map, _optimize_reduce). - 5. If not optimizable or not supported, it runs the operation as-is. - 6. Calculates and stores the selectivity of each operation. - 7. Updates the list of optimized operations and their configurations. - - The method uses rich console to provide status updates during the optimization process. - - Note: - - This method is a key part of the overall optimization process, focusing on - individual steps in the pipeline. - - It relies on several helper methods like _run_partial_step, compute_sample_size, - and various _optimize_* methods for specific operation types. - - When optimizing an operation in the step, all previous operations are run on the - sample size needed for the current operation. This ensures that the input to the - operation being optimized is representative of what it would receive in the full pipeline. - - Raises: - ValueError: If an unsupported operation type is encountered. - """ - self.captured_output.set_step(step.get("name")) - optimized_operations = {} - optimized_operation_names = [] - replacement_operations = {} # List from old op name to new ops - - for op_idx, operation in enumerate(step["operations"]): - if isinstance(operation, dict): - operation_name = list(operation.keys())[0] - operation_config = operation[operation_name] - else: - operation_name = operation - operation_config = {} - - op_object = self.find_operation(operation_name).copy() - op_object.update(operation_config) - op_object["name"] = operation_name - - # Run the pipeline - step_ops = [] - for step_op in step.get("operations"): - step_op_name = list(step_op.keys())[0] if isinstance(step_op, dict) else step_op - - if step_op_name in replacement_operations: - step_ops.extend(replacement_operations[step_op_name]) - else: - step_ops.append(step_op) - - # TODO: incorporate this into the optimizer to not run the most downstream operations - downstream_ops_exist = op_idx < len(step["operations"]) - 1 - - sample_size = self.compute_sample_size( - step.get("name"), step_ops, op_object - ) - input_data = self._run_partial_step( - step, optimized_operation_names, sample_size, optimized_operations - ) - - if ( - not op_object.get("optimize", False) # Default don't optimize - or op_object.get("type") not in SUPPORTED_OPS - ): - # If optimize is False or operation type is not supported, just use the operation without optimization - output_data = self._run_operation(op_object, input_data) - optimized_operations[operation_name] = op_object - if op_object.get("type") == "equijoin": - optimized_operation_names.append(operation) - else: - optimized_operation_names.append(operation_name) - - selectivity = len(output_data) / len(input_data) - self.selectivities[step.get("name")][operation_name] = selectivity - self.samples_taken[step.get("name")][operation_name] = sample_size - else: - sample_info = [] - if op_object.get("type") == "equijoin": - sample_info.extend([ - f"[yellow]Sample size (left): {len(input_data['left'])}", - f"[yellow]Sample size (right): {len(input_data['right'])}" - ]) - else: - sample_info.append(f"[yellow]Sample size: {len(input_data)}") - - # Get optimizer config for this operation type if it exists - optimizer_config = self.config.get("optimizer_config", {}).get(op_object["type"], {}) - - panel_content = "\n".join(sample_info) - if optimizer_config: - panel_content += "\n\n[cyan]Optimizer Config:[/cyan]" - for key, value in optimizer_config.items(): - panel_content += f"\n[cyan]{key}:[/cyan] {value}" - - self.console.print(Panel.fit( - panel_content, - title=f"[yellow]Optimizing {operation_name} (Type: {op_object['type']})" - )) - - # Use rich console status to indicate optimization of the operation - with self.console.status( - f"[bold blue]Optimizing operation: {operation_name} (Type: {op_object['type']})[/bold blue]" - ) as status: - self.status = status - - # Print the number of elements in input_data - self.console.rule( - f"[yellow]Optimizing operation {operation_name} (Type: {op_object['type']})[/yellow]" - ) - if op_object.get("type") == "equijoin": - self.console.log( - f"[yellow] Sample size (left): {len(input_data['left'])}[/yellow]" - ) - self.console.log( - f"[yellow] Sample size (right): {len(input_data['right'])}[/yellow]" - ) - else: - self.console.log( - f"[yellow] Sample size: {len(input_data)}[/yellow]" - ) - - # Run optimization - for retry in range( - self.config.get("optimizer_config", {}).get( - "num_retries", NUM_OPTIMIZER_RETRIES - ) - ): - try: - if op_object.get("type") == "map": - optimized_ops = self._optimize_map( - op_object, input_data - ) - elif op_object.get("type") == "filter": - optimized_ops = self._optimize_map( - op_object, input_data, is_filter=True - ) - elif op_object.get("type") == "reduce": - optimized_ops = self._optimize_reduce( - op_object, input_data, status - ) - elif op_object.get("type") == "resolve": - optimized_ops = self._optimize_resolve( - op_object, input_data - ) - elif op_object.get("type") == "equijoin": - ( - optimized_ops, - input_data, - new_left_name, - new_right_name, - ) = self._optimize_equijoin( - op_object, - next(iter(operation.values()))["left"], - next(iter(operation.values()))["right"], - input_data["left"], - input_data["right"], - status, - ) - else: - raise ValueError( - f"Unsupported operation type: {op_object['type']}" - ) - break # If successful, break out of the retry loop - except Exception as e: - if ( - retry - == self.config.get("optimizer_config", {}).get( - "num_retries", NUM_OPTIMIZER_RETRIES - ) - - 1 - ): - raise # If this was the last retry, re-raise the exception - self.console.log( - f"Optimization attempt {retry + 1} failed. Retrying..." - ) - - if self.status: - self.status.update( - f"[bold blue]Running optimized operation to estimate selectivities: {operation_name}[/bold blue]" - ) - - for op in optimized_ops: - op_name = op["name"] - optimized_operations[op_name] = op - if op.get("type") == "equijoin": - optimized_operation_names.append( - { - op_name: { - "left": new_left_name, - "right": new_right_name, - } - } - ) - else: - optimized_operation_names.append(op_name) - - old_input_data_size = len(input_data) - input_data = self._run_operation(op, input_data) - new_input_data_size = len(input_data) - selectivity = new_input_data_size / old_input_data_size - self.selectivities[step.get("name")][op_name] = selectivity - self.samples_taken[step.get("name")][op_name] = sample_size - - # Set replacement_operations - replacement_operations[op_object["name"]] = [ - o["name"] for o in optimized_ops - ] - - # Print new operator configs - if optimized_ops: - config_content = "[bold green]New op configurations:[/bold green]\n" - for op_name, op_config in optimized_operations.items(): - if op_name in [o["name"] for o in optimized_ops]: - config_content += f"[cyan]{op_name}:[/cyan] {json.dumps(op_config, indent=2)}\n" - - self.console.print(Panel.fit( - config_content, - title="Optimized Operations" - )) - - # Save the optimized operations to disk - os.makedirs(self.optimized_ops_path, exist_ok=True) - - for original_op, replacement_ops in replacement_operations.items(): - optimized_ops_list = [ - ( - optimized_operations[op_name] - if isinstance(op_name, str) - else { - list(op_name.keys())[0]: optimized_operations[ - list(op_name.keys())[0] - ] - } - ) - for op_name in replacement_ops - ] - - # Save to disk - optimized_op_file = os.path.join( - self.optimized_ops_path, f"{original_op}.json" - ) - with open(optimized_op_file, "w") as f: - json.dump(optimized_ops_list, f, indent=2) - - self.console.log( - f"[green]Saved optimized operations to {self.optimized_ops_path}[/green]" - ) - self.status = None - output_data = input_data - - optimized_step = step.copy() - - optimized_step["operations"] = optimized_operation_names - return optimized_step, optimized_operations, output_data - - def _get_sample_data( - self, dataset_name: str, op_config: Optional[Dict[str, Any]], sample_size: int - ) -> List[Dict[str, Any]]: - """ - Retrieve a sample of data from a specified dataset. - - This method loads data from either a previously processed dataset or from a file, - and returns a sample of the data based on the given sample size and operation configuration. - - Args: - dataset_name (str): The name of the dataset to sample from. - op_config (Optional[Dict[str, Any]]): The configuration of the operation to be performed. - This is used to determine if special sampling is needed. - sample_size (int): The desired size of the sample. If set to float('inf'), all data is returned. - - Returns: - List[Dict[str, Any]]: A list of dictionaries representing the sampled data. - - Raises: - ValueError: If the dataset is not found or if the dataset type is unsupported. - """ - if dataset_name is None: - return [] - - if any( - s["name"] == dataset_name - for s in self.optimized_config["pipeline"]["steps"] - ): - step = [ - s - for s in self.optimized_config["pipeline"]["steps"] - if s["name"] == dataset_name - ][0] - name_hash = ( - hashlib.md5( - json.dumps( - { - "step": step, - "operations": [ - self.find_operation(list(op.keys())[0] if isinstance(op, dict) else op, self.optimized_config) for op in step["operations"] - ], - } - ).encode() - ).hexdigest() - + ".json" - ) - else: - name_hash = None - - if name_hash and name_hash in self.datasets: - data = self.datasets[name_hash] - else: - dataset_config = self.config["datasets"].get(dataset_name) - if dataset_config is None: - raise ValueError( - f"Dataset '{dataset_name}' not found in config or previous steps." - ) - dataset = Dataset( - runner=self, - type=dataset_config["type"], - path_or_data=dataset_config["path"], - parsing=dataset_config.get("parsing", []), - user_defined_parsing_tool_map=self.runner.parsing_tool_map, - ) - data = dataset.load() - - if sample_size == float("inf"): - return data - - if op_config: - if op_config.get("type") == "reduce": - return self._get_reduce_sample( - data, op_config.get("reduce_key"), sample_size - ) - - if not self.config.get("optimizer_config", {}).get("random_sample", False): - return data[:sample_size] - - # Take the random 500 examples or all if less than 500 - initial_data = random.sample(data, min(500, len(data))) - - # Calculate counts for each example - char_counts = [len(str(item)) for item in initial_data] - total_counts = sum(char_counts) - - # Calculate weights based on word counts - weights = [count / total_counts for count in char_counts] - - # Perform weighted random sampling - return random.choices( - initial_data, weights=weights, k=min(sample_size, len(initial_data)) - ) - - def _get_reduce_sample( - self, data: List[Dict[str, Any]], reduce_key: str, sample_size: int - ) -> List[Dict[str, Any]]: - """ - Get a representative sample for a reduce operation. - - This method creates a sample that preserves the distribution of groups in the data, - focusing on the top 5 largest groups. It also generates and prints a histogram of group sizes. - - Args: - data (List[Dict[str, Any]]): The full dataset to sample from. - reduce_key (str): The key used for grouping in the reduce operation. - sample_size (int): The desired size of the sample. - - Returns: - List[Dict[str, Any]]: A list of dictionaries representing the sampled data. - """ - # Group data by reduce key - grouped_data = defaultdict(list) - for item in data: - grouped_data[item[reduce_key]].append(item) - - # Sort groups by size in descending order - sorted_groups = sorted( - grouped_data.items(), key=lambda x: len(x[1]), reverse=True - ) - - sample = [] - - # Take the top 5 groups - top_5_groups = sorted_groups[:5] - - # Calculate the total count of items in the top 5 groups - total_count = sum(len(items) for _, items in top_5_groups) - - sample = [] - for _, items in top_5_groups: - # Calculate the proportion of items to sample from this group - group_proportion = len(items) / total_count - group_sample_size = int(sample_size * group_proportion) - - # Sample from the group - if not self.config.get("optimizer_config", {}).get("random_sample", False): - group_sample = items[:group_sample_size] - else: - group_sample = random.sample( - items, min(group_sample_size, len(items)) - ) - - sample.extend(group_sample) - - # If we haven't reached the desired sample size, add more items randomly - if len(sample) < sample_size: - remaining_items = [ - item - for _, items in top_5_groups - for item in items - if item not in sample - ] - additional_sample = random.sample( - remaining_items, - min( - sample_size - len(sample), len(remaining_items) - ), - ) if self.config.get("optimizer_config", {}).get("random_sample", False) else remaining_items[:sample_size - len(sample)] - sample.extend(additional_sample) - - # Create a histogram of group sizes - group_sizes = [len(items) for _, items in grouped_data.items()] - size_counts = Counter(group_sizes) - - # Sort the sizes for a more readable output - sorted_sizes = sorted(size_counts.items()) - - # Replace the histogram logging with a panel - histogram_content = "[bold]Histogram of Group Sizes:[/bold]\n" - max_bar_width, max_count = 2, max(size_counts.values()) - for size, count in sorted_sizes[:5]: - normalized_count = int(count / max_count * max_bar_width) - bar = "█" * normalized_count - histogram_content += f"{size:3d}: {bar} ({count})\n" - - self.console.print(Panel.fit(histogram_content, title="Group Size Distribution")) - - return sample - - def _optimize_reduce( - self, - op_config: Dict[str, Any], - input_data: List[Dict[str, Any]], - status: Status, - ) -> List[Dict[str, Any]]: - """ - Optimize a reduce operation. - - This method creates a ReduceOptimizer instance and uses it to optimize the reduce operation. - - Args: - op_config (Dict[str, Any]): The configuration of the reduce operation. - input_data (List[Dict[str, Any]]): The input data for the reduce operation. - status (Status): The status object to update with the progress of the optimization. - - Returns: - List[Dict[str, Any]]: The optimized operation configuration. - """ - reduce_optimizer = ReduceOptimizer( - self.runner, - self.config, - self.console, - self.llm_client, - self.max_threads, - self._run_operation, - status=status, - ) - optimized_ops, _, cost = reduce_optimizer.optimize(op_config, input_data) - self.operations_cost += cost - return optimized_ops - - def _optimize_equijoin( - self, - op_config: Dict[str, Any], - left_name: str, - right_name: str, - left_data: List[Dict[str, Any]], - right_data: List[Dict[str, Any]], - status: Status, - ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]], str, str]: - """ - Optimize an equijoin operation. - - This method creates a JoinOptimizer instance and uses it to optimize the equijoin operation. - It updates the operation cost and runs the optimized operation. - If the LLM suggests a map transformation, it will optimize the map operation as its own step, and then go back to optimize the equijoin operation. - - Args: - op_config (Dict[str, Any]): The configuration of the equijoin operation. - left_name (str): The name of the left dataset. - right_name (str): The name of the right dataset. - left_data (List[Dict[str, Any]]): The left dataset for the join. - right_data (List[Dict[str, Any]]): The right dataset for the join. - - Returns: - Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]], str, str]: The optimized operation configuration, the new left and right datasets, and the new left and right names. - """ - max_iterations = 2 - new_left_name = left_name - new_right_name = right_name - for _ in range(max_iterations): - join_optimizer = JoinOptimizer( - self.runner, - self.config, - op_config, - self.console, - self.llm_client, - self.max_threads, - target_recall=self.config.get("optimizer_config", {}) - .get("equijoin", {}) - .get("target_recall", 0.95), - estimated_selectivity=self.config.get("optimizer_config", {}) - .get("equijoin", {}) - .get("estimated_selectivity", None), - status=status, - ) - optimized_config, cost, agent_results = join_optimizer.optimize_equijoin( - left_data, right_data - ) - self.operations_cost += cost - # Update the operation config with the optimized values - op_config.update(optimized_config) - - if not agent_results.get("optimize_map", False): - break # Exit the loop if no more map optimizations are necessary - - # Update the status to indicate we're optimizing a map operation - output_key = agent_results["output_key"] - if self.status: - self.status.update( - f"Optimizing map operation for {output_key} extraction to help with the equijoin" - ) - map_prompt = agent_results["map_prompt"] - dataset_to_transform = ( - left_data - if agent_results["dataset_to_transform"] == "left" - else right_data - ) - - # Create a new step for the map operation - map_operation = { - "name": f"synthesized_{output_key}_extraction", - "type": "map", - "prompt": map_prompt, - "model": self.config.get("default_model", "gpt-4o-mini"), - "output": {"schema": {output_key: "string"}}, - "optimize": False, - } - - # Optimize the map operation - if map_operation["optimize"]: - dataset_to_transform_sample = random.sample( - dataset_to_transform, self.sample_size_map.get("map") - ) if self.config.get("optimizer_config", {}).get("random_sample", False) else dataset_to_transform[:self.sample_size_map.get("map")] - optimized_map_operations = self._optimize_map( - map_operation, dataset_to_transform_sample - ) - else: - optimized_map_operations = [map_operation] - - new_step = { - "name": f"synthesized_{output_key}_extraction", - "input": ( - left_name - if agent_results["dataset_to_transform"] == "left" - else right_name - ), - "operations": [mo["name"] for mo in optimized_map_operations], - } - if agent_results["dataset_to_transform"] == "left": - new_left_name = new_step["name"] - else: - new_right_name = new_step["name"] - - for optimized_map_op in optimized_map_operations: - self.optimized_config["operations"].append(optimized_map_op) - - self.optimized_config["pipeline"]["steps"].append(new_step) - - # Now run the optimized map operation on the entire dataset_to_transform - for op in optimized_map_operations: - dataset_to_transform = self._run_operation(op, dataset_to_transform) - - # Update the appropriate dataset for the next iteration - if agent_results["dataset_to_transform"] == "left": - left_data = dataset_to_transform - else: - right_data = dataset_to_transform - - if self.status: - self.status.update( - f"Optimizing equijoin operation with {output_key} extraction" - ) - - # Pop off "left" and "right" from the op_config - op_config.pop("left") - op_config.pop("right") - return ( - [op_config], - {"left": left_data, "right": right_data}, - new_left_name, - new_right_name, - ) - - def _optimize_map( - self, - op_config: Dict[str, Any], - input_data: List[Dict[str, Any]], - is_filter: bool = False, - ) -> List[Dict[str, Any]]: - """ - Optimize a map operation. - - This method creates a MapOptimizer instance and uses it to optimize the map operation. - - Args: - op_config (Dict[str, Any]): The configuration of the map operation. - input_data (List[Dict[str, Any]]): The input data for the map operation. - is_filter (bool, optional): If True, the operation is a filter operation. Defaults to False. - - Returns: - List[Dict[str, Any]]: The optimized operation configuration. - """ - - map_optimizer = MapOptimizer( - self, - self.config, - self.console, - self.llm_client, - self.max_threads, - self._run_operation, - timeout=self.timeout, - is_filter=is_filter, - ) - - optimized_ops, _, cost = map_optimizer.optimize(op_config, input_data, self.config.get("optimizer_config", {}).get("map", {}).get("plan_types", ["chunk", "proj_synthesis", "glean"])) - self.operations_cost += cost - return optimized_ops - - def _optimize_resolve( - self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """ - Optimize a resolve operation. - - This method creates a JoinOptimizer instance and uses it to optimize the resolve operation. - It updates the operation cost and runs the optimized operation. - - Args: - op_config (Dict[str, Any]): The configuration of the resolve operation. - input_data (List[Dict[str, Any]]): The input data for the resolve operation. - - Returns: - List[Dict[str, Any]]: The optimized operation configuration. - """ - optimized_config, cost = JoinOptimizer( - self.runner, - self.config, - op_config, - self.console, - self.llm_client, - self.max_threads, - target_recall=self.config.get("optimizer_config", {}) - .get("resolve", {}) - .get("target_recall", 0.95), - ).optimize_resolve(input_data) - - if optimized_config.get("empty", False): - # Remove this operation from the pipeline and just return input data - return [], input_data - - self.operations_cost += cost - - # Update the operation config with the optimized values - op_config.update(optimized_config) - - return [op_config] - - def _run_operation( - self, - op_config: Dict[str, Any], - input_data: List[Dict[str, Any]], - return_instance: bool = False, - is_build: bool = False, - ) -> Union[List[Dict[str, Any]], Tuple[List[Dict[str, Any]], BaseOperation]]: - """ - Run a single operation based on its configuration. - - This method creates an instance of the appropriate operation class and executes it. - It also updates the total operation cost. - - Args: - op_config (Dict[str, Any]): The configuration of the operation to run. - input_data (List[Dict[str, Any]]): The input data for the operation. - return_instance (bool, optional): If True, return the operation instance along with the output data. - - Returns: - Union[List[Dict[str, Any]], Tuple[List[Dict[str, Any]], BaseOperation]]: - If return_instance is False, returns the output data. - If return_instance is True, returns a tuple of the output data and the operation instance. - """ - operation_class = get_operation(op_config["type"]) - - oc_kwargs = { - "runner": self.runner, - "config": op_config, - "default_model": self.config["default_model"], - "max_threads": self.max_threads, - "console": self.console, - "status": self.status, - } - operation_instance = operation_class(**oc_kwargs) - if op_config["type"] == "equijoin": - left_data = input_data["left"] - right_data = input_data["right"] - output_data, cost = operation_instance.execute(left_data, right_data) - elif op_config["type"] == "filter": - output_data, cost = operation_instance.execute(input_data, is_build) - else: - output_data, cost = operation_instance.execute(input_data) - self.operations_cost += cost - if return_instance: - return output_data, operation_instance - else: - return output_data - - # Recursively resolve all anchors and aliases - @staticmethod - def resolve_anchors(data): - """ - Recursively resolve all anchors and aliases in a nested data structure. - - This static method traverses through dictionaries and lists, resolving any YAML anchors and aliases. - - Args: - data: The data structure to resolve. Can be a dictionary, list, or any other type. - - Returns: - The resolved data structure with all anchors and aliases replaced by their actual values. - """ - if isinstance(data, dict): - return {k: Optimizer.resolve_anchors(v) for k, v in data.items()} - elif isinstance(data, list): - return [Optimizer.resolve_anchors(item) for item in data] - else: - return data - - def clean_optimized_config(self): - """ - Remove _intermediates from each operation in the optimized config - """ - # Create a copy of the optimized config to modify - config_to_save = self.optimized_config.copy() - - resolved_config = Optimizer.resolve_anchors(config_to_save) - - # Remove _intermediates from each operation in resolved_config - if "operations" in resolved_config: - for op_config in resolved_config["operations"]: - if "_intermediates" in op_config: - del op_config["_intermediates"] - if "recursively_optimize" in op_config: - del op_config["recursively_optimize"] - if "optimize" in op_config: - del op_config["optimize"] - - return resolved_config - - def save_optimized_config(self, optimized_config_path: str): - """ - Save the optimized configuration to a YAML file. - - This method creates a copy of the optimized configuration, resolves all anchors and aliases, - and saves it to a new YAML file. The new file name is based on the original file name with '_opt' appended. - """ - resolved_config = self.clean_optimized_config() - - with open(optimized_config_path, "w") as f: - yaml.safe_dump(resolved_config, f, default_flow_style=False, width=80) - self.console.log( - f"[green italic]💾 Optimized config saved to {optimized_config_path}[/green italic]" - ) - - -if __name__ == "__main__": - optimizer = Optimizer("workloads/medical/map.yaml", model="gpt-4o-mini") - optimizer.optimize() \ No newline at end of file diff --git a/docetl/cli.py b/docetl/cli.py index f5dcac92..8619eb98 100644 --- a/docetl/cli.py +++ b/docetl/cli.py @@ -1,16 +1,16 @@ +import os from pathlib import Path from typing import Optional -import os import typer +from dotenv import load_dotenv from docetl.operations.utils import clear_cache as cc from docetl.runner import DSLRunner -from dotenv import load_dotenv - app = typer.Typer() + @app.command() def build( yaml_file: Path = typer.Argument( diff --git a/docetl/config_wrapper.py b/docetl/config_wrapper.py index b9d20ccc..2a2c68b8 100644 --- a/docetl/config_wrapper.py +++ b/docetl/config_wrapper.py @@ -1,14 +1,16 @@ import datetime +import math import os -from docetl.console import get_console -from docetl.utils import decrypt, load_config -from typing import Any, Dict, List, Optional, Tuple, Union -from docetl.operations.utils import APIWrapper -import pyrate_limiter from inspect import isawaitable -import math +from typing import Dict, Optional + +import pyrate_limiter from rich.console import Console +from docetl.console import get_console +from docetl.operations.utils import APIWrapper +from docetl.utils import decrypt, load_config + class BucketCollection(pyrate_limiter.BucketFactory): def __init__(self, **buckets): @@ -108,6 +110,6 @@ def __init__( self.rate_limiter = pyrate_limiter.Limiter(bucket_factory, max_delay=math.inf) self.api = APIWrapper(self) - + def reset_env(self): os.environ = self._original_env diff --git a/docetl/console.py b/docetl/console.py index da1d6aa8..a80aff07 100644 --- a/docetl/console.py +++ b/docetl/console.py @@ -1,12 +1,16 @@ import os +import threading import time -from typing import Any, Optional, Tuple -from rich.console import Console from io import StringIO -import threading -import queue +from typing import Tuple + +from rich.console import Console, RenderableType +from rich.status import Status +from rich.style import StyleType + from docetl.utils import StageType, get_stage_description + class ThreadSafeConsole(Console): def __init__(self, *args, **kwargs): self.buffer = StringIO() @@ -17,28 +21,36 @@ def __init__(self, *args, **kwargs): self.optimizer_statuses = [] self.optimizer_rationale = None + def get_output(self): + # return self.export_text(styles=True) + value = self.buffer.getvalue() + self.buffer.truncate(0) + self.buffer.seek(0) + return value + def status( self, status: "RenderableType", *, spinner: str = "dots", spinner_style: "StyleType" = "status.spinner", - speed: float = 1.0, - refresh_per_second: float = 12.5, + speed: float = 0.1, # Much slower speed + refresh_per_second: float = 0.5, # Much slower refresh rate (every 2 seconds) ) -> "Status": - from rich.status import Status status_renderable = Status( status, - console=None, + console=self, spinner=spinner, spinner_style=spinner_style, speed=speed, refresh_per_second=refresh_per_second, ) return status_renderable - - def post_optimizer_rationale(self, should_optimize: bool, rationale: str, validator_prompt: str): + + def post_optimizer_rationale( + self, should_optimize: bool, rationale: str, validator_prompt: str + ): self.optimizer_rationale = (should_optimize, rationale, validator_prompt) def post_optimizer_status(self, stage: StageType): @@ -47,8 +59,11 @@ def post_optimizer_status(self, stage: StageType): def get_optimizer_progress(self) -> Tuple[str, float]: if len(self.optimizer_statuses) == 0: return ("Optimization starting...", 0) - - if len(self.optimizer_statuses) > 0 and self.optimizer_statuses[-1][0] == StageType.END: + + if ( + len(self.optimizer_statuses) > 0 + and self.optimizer_statuses[-1][0] == StageType.END + ): return (get_stage_description(StageType.END), 1) num_stages = len(StageType) - 1 @@ -84,11 +99,16 @@ def get_console(): if os.environ.get("USE_FRONTEND") == "true": return ThreadSafeConsole( force_terminal=True, - width=80, soft_wrap=True, highlight=False, + log_path=False, + color_system="truecolor", + width=120, + style="bright_white on black", + record=True, ) else: + class NoOpConsole(Console): def post_optimizer_status(self, *args, **kwargs): pass @@ -96,7 +116,8 @@ def post_optimizer_status(self, *args, **kwargs): def post_optimizer_rationale(self, *args, **kwargs): pass - return NoOpConsole() + return NoOpConsole(log_path=False) +# Create the console first DOCETL_CONSOLE = get_console() diff --git a/docetl/containers.py b/docetl/containers.py new file mode 100644 index 00000000..f36db994 --- /dev/null +++ b/docetl/containers.py @@ -0,0 +1,584 @@ +""" +This module contains the container classes used by the DSLRunner for pipeline execution. +The containers implement a pull-based execution model where operations are lazily evaluated +only when their outputs are needed by parent nodes. +""" + +import json +import math +import os +from typing import TYPE_CHECKING, Dict, List, Tuple + +from rich.panel import Panel + +from docetl.dataset import Dataset +from docetl.operations import get_operation +from docetl.operations.utils import flush_cache +from docetl.optimizers import JoinOptimizer, MapOptimizer, ReduceOptimizer +from docetl.utils import smart_sample + +if TYPE_CHECKING: + from docetl.runner import DSLRunner + +SUPPORTED_OPS = ["map", "resolve", "reduce", "equijoin", "filter"] +NUM_OPTIMIZER_RETRIES = 1 + + +class OpContainer: + """ + OpContainer implements a pull-based execution model for pipeline operations. Each container + represents a node in the execution DAG and lazily evaluates its operation only when its + output is requested by a parent node. + + Key features: + - Lazy evaluation: Operations only execute when their output is needed + - Transparent caching: Results can be cached and reused across pipeline runs + - Cost tracking: Each operation's execution cost is tracked and aggregated + + The pull-based model means that execution flows backwards through the DAG - when the final + node is asked for data, it recursively requests data from its children until reaching leaf + nodes (typically scan operations that load initial datasets). + """ + + def __init__(self, name: str, runner: "DSLRunner", config: Dict, **kwargs): + self.name = name + self.config = config + self.children = [] + self.parent = None + self.is_equijoin = config.get("type") == "equijoin" + self.runner = runner + self.selectivity = kwargs.get("selectivity", None) + if not self.selectivity: + # If it's a map or resolve or gather operation, we know the selectivity is 1 + if self.config.get("type") in [ + "map", + "parallel_map", + "code_map", + "resolve", + "gather", + ]: + self.selectivity = 1 + self.is_optimized = False + self.kwargs = kwargs + + def to_string(self) -> str: + return json.dumps(self.config, indent=2) + + def add_child(self, child: "OpContainer") -> None: + self.children.append(child) + child.parent = self + + def optimize(self): + """ + Optimize the next operation, to get a sample of size sample_size. + Along the way, we will replace this op container with the optimized op container. + + We do the following: + 1. Optimize the children + 2. Run the children to get the input data for optimizing this operation + 3. Optimize this operation and replace it with the optimized op containers + """ + # Return early if already optimized + if self.is_optimized: + return + + # optimize the children + for child in self.children: + child.optimize() + + # Figure out the sample size needed for this operation from the sample size map + # It may be None if the operation is not in the sample size map, which means we will get all the data + sample_size_needed = self.runner.optimizer.sample_size_map.get( + self.config["type"] + ) + # run the children to get the input data for optimizing this operation + input_data = [] + for child in self.children: + input_data.append( + child.next(is_build=True, sample_size_needed=sample_size_needed)[0] + ) + + # Optimize this operation if it's eligible for optimization + new_head_pointer = self + if self.config.get("optimize", False): + if self.config["type"] not in SUPPORTED_OPS: + self.runner.console.log( + f"[red]Operation {self.name} is not supported for optimization. Proceeding without optimizing it.[/red]" + ) + else: + # If this is a build operation, set the captured output + self.runner.optimizer.captured_output.set_step(self.name.split("/")[0]) + + # Print statistics for optimizing this operation + sample_info = [] + if self.config["type"] == "equijoin": + sample_info.extend( + [ + f"[yellow]Sample size (left): {len(input_data[0])}", + f"[yellow]Sample size (right): {len(input_data[1])}", + ] + ) + else: + sample_info.append(f"[yellow]Sample size: {len(input_data[0])}") + + # Get optimizer config for this operation type if it exists + optimizer_config = self.runner.config.get("optimizer_config", {}).get( + self.config["type"], {} + ) + + panel_content = "\n".join(sample_info) + if optimizer_config: + panel_content += "\n\n[cyan]Optimizer Config:[/cyan]" + for key, value in optimizer_config.items(): + panel_content += f"\n[cyan]{key}:[/cyan] {value}" + + self.runner.console.log( + Panel.fit( + panel_content, + title=f"[yellow]Optimizing {self.name} (Type: {self.config['type']})", + ) + ) + + # Use rich console status to indicate optimization of the operation + with self.runner.console.status( + f"[bold blue]Optimizing operation: {self.name} (Type: {self.config['type']})[/bold blue]" + ) as status: + self.runner.status = status + optimized_ops = [] + + # Run optimization + for retry in range( + self.runner.config.get("optimizer_config", {}).get( + "num_retries", NUM_OPTIMIZER_RETRIES + ) + ): + try: + if self.config.get("type") in ["map", "filter"]: + map_optimizer = MapOptimizer( + self.runner, + self.runner._run_operation, + is_filter=self.config["type"] == "filter", + ) + optimized_ops, _, cost = map_optimizer.optimize( + self.config, input_data[0] + ) + self.runner.total_cost += cost + elif self.config.get("type") == "reduce": + reduce_optimizer = ReduceOptimizer( + self.runner, + self.runner._run_operation, + ) + optimized_ops, _, cost = reduce_optimizer.optimize( + self.config, input_data[0] + ) + self.runner.total_cost += cost + elif self.config.get("type") == "resolve": + optimized_config, cost = JoinOptimizer( + self.runner, + self.config, + target_recall=self.runner.config.get( + "optimizer_config", {} + ) + .get("resolve", {}) + .get("target_recall", 0.95), + estimated_selectivity=self.runner.config.get( + "optimizer_config", {} + ) + .get("resolve", {}) + .get("estimated_selectivity", None), + ).optimize_resolve(input_data[0]) + op_config = self.config.copy() + op_config.update(optimized_config) + optimized_ops = ( + [op_config] + if not optimized_config.get("empty", False) + else [] + ) + self.runner.total_cost += cost + + elif self.config.get("type") == "equijoin": + op_config, new_steps, new_left_name, new_right_name = ( + self.runner.optimizer._optimize_equijoin( + self.config, + self.kwargs["left_name"], + self.kwargs["right_name"], + input_data[0], + input_data[1], + self.runner._run_operation, + ) + ) + # Set this current config to be op_config + self.config = op_config + + # Replace old op map + self.runner.op_container_map = { + k: v + for k, v in self.runner.op_container_map.items() + if k + not in [ + self.children[0].name, + self.children[1].name, + ] + } + + # Set the children to be scans of the new left and right names + curr_step_name = self.name.split("/")[0] + self.children[0].config = { + "type": "scan", + "name": f"scan_{new_left_name}", + "dataset_name": new_left_name, + } + self.children[0].name = ( + f"{curr_step_name}/scan_{new_left_name}" + ) + self.children[1].config = { + "type": "scan", + "name": f"scan_{new_right_name}", + "dataset_name": new_right_name, + } + self.children[1].name = ( + f"{curr_step_name}/scan_{new_right_name}" + ) + + # Replace in the op map + self.runner.op_container_map[ + f"{curr_step_name}/scan_{new_left_name}" + ] = self.children[0] + self.runner.op_container_map[ + f"{curr_step_name}/scan_{new_right_name}" + ] = self.children[1] + + # Find the child dataset name that changed (left or right) + left_changed = new_left_name != self.kwargs["left_name"] + if left_changed: + # Set the left to be the local last op container + local_last_op_container = self.children[0] + else: + # Set the right to be the local last op container + local_last_op_container = self.children[1] + + # Change the kwargs left and right names + self.kwargs["left_name"] = new_left_name + self.kwargs["right_name"] = new_right_name + + # Insert new containers before local_last_op_container's children and local_last_op_container + old_children = local_last_op_container.children + local_last_op_container.children = [] + + # Add the new steps and operations to the query plan + for step_name, step_obj, operations in reversed( + new_steps + ): + # Create the step boundary op container + step_boundary_container = StepBoundary( + f"{step_name}/boundary", + self.runner, + { + "type": "step_boundary", + "name": f"{step_name}/boundary", + }, + ) + self.runner.op_container_map[ + f"{step_name}/boundary" + ] = step_boundary_container + # Point the equijoin op container to this step boundary + local_last_op_container.add_child( + step_boundary_container + ) + + local_last_op_container = step_boundary_container + + # Create new op containers for each operation + for op in operations: + op_container = OpContainer( + f"{step_name}/{op['name']}", self.runner, op + ) + self.runner.op_container_map[ + f"{step_name}/{op['name']}" + ] = op_container + local_last_op_container.add_child(op_container) + local_last_op_container = op_container + + # Add a scan operation based on the input for the step op + scan_op_container = OpContainer( + f"{step_name}/scan_{step_obj['input']}", + self.runner, + { + "type": "scan", + "name": f"scan_{step_obj['input']}", + "dataset_name": step_obj["input"], + }, + ) + self.runner.op_container_map[ + f"{step_name}/scan_{step_obj['input']}" + ] = scan_op_container + local_last_op_container.add_child(scan_op_container) + local_last_op_container = scan_op_container + + # Set the local_last_op_container's children to the old children + for child in old_children: + local_last_op_container.add_child(child) + + else: + raise ValueError( + f"Unsupported operation type: {self.config['type']}" + ) + break # If successful, break out of the retry loop + except Exception as e: + if ( + retry + == self.runner.config.get("optimizer_config", {}).get( + "num_retries", NUM_OPTIMIZER_RETRIES + ) + - 1 + ): + raise # If this was the last retry, re-raise the exception + self.runner.console.log( + f"Optimization attempt {retry + 1} failed with error: {e}. Retrying..." + ) + + if len(optimized_ops) > 0: + # Replace this op container with the optimized op containers + # Since this is not an equijoin, we have only one child + old_children = self.children + self.children = [] + local_last_op_container = self.parent + local_last_op_container.children = [] + curr_step_name = self.name.split("/")[0] + + for idx, op in enumerate(list(reversed(optimized_ops))): + op_container = OpContainer( + f"{curr_step_name}/{op['name']}", self.runner, op + ) + if idx == 0: + new_head_pointer = op_container + + self.runner.op_container_map[ + f"{curr_step_name}/{op['name']}" + ] = op_container + local_last_op_container.add_child(op_container) + local_last_op_container = op_container + + for child in old_children: + local_last_op_container.add_child(child) + + # Figure out the sample size needed for this operation from the sample size map + # It may be None if the operation is not in the sample size map, which means we will get all the data + sample_size_needed = self.runner.optimizer.sample_size_map.get( + new_head_pointer.config["type"] + ) + # walk down the new head pointer and set the selectivities + queue = [new_head_pointer] if new_head_pointer.parent else [] + while queue: + curr_op = queue.pop(0) + if not curr_op.selectivity: + # Run the operation to set the selectivity + if len(curr_op.children) == 0: + # Selectivity is 1 because it's a scan + curr_op.selectivity = 1 + else: + # Just run the operation because next will set the selectivity + curr_op.next(is_build=True, sample_size_needed=sample_size_needed) + + # Set the curr op to be optimized + curr_op.is_optimized = True + + queue.extend(curr_op.children) + + # Checkpoint the optimized operations + self.runner.optimizer.checkpoint_optimized_ops() + + def next( + self, is_build: bool = False, sample_size_needed: int = None + ) -> Tuple[List[Dict], float, str]: + """ + Execute this operation and return its results. This is the core method implementing + the pull-based execution model. + + The execution follows these steps: + 1. Check for cached results in checkpoints + 2. If not cached, recursively request input data from child nodes + 3. Apply any configured sampling + 4. Execute the operation on the input data + 5. Cache results if checkpointing is enabled + + Returns: + Tuple[List[Dict], float, str]: A tuple containing: + - The operation's output data + - Total cost of this operation and its children + - Execution logs as a formatted string + """ + # Track cost and logs for this operation and its children + input_data = None + cost = 0.0 + this_op_cost = 0.0 + curr_logs = "" + input_len = None + + # If this is a build operation, check the sample cache first + if is_build: + cache_key = self.name + if cache_key in self.runner.optimizer.sample_cache: + cached_data, cached_sample_size = self.runner.optimizer.sample_cache[ + cache_key + ] + # If we have enough samples cached, use them + if not sample_size_needed or cached_sample_size >= sample_size_needed: + curr_logs += f"[green]✓[/green] Using cached {self.name} (sample size: {cached_sample_size})\n" + # Sample the cached data if needed + if sample_size_needed: + cached_data = smart_sample(cached_data, sample_size_needed) + + return cached_data, 0, curr_logs + + # Try to load from checkpoint if available + if not is_build: + attempted_input_data = self.runner._load_from_checkpoint_if_exists( + self.name.split("/")[0], self.name.split("/")[-1] + ) + if attempted_input_data is not None: + curr_logs += f"[green]✓[/green] Using cached {self.name}\n" + return attempted_input_data, 0, curr_logs + + # If there's a selectivity estimate, we need to take a sample of size sample_size_needed / selectivity + if self.selectivity and sample_size_needed: + input_sample_size_needed = int( + math.ceil(sample_size_needed / self.selectivity) + ) + else: + input_sample_size_needed = sample_size_needed + + # Clear any existing checkpoint before running + if self.runner.intermediate_dir: + checkpoint_path = os.path.join( + self.runner.intermediate_dir, + self.name.split("/")[0], + f"{self.name.split('/')[-1]}.json", + ) + if os.path.exists(checkpoint_path): + os.remove(checkpoint_path) + + # Handle equijoin operations which have two input streams + if self.is_equijoin: + assert ( + len(self.children) == 2 + ), "Equijoin should have left and right children" + left_data, left_cost, left_logs = self.children[0].next( + is_build, input_sample_size_needed + ) + right_data, right_cost, right_logs = self.children[1].next( + is_build, input_sample_size_needed + ) + cost += left_cost + right_cost + curr_logs += left_logs + right_logs + input_len = max(len(left_data), len(right_data)) + input_data = {"left_data": left_data, "right_data": right_data} + # Handle standard operations with single input + elif len(self.children) > 0: + input_data, input_cost, input_logs = self.children[0].next( + is_build, input_sample_size_needed + ) + cost += input_cost + curr_logs += input_logs + input_len = len(input_data) + + # Apply sampling if configured + if input_data and "sample" in self.config and not is_build: + input_data = input_data[: self.config["sample"]] + + # Execute the operation + with self.runner.console.status(f"Running {self.name}") as status: + self.runner.status = status + + cost_before_execution = self.runner.total_cost + + # Execute operation with appropriate inputs + output_data = self.runner._run_operation( + self.config, input_data, is_build=is_build + ) + + # Track costs and log execution + this_op_cost = self.runner.total_cost - cost_before_execution + cost += this_op_cost + if this_op_cost > 0: + build_indicator = "[yellow](build)[/yellow] " if is_build else "" + curr_logs += f"[green]✓[/green] {build_indicator}{self.name} (Cost: [green]${this_op_cost:.2f}[/green])\n" + else: + build_indicator = "[yellow](build)[/yellow] " if is_build else "" + curr_logs += f"[green]✓[/green] {build_indicator}{self.name}\n" + + # Save selectivity estimate + output_size = len(output_data) + self.selectivity = output_size / input_len if input_len else 1 + + # Cache the results if this is a build operation + if is_build: + self.runner.optimizer.sample_cache[self.name] = ( + output_data, + len(output_data), + ) + + # Truncate output data to the sample size needed + if sample_size_needed: + output_data = smart_sample(output_data, sample_size_needed) + + # Save checkpoint if enabled + if ( + not is_build + and self.runner.intermediate_dir + and self.name.split("/")[1] + in self.runner.step_op_hashes[self.name.split("/")[0]] + ): + self.runner._save_checkpoint( + self.name.split("/")[0], self.name.split("/")[-1], output_data + ) + + return output_data, cost, curr_logs + + def syntax_check(self) -> str: + operation = self.config["name"] + operation_type = self.config["type"] + + operation_class = get_operation(operation_type) + obj = operation_class( + self.runner, + self.config, + self.runner.default_model, + self.runner.max_threads, + self.runner.console, + self.runner.status, + ) + + # Do syntax check + obj.syntax_check() + + return f"[green]✓[/green] Operation '{operation}' ({operation_type})" + + +class StepBoundary(OpContainer): + def next( + self, is_build: bool = False, sample_size_needed: int = None + ) -> Tuple[List[Dict], float, str]: + + output_data, step_cost, step_logs = self.children[0].next( + is_build, sample_size_needed + ) + + # Print step logs only if not building + self.runner.datasets[self.name.split("/")[0]] = Dataset( + self, "memory", output_data + ) + if not is_build: + flush_cache(self.runner.console) + self.runner.console.log( + Panel.fit( + step_logs + + f"Step [cyan]{self.name}[/cyan] completed. Cost: [green]${step_cost:.2f}[/green]", + title=f"[bold blue]Step Execution: {self.name}[/bold blue]", + ) + ) + + return output_data, 0, "" + + def syntax_check(self) -> str: + return "" diff --git a/docetl/dataset.py b/docetl/dataset.py index 0e1e9e59..7826433b 100644 --- a/docetl/dataset.py +++ b/docetl/dataset.py @@ -1,10 +1,11 @@ import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Callable, Dict, List, Optional, Union, Literal +from typing import Any, Callable, Dict, List, Literal, Optional, Union + from pydantic import BaseModel -from docetl.parsing_tools import get_parser, get_parsing_tools from docetl.base_schemas import ParsingTool +from docetl.parsing_tools import get_parser, get_parsing_tools def create_parsing_tool_map( @@ -73,7 +74,7 @@ class schema(BaseModel): path: str source: str = "local" parsing: Optional[List[Dict[str, str]]] = None - + def __init__( self, runner, diff --git a/docetl/operations/__init__.py b/docetl/operations/__init__.py index badcc9af..651e3eee 100644 --- a/docetl/operations/__init__.py +++ b/docetl/operations/__init__.py @@ -10,7 +10,7 @@ from docetl.operations.split import SplitOperation from docetl.operations.sample import SampleOperation from docetl.operations.unnest import UnnestOperation - +from docetl.operations.scan import ScanOperation mapping = { "cluster": ClusterOperation, @@ -26,6 +26,7 @@ "split": SplitOperation, "sample": SampleOperation, "unnest": UnnestOperation, + "scan": ScanOperation } def get_operation(operation_type: str): diff --git a/docetl/operations/base.py b/docetl/operations/base.py index 761c4434..059487cb 100644 --- a/docetl/operations/base.py +++ b/docetl/operations/base.py @@ -5,12 +5,12 @@ from abc import ABC, ABCMeta, abstractmethod from typing import Dict, List, Optional, Tuple -from docetl.operations.utils import APIWrapper -from docetl.console import DOCETL_CONSOLE -from rich.console import Console -from rich.status import Status import jsonschema from pydantic import BaseModel +from rich.console import Console +from rich.status import Status + +from docetl.console import DOCETL_CONSOLE # FIXME: This should probably live in some utils module? @@ -32,7 +32,7 @@ def __new__(cls, *arg, **kw): class BaseOperation(ABC, metaclass=BaseOperationMeta): def __init__( self, - runner: "ConfigWrapper", + runner, config: Dict, default_model: str, max_threads: int, diff --git a/docetl/operations/cluster.py b/docetl/operations/cluster.py index 76f57cc4..3a2f1839 100644 --- a/docetl/operations/cluster.py +++ b/docetl/operations/cluster.py @@ -1,11 +1,12 @@ -import numpy as np -from jinja2 import Environment, Template from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple + +import numpy as np +from jinja2 import Template + from .base import BaseOperation -from .utils import RichLoopBar, strict_render from .clustering_utils import get_embeddings_for_clustering - +from .utils import RichLoopBar, strict_render class ClusterOperation(BaseOperation): @@ -101,9 +102,9 @@ def execute( ) tree = self.agglomerative_cluster_of_embeddings(input_data, embeddings) - + if "collapse" in self.config: - tree = self.collapse_tree(tree, collapse = self.config["collapse"]) + tree = self.collapse_tree(tree, collapse=self.config["collapse"]) self.prompt_template = Template(self.config["summary_prompt"]) cost += self.annotate_clustering_tree(tree) @@ -127,7 +128,7 @@ def build_tree(i): # res["embedding"] = list(embeddings[i]) return res return { - "children": [ + "children": [ build_tree(cl.children_[i - nsamples, 0]), build_tree(cl.children_[i - nsamples, 1]), ], @@ -139,37 +140,54 @@ def build_tree(i): def get_tree_distances(self, t): res = set() if "distance" in t: - res.update(set([t["distance"] - child["distance"] for child in t["children"] if "distance" in child])) + res.update( + set( + [ + t["distance"] - child["distance"] + for child in t["children"] + if "distance" in child + ] + ) + ) if "children" in t: for child in t["children"]: res.update(self.get_tree_distances(child)) return res - - def _collapse_tree(self, t, parent_dist = None, collapse = None): + + def _collapse_tree(self, t, parent_dist=None, collapse=None): if "children" in t: - if ( "distance" in t + if ( + "distance" in t and parent_dist is not None and collapse is not None - and parent_dist - t["distance"] < collapse): - return [grandchild - for child in t["children"] - for grandchild in self._collapse_tree(child, parent_dist=parent_dist, collapse=collapse)] + and parent_dist - t["distance"] < collapse + ): + return [ + grandchild + for child in t["children"] + for grandchild in self._collapse_tree( + child, parent_dist=parent_dist, collapse=collapse + ) + ] else: res = dict(t) - res["children"] = [grandchild - for idx, child in enumerate(t["children"]) - for grandchild in self._collapse_tree(child, parent_dist=t["distance"], collapse=collapse)] + res["children"] = [ + grandchild + for idx, child in enumerate(t["children"]) + for grandchild in self._collapse_tree( + child, parent_dist=t["distance"], collapse=collapse + ) + ] return [res] else: return [t] - - def collapse_tree(self, tree, collapse = None): + + def collapse_tree(self, tree, collapse=None): if collapse is not None: tree_distances = np.array(sorted(self.get_tree_distances(tree))) collapse = tree_distances[int(len(tree_distances) * collapse)] return self._collapse_tree(tree, collapse=collapse)[0] - def annotate_clustering_tree(self, t): if "children" in t: with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor: @@ -218,7 +236,9 @@ def validation_fn(response: Dict[str, Any]): else None ), verbose=self.config.get("verbose", False), - litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), + litellm_completion_kwargs=self.config.get( + "litellm_completion_kwargs", {} + ), ) total_cost += response.total_cost if response.validated: diff --git a/docetl/operations/code_operations.py b/docetl/operations/code_operations.py index 040275c9..2267135a 100644 --- a/docetl/operations/code_operations.py +++ b/docetl/operations/code_operations.py @@ -1,9 +1,11 @@ import os -from typing import Any, Dict, List, Optional, Tuple from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Tuple + from docetl.operations.base import BaseOperation from docetl.operations.utils import RichLoopBar + class CodeMapOperation(BaseOperation): class schema(BaseOperation.schema): type: str = "code_map" @@ -29,7 +31,9 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: transform_fn = namespace["transform"] results = [] - with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor: + with ThreadPoolExecutor( + max_workers=self.config.get("concurrent_thread_count", os.cpu_count()) + ) as executor: futures = [executor.submit(transform_fn, doc) for doc in input_data] pbar = RichLoopBar( range(len(futures)), @@ -40,15 +44,17 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: result = futures[i].result() if self.config.get("drop_keys"): result = { - k: v for k, v in result.items() + k: v + for k, v in result.items() if k not in self.config["drop_keys"] } doc = input_data[i] merged_result = {**doc, **result} results.append(merged_result) - + return results, 0.0 + class CodeReduceOperation(BaseOperation): class schema(BaseOperation.schema): type: str = "code_reduce" @@ -79,6 +85,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: if reduce_keys == ["_all"] or reduce_keys == "_all": grouped_data = [("_all", input_data)] else: + def get_group_key(item): return tuple(item[key] for key in reduce_keys) @@ -92,7 +99,9 @@ def get_group_key(item): grouped_data = list(grouped_data.items()) results = [] - with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor: + with ThreadPoolExecutor( + max_workers=self.config.get("concurrent_thread_count", os.cpu_count()) + ) as executor: futures = [executor.submit(reduce_fn, group) for _, group in grouped_data] pbar = RichLoopBar( range(len(futures)), @@ -101,7 +110,7 @@ def get_group_key(item): ) for i, (key, group) in zip(pbar, grouped_data): result = futures[i].result() - + # Apply pass-through at the group level if self.config.get("pass_through", False) and group: for k, v in group[0].items(): @@ -120,6 +129,7 @@ def get_group_key(item): return results, 0.0 + class CodeFilterOperation(BaseOperation): class schema(BaseOperation.schema): type: str = "code_filter" @@ -144,7 +154,9 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: filter_fn = namespace["transform"] results = [] - with ThreadPoolExecutor(max_workers=self.config.get('concurrent_thread_count', os.cpu_count())) as executor: + with ThreadPoolExecutor( + max_workers=self.config.get("concurrent_thread_count", os.cpu_count()) + ) as executor: futures = [executor.submit(filter_fn, doc) for doc in input_data] pbar = RichLoopBar( range(len(futures)), @@ -155,4 +167,4 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: should_keep = futures[i].result() if should_keep: results.append(input_data[i]) - return results, 0.0 \ No newline at end of file + return results, 0.0 diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index 274d30df..9047e1f9 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -7,21 +7,18 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from multiprocessing import Pool, cpu_count -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple -from docetl import console -from docetl.operations.utils import strict_render -from docetl.operations.utils.progress import RichLoopBar import numpy as np -from jinja2 import Template from litellm import model_cost -from rich.prompt import Confirm from rich.console import Console +from rich.prompt import Confirm from docetl.operations.base import BaseOperation +from docetl.operations.utils import strict_render +from docetl.operations.utils.progress import RichLoopBar from docetl.utils import completion_cost - # Global variables to store shared data _right_data = None _blocking_conditions = None @@ -39,6 +36,7 @@ def is_match(left_item: Dict[str, Any], right_item: Dict[str, Any]) -> bool: for condition in _blocking_conditions ) + # LLM-based comparison for blocked pairs def get_hashable_key(item: Dict) -> str: return json.dumps(item, sort_keys=True) @@ -73,7 +71,7 @@ class schema(BaseOperation.schema): blocking_keys: Optional[Dict[str, List[str]]] = None timeout: Optional[int] = None litellm_completion_kwargs: Dict[str, Any] = {} - + def compare_pair( self, comparison_prompt: str, @@ -101,7 +99,7 @@ def compare_pair( try: prompt = strict_render(comparison_prompt, {"left": item1, "right": item2}) except Exception as e: - self.console.print(f"[red]Error rendering prompt: {e}[/red]") + self.console.log(f"[red]Error rendering prompt: {e}[/red]") return False, 0 response = self.runner.api.call_llm( model, @@ -120,7 +118,7 @@ def compare_pair( response.response, {"is_match": "bool"} )[0] except Exception as e: - self.console.print(f"[red]Error parsing LLM response: {e}[/red]") + self.console.log(f"[red]Error parsing LLM response: {e}[/red]") return False, cost return output["is_match"], cost @@ -254,10 +252,7 @@ def execute( if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons: # Sample pairs based on cardinality and length sampled_pairs = stratified_length_sample( - blocked_pairs, - limit_comparisons, - sample_size=1000, - console=self.console + blocked_pairs, limit_comparisons, sample_size=1000, console=self.console ) # Calculate number of dropped pairs @@ -442,7 +437,7 @@ def get_embeddings( pbar = RichLoopBar( range(len(future_to_pair)), - desc=f"Comparing pairs", + desc="Comparing pairs", console=self.console, ) @@ -499,20 +494,20 @@ def estimate_length(items: List[Dict], sample_size: int = 1000) -> float: """ Estimates average document length in the relation. Returns a normalized score (0-1) representing relative document size. - + Args: items: List of dictionary items to analyze sample_size: Number of items to sample for estimation - + Returns: float: Normalized score based on average document length """ if not items: return 0.0 - + sample_size = min(len(items), sample_size) sample = random.sample(items, sample_size) - + def get_doc_length(doc: Dict) -> int: """Calculate total length of all string values in document""" total_len = 0 @@ -523,20 +518,20 @@ def get_doc_length(doc: Dict) -> int: # For nested structures, use their string representation total_len += len(str(value)) return total_len - + lengths = [get_doc_length(item) for item in sample] if not lengths: return 0.0 - + avg_length = sum(lengths) / len(lengths) return avg_length def stratified_length_sample( - blocked_pairs: List[Tuple[Dict, Dict]], + blocked_pairs: List[Tuple[Dict, Dict]], limit_comparisons: int, sample_size: int = 1000, - console: Console = None + console: Console = None, ) -> List[Tuple[Dict, Dict]]: """ Samples pairs stratified by the smaller cardinality relation, @@ -545,44 +540,45 @@ def stratified_length_sample( # Extract left and right items left_items = [left for left, _ in blocked_pairs] right_items = [right for _, right in blocked_pairs] - + # Estimate length for both relations left_length = estimate_length(left_items, sample_size) right_length = estimate_length(right_items, sample_size) - + # Group by the relation with estimated lower length use_left_as_key = left_length > right_length if console: longer_length = max(left_length, right_length) longer_side = "left" if left_length > right_length else "right" - console.log(f"Longer length is {longer_length:.2f} ({longer_side} side). Using {longer_side} to sample matches.") + console.log( + f"Longer length is {longer_length:.2f} ({longer_side} side). Using {longer_side} to sample matches." + ) groups = defaultdict(list) - + for left, right in blocked_pairs: key = get_hashable_key(left if use_left_as_key else right) value = (left, right) groups[key].append(value) - + # Sort each group by length of the other relation's item for key in groups: groups[key].sort( key=lambda x: len(x[1 if use_left_as_key else 0]), - reverse=True # Prioritize longer matches + reverse=True, # Prioritize longer matches ) - + # Calculate samples per group n_groups = len(groups) base_samples_per_group = limit_comparisons // n_groups extra_samples = limit_comparisons % n_groups - + # Sample from each group sampled_pairs = [] for i, (key, pairs) in enumerate(groups.items()): # Add one extra sample to early groups if we have remainder group_sample_size = min( - len(pairs), - base_samples_per_group + (1 if i < extra_samples else 0) + len(pairs), base_samples_per_group + (1 if i < extra_samples else 0) ) sampled_pairs.extend(pairs[:group_sample_size]) - - return sampled_pairs \ No newline at end of file + + return sampled_pairs diff --git a/docetl/operations/filter.py b/docetl/operations/filter.py index e48490a8..1f2ac62d 100644 --- a/docetl/operations/filter.py +++ b/docetl/operations/filter.py @@ -1,13 +1,10 @@ """The `FilterOperation` class is a subclass of `BaseOperation` that implements a filtering operation on input data using a language model.""" -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Tuple -from pydantic import Field - -from jinja2 import Template +from typing import Dict, List, Tuple from docetl.operations.map import MapOperation + class FilterOperation(MapOperation): class schema(MapOperation.schema): type: str = "filter" @@ -113,6 +110,7 @@ def execute( results, total_cost = super().execute(input_data) # Drop records with filter_key values that are False - results = [result for result in results if result[filter_key]] + if not is_build: + results = [result for result in results if result[filter_key]] return results, total_cost diff --git a/docetl/operations/gather.py b/docetl/operations/gather.py index ba6e0242..ea9e753b 100644 --- a/docetl/operations/gather.py +++ b/docetl/operations/gather.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple from docetl.operations.base import BaseOperation @@ -22,7 +22,7 @@ class schema(BaseOperation.schema): order_key: str peripheral_chunks: Dict[str, Any] doc_header_key: Optional[str] = None - + def __init__(self, *args: Any, **kwargs: Any) -> None: """ Initialize the GatherOperation. diff --git a/docetl/operations/link_resolve.py b/docetl/operations/link_resolve.py index 22ceded0..e5b6c2b8 100644 --- a/docetl/operations/link_resolve.py +++ b/docetl/operations/link_resolve.py @@ -1,23 +1,20 @@ -import random -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Tuple -import jinja2 from jinja2 import Template from rich.prompt import Confirm +from sklearn.metrics.pairwise import cosine_similarity from docetl.operations.base import BaseOperation -from docetl.operations.utils import RichLoopBar, rich_as_completed -from docetl.utils import completion_cost, extract_jinja_variables -from docetl.operations.utils import strict_render +from docetl.operations.utils import RichLoopBar, strict_render + from .clustering_utils import get_embeddings_for_clustering -from sklearn.metrics.pairwise import cosine_similarity -import numpy as np + class LinkResolveOperation(BaseOperation): def syntax_check(self) -> None: pass - + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: """ Executes the resolve links operation on the provided dataset. @@ -33,7 +30,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: return [], 0 self.prompt_template = Template(self.config["comparison_prompt"]) - + id_key = self.config.get("id_key", "title") link_key = self.config.get("link_key", "related_to") blocking_threshold = self.config.get("blocking_threshold") @@ -42,9 +39,8 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: # Note: We don't want to use text-embedding-3-small as it has bad performance on short texts... embedding_model = self.config.get("embedding_model", "text-embedding-ada-002") - item_by_id = {item[id_key]: item - for item in input_data} - + item_by_id = {item[id_key]: item for item in input_data} + id_values = set([item[id_key] for item in input_data]) link_values = set() @@ -53,33 +49,26 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: to_resolve = list(link_values - id_values) id_values = list(id_values) - + if not blocking_threshold and not blocking_conditions: # Prompt the user for confirmation if not Confirm.ask( - f"[yellow]Warning: No blocking keys or conditions specified. " - f"This may result in a large number of comparisons. " - f"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. " - f"Do you want to continue without blocking?[/yellow]", + "[yellow]Warning: No blocking keys or conditions specified. " + "This may result in a large number of comparisons. " + "We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. " + "Do you want to continue without blocking?[/yellow]", ): raise ValueError("Operation cancelled by user.") - id_embeddings, id_embedding_cost = get_embeddings_for_clustering( [{"key": value} for value in id_values], - { - "embedding_model": embedding_model, - "embedding_keys": ["key"] - }, - self.runner.api + {"embedding_model": embedding_model, "embedding_keys": ["key"]}, + self.runner.api, ) link_embeddings, link_embedding_cost = get_embeddings_for_clustering( [{"key": value} for value in to_resolve], - { - "embedding_model": embedding_model, - "embedding_keys": ["key"] - }, - self.runner.api + {"embedding_model": embedding_model, "embedding_keys": ["key"]}, + self.runner.api, ) similarity_matrix = cosine_similarity(link_embeddings, id_embeddings) @@ -88,12 +77,12 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: total_possible_comparisons = acceptable.shape[0] * acceptable.shape[1] comparisons_saved = total_possible_comparisons - acceptable.sum().sum() - + self.console.log( f"[green]Comparisons saved by blocking: {comparisons_saved} " f"({(comparisons_saved / total_possible_comparisons) * 100:.2f}%)[/green]" ) - + self.replacements = {} batch_size = self.config.get("compare_batch_size", 100) @@ -102,7 +91,8 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: futures = [] for link_idx in range(acceptable.shape[0]): for id_idx in range(acceptable.shape[1]): - if not acceptable[link_idx, id_idx]: continue + if not acceptable[link_idx, id_idx]: + continue id_value = id_values[id_idx] link_value = to_resolve[link_idx] @@ -111,11 +101,13 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: futures.append( executor.submit( self.compare, - link_idx = link_idx, - id_idx = id_idx, - link_value = link_value, - id_value = id_value, - item = item)) + link_idx=link_idx, + id_idx=id_idx, + link_value=link_value, + id_value=id_value, + item=item, + ) + ) total_cost = 0 pbar = RichLoopBar( @@ -127,24 +119,23 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: total_cost += futures[i].result() pbar.update(i) - self.console.log( f"[green]Number of replacements found: {len(self.replacements)} " f"({(len(self.replacements) / total_possible_comparisons) * 100:.2f}% of all comparisons)[/green]" ) - + for item in input_data: - item[link_key] = [self.replacements.get(value, value) - for value in item[link_key]] + item[link_key] = [ + self.replacements.get(value, value) for value in item[link_key] + ] return input_data, total_cost - + def compare(self, link_idx, id_idx, link_value, id_value, item): - prompt = strict_render(self.prompt_template, { - "link_value": link_value, - "id_value": id_value, - "item": item - }) + prompt = strict_render( + self.prompt_template, + {"link_value": link_value, "id_value": id_value, "item": item}, + ) schema = {"is_same": "bool"} @@ -178,7 +169,7 @@ def validation_fn(response: Dict[str, Any]): verbose=self.config.get("verbose", False), litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), ) - + if response.validated: output = self.runner.api.parse_llm_response( response.response, @@ -189,8 +180,3 @@ def validation_fn(response: Dict[str, Any]): self.replacements[link_value] = id_value return response.total_cost - - - - - diff --git a/docetl/operations/map.py b/docetl/operations/map.py index afb1256d..d3878031 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -5,18 +5,14 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Tuple, Union -from docetl.operations.utils import strict_render -from jinja2 import Environment, Template +from jinja2 import Template +from litellm.utils import ModelResponse +from pydantic import Field, field_validator from tqdm import tqdm -from docetl.operations.base import BaseOperation -from docetl.operations.utils import RichLoopBar from docetl.base_schemas import Tool, ToolFunction -from docetl.utils import completion_cost -from pydantic import Field, field_validator -from litellm.utils import ModelResponse - - +from docetl.operations.base import BaseOperation +from docetl.operations.utils import RichLoopBar, strict_render class MapOperation(BaseOperation): @@ -41,6 +37,7 @@ class schema(BaseOperation.schema): clustering_method: Optional[str] = None batch_prompt: Optional[str] = None litellm_completion_kwargs: Dict[str, Any] = {} + @field_validator("drop_keys") def validate_drop_keys(cls, v): if isinstance(v, str): @@ -75,7 +72,7 @@ def syntax_check(self) -> None: raise ValueError( "If 'drop_keys' is not specified, both 'prompt' and 'output' must be present in the configuration" ) - + if config.batch_prompt: try: template = Template(config.batch_prompt) @@ -111,7 +108,7 @@ def syntax_check(self) -> None: for tool in config.tools: try: tool_obj = Tool(**tool) - except Exception as e: + except Exception: raise TypeError("Tool must be a dictionary") if not (tool_obj.code and tool_obj.function): @@ -129,7 +126,6 @@ def syntax_check(self) -> None: ) self.gleaning_check() - def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: """ @@ -164,17 +160,23 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: if self.status: self.status.stop() - def _process_map_item(item: Dict, initial_result: Optional[Dict] = None) -> Tuple[Optional[Dict], float]: + def _process_map_item( + item: Dict, initial_result: Optional[Dict] = None + ) -> Tuple[Optional[Dict], float]: prompt = strict_render(self.config["prompt"], {"input": item}) def validation_fn(response: Union[Dict[str, Any], ModelResponse]): - output = self.runner.api.parse_llm_response( - response, - schema=self.config["output"]["schema"], - tools=self.config.get("tools", None), - manually_fix_errors=self.manually_fix_errors, - )[0] if isinstance(response, ModelResponse) else response + output = ( + self.runner.api.parse_llm_response( + response, + schema=self.config["output"]["schema"], + tools=self.config.get("tools", None), + manually_fix_errors=self.manually_fix_errors, + )[0] + if isinstance(response, ModelResponse) + else response + ) for key, value in item.items(): if key not in self.config["output"]["schema"]: output[key] = value @@ -205,7 +207,9 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]): verbose=self.config.get("verbose", False), bypass_cache=self.config.get("bypass_cache", False), initial_result=initial_result, - litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), + litellm_completion_kwargs=self.config.get( + "litellm_completion_kwargs", {} + ), ) if llm_result.validated: @@ -220,7 +224,6 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]): else: output = llm_result.response - # Augment the output with the original item output = {**item, **output} if self.config.get("enable_observability", False): @@ -228,12 +231,14 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]): return output, llm_result.total_cost return None, llm_result.total_cost - - # If there's a batch prompt, let's use that + + # If there's a batch prompt, let's use that def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]: total_cost = 0 if len(items) > 1 and self.config.get("batch_prompt", None): - batch_prompt = strict_render(self.config["batch_prompt"], {"inputs": items}) + batch_prompt = strict_render( + self.config["batch_prompt"], {"inputs": items} + ) # Issue the batch call llm_result = self.runner.api.call_llm_batch( @@ -243,23 +248,39 @@ def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]: self.config["output"]["schema"], verbose=self.config.get("verbose", False), timeout_seconds=self.config.get("timeout", 120), - max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), bypass_cache=self.config.get("bypass_cache", False), - litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), + litellm_completion_kwargs=self.config.get( + "litellm_completion_kwargs", {} + ), ) total_cost += llm_result.total_cost # Parse the LLM response - parsed_output = self.runner.api.parse_llm_response(llm_result.response, self.config["output"]["schema"])[0].get("results", []) - items_and_outputs = [(item, parsed_output[idx] if idx < len(parsed_output) else None) for idx, item in enumerate(items)] + parsed_output = self.runner.api.parse_llm_response( + llm_result.response, self.config["output"]["schema"] + )[0].get("results", []) + items_and_outputs = [ + (item, parsed_output[idx] if idx < len(parsed_output) else None) + for idx, item in enumerate(items) + ] else: items_and_outputs = [(item, None) for item in items] - # Run _process_map_item for each item + # Run _process_map_item for each item all_results = [] if len(items_and_outputs) > 1: with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor: - futures = [executor.submit(_process_map_item, items_and_outputs[i][0], items_and_outputs[i][1]) for i in range(len(items_and_outputs))] + futures = [ + executor.submit( + _process_map_item, + items_and_outputs[i][0], + items_and_outputs[i][1], + ) + for i in range(len(items_and_outputs)) + ] for i in range(len(futures)): try: result, item_cost = futures[i].result() @@ -268,19 +289,25 @@ def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]: total_cost += item_cost except Exception as e: if self.config.get("skip_on_error", False): - self.console.log(f"[bold red]Error in map operation {self.config['name']}, skipping item:[/bold red] {e}") - continue + self.console.log( + f"[bold red]Error in map operation {self.config['name']}, skipping item:[/bold red] {e}" + ) + continue else: raise e else: try: - result, item_cost = _process_map_item(items_and_outputs[0][0], items_and_outputs[0][1]) + result, item_cost = _process_map_item( + items_and_outputs[0][0], items_and_outputs[0][1] + ) if result is not None: all_results.append(result) total_cost += item_cost except Exception as e: if self.config.get("skip_on_error", False): - self.console.log(f"[bold red]Error in map operation {self.config['name']}, skipping item:[/bold red] {e}") + self.console.log( + f"[bold red]Error in map operation {self.config['name']}, skipping item:[/bold red] {e}" + ) else: raise e @@ -291,7 +318,7 @@ def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]: batch_size = self.max_batch_size if self.max_batch_size is not None else 1 futures = [] for i in range(0, len(input_data), batch_size): - batch = input_data[i:i + batch_size] + batch = input_data[i : i + batch_size] futures.append(executor.submit(_process_map_batch, batch)) results = [] total_cost = 0 @@ -304,11 +331,14 @@ def _process_map_batch(items: List[Dict]) -> Tuple[List[Dict], float]: result_list, item_cost = futures[i].result() if result_list: if "drop_keys" in self.config: - result_list = [{ - k: v - for k, v in result.items() - if k not in self.config["drop_keys"] - } for result in result_list] + result_list = [ + { + k: v + for k, v in result.items() + if k not in self.config["drop_keys"] + } + for result in result_list + ] results.extend(result_list) total_cost += item_cost @@ -470,7 +500,9 @@ def process_prompt(item, prompt_config): timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), bypass_cache=self.config.get("bypass_cache", False), - litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), + litellm_completion_kwargs=self.config.get( + "litellm_completion_kwargs", {} + ), ) output = self.runner.api.parse_llm_response( response.response, @@ -513,7 +545,9 @@ def process_prompt(item, prompt_config): if self.config.get("enable_observability", False): if f"_observability_{self.config['name']}" not in item_result: item_result[f"_observability_{self.config['name']}"] = {} - item_result[f"_observability_{self.config['name']}"].update({f"prompt_{prompt_index}": prompt}) + item_result[f"_observability_{self.config['name']}"].update( + {f"prompt_{prompt_index}": prompt} + ) # Update the item_result with the output item_result.update(output) diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 08178900..b727006d 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -17,16 +17,15 @@ import jinja2 import numpy as np from jinja2 import Template +from pydantic import Field from docetl.operations.base import BaseOperation -from docetl.operations.utils import strict_render from docetl.operations.clustering_utils import ( cluster_documents, get_embeddings_for_clustering, ) -from docetl.operations.utils import rich_as_completed +from docetl.operations.utils import rich_as_completed, strict_render from docetl.utils import completion_cost -from pydantic import Field class ReduceOperation(BaseOperation): @@ -332,7 +331,9 @@ def get_group_key(item): value = item[key] # Special handling for list-type values if isinstance(value, list): - key_values.append(tuple(sorted(value))) # Convert list to sorted tuple + key_values.append( + tuple(sorted(value)) + ) # Convert list to sorted tuple else: key_values.append(value) return tuple(key_values) @@ -389,10 +390,9 @@ def process_group( # Only execute merge-based plans if associative = True if "merge_prompt" in self.config and self.config.get("associative", True): result, prompts, cost = self._parallel_fold_and_merge(key, group_list) - elif ( - self.config.get("fold_batch_size", None) - and self.config.get("fold_batch_size") >= len(group_list) - ): + elif self.config.get("fold_batch_size", None) and self.config.get( + "fold_batch_size" + ) >= len(group_list): # If the fold batch size is greater than or equal to the number of items in the group, # we can just run a single fold operation result, prompt, cost = self._batch_reduce(key, group_list) @@ -410,9 +410,7 @@ def process_group( if self.config.get("enable_observability", False): # Add the _observability_{self.config['name']} key to the result - result[f"_observability_{self.config['name']}"] = { - "prompts": prompts - } + result[f"_observability_{self.config['name']}"] = {"prompts": prompts} # Apply pass-through at the group level if ( @@ -510,8 +508,10 @@ def _semantic_similarity_sampling( self, key: Tuple, group_list: List[Dict], value_sampling: Dict, sample_size: int ) -> Tuple[List[Dict], float]: embedding_model = value_sampling["embedding_model"] - query_text = strict_render(value_sampling["query_text"], {"reduce_key": dict(zip(self.config["reduce_key"], key))}) - + query_text = strict_render( + value_sampling["query_text"], + {"reduce_key": dict(zip(self.config["reduce_key"], key))}, + ) embeddings, cost = get_embeddings_for_clustering( group_list, value_sampling, self.runner.api @@ -557,6 +557,7 @@ def _parallel_fold_and_merge( merge_batch_size = self.config["merge_batch_size"] total_cost = 0 prompts = [] + def calculate_num_parallel_folds(): fold_time, fold_default = self.get_fold_time() merge_time, merge_default = self.get_merge_time() @@ -687,7 +688,11 @@ def calculate_num_parallel_folds(): fold_results = new_results - return (fold_results[0], prompts, total_cost) if fold_results else (None, prompts, total_cost) + return ( + (fold_results[0], prompts, total_cost) + if fold_results + else (None, prompts, total_cost) + ) def _incremental_reduce( self, key: Tuple, group_list: List[Dict] @@ -793,11 +798,14 @@ def _increment_fold( return self._batch_reduce(key, batch, scratchpad) start_time = time.time() - fold_prompt = strict_render(self.config["fold_prompt"], { - "inputs": batch, - "output": current_output, - "reduce_key": dict(zip(self.config["reduce_key"], key)) - }) + fold_prompt = strict_render( + self.config["fold_prompt"], + { + "inputs": batch, + "output": current_output, + "reduce_key": dict(zip(self.config["reduce_key"], key)), + }, + ) response = self.runner.api.call_llm( self.config.get("model", self.default_model), @@ -855,10 +863,13 @@ def _merge_results( the prompt used, and the cost of the merge operation. """ start_time = time.time() - merge_prompt = strict_render(self.config["merge_prompt"], { - "outputs": outputs, - "reduce_key": dict(zip(self.config["reduce_key"], key)) - }) + merge_prompt = strict_render( + self.config["merge_prompt"], + { + "outputs": outputs, + "reduce_key": dict(zip(self.config["reduce_key"], key)), + }, + ) response = self.runner.api.call_llm( self.config.get("model", self.default_model), "merge", @@ -961,10 +972,13 @@ def _batch_reduce( Tuple[Optional[Dict], str, float]: A tuple containing the reduced output (or None if processing failed), the prompt used, and the cost of the reduce operation. """ - prompt = strict_render(self.config["prompt"], { - "reduce_key": dict(zip(self.config["reduce_key"], key)), - "inputs": group_list - }) + prompt = strict_render( + self.config["prompt"], + { + "reduce_key": dict(zip(self.config["reduce_key"], key)), + "inputs": group_list, + }, + ) item_cost = 0 response = self.runner.api.call_llm( diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index f146bc03..e2b2881e 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -3,24 +3,18 @@ """ import random -import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Dict, List, Tuple, Optional, Union -import json -from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple -from docetl.operations.utils import strict_render import jinja2 from jinja2 import Template +from pydantic import Field from rich.prompt import Confirm -import math from docetl.operations.base import BaseOperation -from docetl.operations.utils import RichLoopBar, rich_as_completed +from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render from docetl.utils import completion_cost, extract_jinja_variables -from pydantic import Field - def find_cluster(item, cluster_map): while item != cluster_map[item]: @@ -81,11 +75,7 @@ def compare_pair( ): return True, 0, "" - - prompt = strict_render(comparison_prompt, { - "input1": item1, - "input2": item2 - }) + prompt = strict_render(comparison_prompt, {"input1": item1, "input2": item2}) response = self.runner.api.call_llm( model, "compare", @@ -247,7 +237,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: if observability_key not in item: item[observability_key] = { "comparison_prompts": [], - "resolution_prompt": None + "resolution_prompt": None, } blocking_keys = self.config.get("blocking_keys", []) @@ -259,10 +249,10 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: if not blocking_threshold and not blocking_conditions: # Prompt the user for confirmation if not Confirm.ask( - f"[yellow]Warning: No blocking keys or conditions specified. " - f"This may result in a large number of comparisons. " - f"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. " - f"Do you want to continue without blocking?[/yellow]", + "[yellow]Warning: No blocking keys or conditions specified. " + "This may result in a large number of comparisons. " + "We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. " + "Do you want to continue without blocking?[/yellow]", console=self.runner.console, ): raise ValueError("Operation cancelled by user.") @@ -320,7 +310,9 @@ def get_embeddings_batch( total_cost += sum(costs) # Generate all pairs to compare, ensuring no duplicate comparisons - def get_unique_comparison_pairs() -> Tuple[List[Tuple[int, int]], Dict[Tuple[str, ...], List[int]]]: + def get_unique_comparison_pairs() -> ( + Tuple[List[Tuple[int, int]], Dict[Tuple[str, ...], List[int]]] + ): # Create a mapping of values to their indices value_to_indices: Dict[Tuple[str, ...], List[int]] = {} for i, item in enumerate(input_data): @@ -354,7 +346,11 @@ def meets_blocking_conditions(pair: Tuple[int, int]) -> bool: is_match(input_data[i], input_data[j]) if blocking_conditions else False ) - blocked_pairs = list(filter(meets_blocking_conditions, comparison_pairs)) if blocking_conditions else comparison_pairs + blocked_pairs = ( + list(filter(meets_blocking_conditions, comparison_pairs)) + if blocking_conditions + else comparison_pairs + ) # Apply limit_comparisons to blocked pairs if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons: @@ -447,21 +443,21 @@ def merge_clusters(item1: int, item2: int) -> None: def auto_batch() -> int: # Maximum batch size limit for 4o-mini model M = 500 - + n = len(input_data) m = len(blocked_pairs) - + # https://www.wolframalpha.com/input?i=k%28k-1%29%2F2+%2B+%28n-k%29%28k-1%29+%3D+m%2C+solve+for+k # Two possible solutions for k: # k = -1/2 sqrt((1 - 2n)^2 - 8m) + n + 1/2 # k = 1/2 (sqrt((1 - 2n)^2 - 8m) + 2n + 1) - - discriminant = (1 - 2*n)**2 - 8*m - sqrt_discriminant = discriminant ** 0.5 - + + discriminant = (1 - 2 * n) ** 2 - 8 * m + sqrt_discriminant = discriminant**0.5 + k1 = -0.5 * sqrt_discriminant + n + 0.5 - k2 = 0.5 * (sqrt_discriminant + 2*n + 1) - + k2 = 0.5 * (sqrt_discriminant + 2 * n + 1) + # Take the maximum viable solution k = max(k1, k2) return M if k < 0 else min(int(k), M) @@ -479,11 +475,13 @@ def auto_batch() -> int: last_processed = 0 for i in pbar: batch_end = last_processed + batch_size - batch = blocked_pairs[last_processed : batch_end] + batch = blocked_pairs[last_processed:batch_end] # Filter pairs for the initial batch better_batch = [ - pair for pair in batch - if find_cluster(pair[0], cluster_map) == pair[0] and find_cluster(pair[1], cluster_map) == pair[1] + pair + for pair in batch + if find_cluster(pair[0], cluster_map) == pair[0] + and find_cluster(pair[1], cluster_map) == pair[1] ] # Expand better_batch if it doesn’t reach batch_size @@ -493,8 +491,10 @@ def auto_batch() -> int: next_batch = blocked_pairs[batch_end:next_end] better_batch.extend( - pair for pair in next_batch - if find_cluster(pair[0], cluster_map) == pair[0] and find_cluster(pair[1], cluster_map) == pair[1] + pair + for pair in next_batch + if find_cluster(pair[0], cluster_map) == pair[0] + and find_cluster(pair[1], cluster_map) == pair[1] ) # Update batch_end to prevent overlapping in the next loop @@ -524,18 +524,19 @@ def auto_batch() -> int: pair_costs += cost if is_match_result: merge_clusters(pair[0], pair[1]) - + if self.config.get("enable_observability", False): observability_key = f"_observability_{self.config['name']}" for idx in (pair[0], pair[1]): if observability_key not in input_data[idx]: input_data[idx][observability_key] = { "comparison_prompts": [], - "resolution_prompt": None + "resolution_prompt": None, } - input_data[idx][observability_key]["comparison_prompts"].append(prompt) + input_data[idx][observability_key][ + "comparison_prompts" + ].append(prompt) - pbar.update(last_processed//batch_size) total_cost += pair_costs # Collect final clusters @@ -553,10 +554,9 @@ def process_cluster(cluster): for item in cluster_items ] - - resolution_prompt = strict_render(self.config["resolution_prompt"], { - "inputs": cluster_items - }) + resolution_prompt = strict_render( + self.config["resolution_prompt"], {"inputs": cluster_items} + ) reduction_response = self.runner.api.call_llm( self.config.get("resolution_model", self.default_model), "reduce", @@ -575,7 +575,9 @@ def process_cluster(cluster): if self.config.get("validate", None) else None ), - litellm_completion_kwargs=self.config.get("litellm_completion_kwargs", {}), + litellm_completion_kwargs=self.config.get( + "litellm_completion_kwargs", {} + ), ) reduction_cost = reduction_response.total_cost @@ -585,7 +587,7 @@ def process_cluster(cluster): if observability_key not in item: item[observability_key] = { "comparison_prompts": [], - "resolution_prompt": None + "resolution_prompt": None, } item[observability_key]["resolution_prompt"] = resolution_prompt diff --git a/docetl/operations/sample.py b/docetl/operations/sample.py index 669b4de4..67052ab7 100644 --- a/docetl/operations/sample.py +++ b/docetl/operations/sample.py @@ -1,5 +1,7 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Tuple + import numpy as np + from docetl.operations.base import BaseOperation from docetl.operations.clustering_utils import get_embeddings_for_clustering diff --git a/docetl/operations/scan.py b/docetl/operations/scan.py new file mode 100644 index 00000000..8043e6a2 --- /dev/null +++ b/docetl/operations/scan.py @@ -0,0 +1,32 @@ +from typing import Dict, List, Tuple + +from docetl.operations.base import BaseOperation + + +class ScanOperation(BaseOperation): + class schema(BaseOperation.schema): + dataset_name: str + + def syntax_check(self) -> None: + """Validate the scan operation configuration.""" + super().syntax_check() + + def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: + """ + Execute the scan operation to load data from the configured source. + + Args: + input_data: Not used in scan operation + + Returns: + Tuple[List[Dict], float]: Loaded data and cost (0 for scan) + """ + + # Look in the runner.datasets objects + if self.config["dataset_name"] not in self.runner.datasets: + raise ValueError(f"Dataset {self.config['dataset_name']} not found") + + return ( + self.runner.datasets[self.config["dataset_name"]].load(), + 0.0, + ) # Scan has no LLM cost diff --git a/docetl/operations/split.py b/docetl/operations/split.py index c32407ea..01dbff5b 100644 --- a/docetl/operations/split.py +++ b/docetl/operations/split.py @@ -1,5 +1,5 @@ import uuid -from typing import Dict, List, Tuple, Any, Optional +from typing import Any, Dict, List, Optional, Tuple import tiktoken @@ -18,6 +18,7 @@ class SplitOperation(BaseOperation): - {name}_id: A unique identifier for each original document. - {name}_chunk_num: The sequential number of the chunk within its original document. """ + class schema(BaseOperation.schema): type: str = "split" split_key: str diff --git a/docetl/operations/unnest.py b/docetl/operations/unnest.py index a8bcbb38..d2b57640 100644 --- a/docetl/operations/unnest.py +++ b/docetl/operations/unnest.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Optional, Tuple from docetl.operations.base import BaseOperation @@ -63,7 +63,7 @@ class schema(BaseOperation.schema): expand_fields: Optional[List[str]] = None recursive: Optional[bool] = None depth: Optional[int] = None - + def syntax_check(self) -> None: """ Checks if the required configuration key is present in the operation's config. diff --git a/docetl/operations/utils/api.py b/docetl/operations/utils/api.py index 5943990e..d82d9bcf 100644 --- a/docetl/operations/utils/api.py +++ b/docetl/operations/utils/api.py @@ -4,18 +4,25 @@ import time from typing import Any, Dict, List, Optional -from litellm import completion, embedding, RateLimitError, ModelResponse +from litellm import ModelResponse, RateLimitError, completion, embedding +from rich import print as rprint from rich.console import Console -from .cache import cache, cache_key, freezeargs -from .llm import LLMResult, InvalidOutputError, timeout, truncate_messages -from .validation import safe_eval, convert_dict_schema_to_list_schema, get_user_input_for_schema, convert_val, strict_render from docetl.utils import completion_cost -from rich import print as rprint +from .cache import cache, cache_key, freezeargs +from .llm import InvalidOutputError, LLMResult, timeout, truncate_messages +from .validation import ( + convert_dict_schema_to_list_schema, + convert_val, + get_user_input_for_schema, + safe_eval, + strict_render, +) BASIC_MODELS = ["gpt-4o-mini", "gpt-4o"] + class APIWrapper(object): def __init__(self, runner): self.runner = runner @@ -61,7 +68,7 @@ def gen_embedding(self, model: str, input: List[str]) -> List[float]: c.set(key, result) return result - + def call_llm_batch( self, model: str, @@ -76,10 +83,19 @@ def call_llm_batch( ) -> LLMResult: # Turn the output schema into a list of schemas output_schema = convert_dict_schema_to_list_schema(output_schema) - + # Invoke the LLM call - return self.call_llm(model, op_type,messages, output_schema, verbose=verbose, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=bypass_cache, litellm_completion_kwargs=litellm_completion_kwargs) - + return self.call_llm( + model, + op_type, + messages, + output_schema, + verbose=verbose, + timeout_seconds=timeout_seconds, + max_retries_per_timeout=max_retries_per_timeout, + bypass_cache=bypass_cache, + litellm_completion_kwargs=litellm_completion_kwargs, + ) def _cached_call_llm( self, @@ -129,7 +145,13 @@ def _cached_call_llm( else: if not initial_result: response = self._call_llm_with_cache( - model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs + model, + op_type, + messages, + output_schema, + tools, + scratchpad, + litellm_completion_kwargs, ) total_cost += completion_cost(response) else: @@ -139,9 +161,11 @@ def _cached_call_llm( # Retry gleaning prompt + regular LLM num_gleaning_rounds = gleaning_config.get("num_rounds", 2) - parsed_output = self.parse_llm_response( - response, output_schema, tools - )[0] if isinstance(response, ModelResponse) else response + parsed_output = ( + self.parse_llm_response(response, output_schema, tools)[0] + if isinstance(response, ModelResponse) + else response + ) validator_messages = ( [ @@ -156,7 +180,10 @@ def _cached_call_llm( for rnd in range(num_gleaning_rounds): # Prepare validator prompt - validator_prompt = strict_render(gleaning_config["validation_prompt"], {"output": parsed_output}) + validator_prompt = strict_render( + gleaning_config["validation_prompt"], + {"output": parsed_output}, + ) self.runner.rate_limiter.try_acquire("llm_call", weight=1) # Get params for should refine @@ -197,7 +224,9 @@ def _cached_call_llm( # Parse the validator response suggestion = json.loads( - validator_response.choices[0].message.tool_calls[0].function.arguments + validator_response.choices[0] + .message.tool_calls[0] + .function.arguments ) if not suggestion["should_refine"]: break @@ -219,7 +248,13 @@ def _cached_call_llm( # Call LLM again response = self._call_llm_with_cache( - model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs + model, + op_type, + messages, + output_schema, + tools, + scratchpad, + litellm_completion_kwargs, ) parsed_output = self.parse_llm_response( response, output_schema, tools @@ -269,7 +304,13 @@ def _cached_call_llm( i += 1 response = self._call_llm_with_cache( - model, op_type, messages, output_schema, tools, scratchpad, litellm_completion_kwargs + model, + op_type, + messages, + output_schema, + tools, + scratchpad, + litellm_completion_kwargs, ) total_cost += completion_cost(response) @@ -323,7 +364,14 @@ def call_llm( Raises: TimeoutError: If the call times out after retrying. """ - key = cache_key(model, op_type, messages, output_schema, scratchpad, self.runner.config.get("system_prompt", {})) + key = cache_key( + model, + op_type, + messages, + output_schema, + scratchpad, + self.runner.config.get("system_prompt", {}), + ) max_retries = max_retries_per_timeout attempt = 0 @@ -444,9 +492,15 @@ def _call_llm_with_cache( tools = None tool_choice = None - persona = self.runner.config.get("system_prompt", {}).get("persona", "a helpful assistant") - dataset_description = self.runner.config.get("system_prompt", {}).get("dataset_description", "a collection of unstructured documents") - parethetical_op_instructions = "many inputs:one output" if op_type == "reduce" else "one input:one output" + persona = self.runner.config.get("system_prompt", {}).get( + "persona", "a helpful assistant" + ) + dataset_description = self.runner.config.get("system_prompt", {}).get( + "dataset_description", "a collection of unstructured documents" + ) + parethetical_op_instructions = ( + "many inputs:one output" if op_type == "reduce" else "one input:one output" + ) system_prompt = f"You are a {persona}, helping the user make sense of their data. The dataset description is: {dataset_description}. You will be performing a {op_type} operation ({parethetical_op_instructions}). You will perform the specified task on the provided data, as precisely and exhaustively (i.e., high recall) as possible. The result should be a structured output that you will send back to the user, with the `send_output` function. Do not influence your answers too much based on the `send_output` function parameter names; just use them to send the result back to the user." if scratchpad: @@ -481,7 +535,6 @@ def _call_llm_with_cache( Your main result must be sent via send_output. The updated_scratchpad is only for tracking state between batches, and should be null unless you specifically need to track frequencies.""" - # Truncate messages if they exceed the model's context length messages = truncate_messages(messages, model) @@ -504,8 +557,10 @@ def _call_llm_with_cache( except Exception as e: # Check that there's a prefix for the model name if it's not a basic model if model not in BASIC_MODELS: - if not "/" in model: - raise ValueError(f"Note: You may also need to prefix your model name with the provider, e.g. 'openai/gpt-4o-mini' or 'gemini/gemini-1.5-flash' to conform to LiteLLM API standards. Original error: {e}") + if "/" not in model: + raise ValueError( + f"Note: You may also need to prefix your model name with the provider, e.g. 'openai/gpt-4o-mini' or 'gemini/gemini-1.5-flash' to conform to LiteLLM API standards. Original error: {e}" + ) raise e else: try: @@ -523,11 +578,12 @@ def _call_llm_with_cache( except Exception as e: # Check that there's a prefix for the model name if it's not a basic model if model not in BASIC_MODELS: - if not "/" in model: - raise ValueError(f"Note: You may also need to prefix your model name with the provider, e.g. 'openai/gpt-4o-mini' or 'gemini/gemini-1.5-flash' to conform to LiteLLM API standards. Original error: {e}") + if "/" not in model: + raise ValueError( + f"Note: You may also need to prefix your model name with the provider, e.g. 'openai/gpt-4o-mini' or 'gemini/gemini-1.5-flash' to conform to LiteLLM API standards. Original error: {e}" + ) raise e - return response def parse_llm_response( @@ -643,13 +699,13 @@ def _parse_llm_response_helper( continue try: output_dict[key] = ast.literal_eval(value) - except: + except Exception: try: if value.startswith("["): output_dict[key] = ast.literal_eval(value + "]") else: output_dict[key] = value - except: + except Exception: pass outputs.append(output_dict) except json.JSONDecodeError: @@ -698,4 +754,4 @@ def validate_output(self, operation: Dict, output: Dict, console: Console) -> bo console.log(f"[bold red]Validation error:[/bold red] {str(e)}") console.log(f"[yellow]Output:[/yellow] {output}") return False - return True \ No newline at end of file + return True diff --git a/docetl/operations/utils/cache.py b/docetl/operations/utils/cache.py index 51625e5a..2dc3cb9f 100644 --- a/docetl/operations/utils/cache.py +++ b/docetl/operations/utils/cache.py @@ -3,26 +3,31 @@ import json import os import shutil -from typing import Any, Dict, List -from frozendict import frozendict +from typing import Dict, List + from diskcache import Cache -from rich.console import Console from dotenv import load_dotenv +from frozendict import frozendict +from rich.console import Console from docetl.console import DOCETL_CONSOLE load_dotenv() -DOCETL_HOME_DIR = os.environ.get("DOCETL_HOME_DIR", os.path.expanduser("~"))+"/.cache/docetl" +DOCETL_HOME_DIR = ( + os.environ.get("DOCETL_HOME_DIR", os.path.expanduser("~")) + "/.cache/docetl" +) CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "general") LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm") cache = Cache(LLM_CACHE_DIR) cache.close() + def freezeargs(func): """ Decorator to convert mutable dictionary arguments into immutable. """ + @functools.wraps(func) def wrapped(*args, **kwargs): args = tuple( @@ -42,14 +47,17 @@ def wrapped(*args, **kwargs): for k, v in kwargs.items() } return func(*args, **kwargs) + return wrapped + def flush_cache(console: Console = DOCETL_CONSOLE): """Flush the cache to disk.""" console.log("[bold green]Flushing cache to disk...[/bold green]") cache.close() console.log("[bold green]Cache flushed to disk.[/bold green]") + def clear_cache(console: Console = DOCETL_CONSOLE): """Clear the LLM cache stored on disk.""" console.log("[bold yellow]Clearing LLM cache...[/bold yellow]") @@ -67,14 +75,17 @@ def clear_cache(console: Console = DOCETL_CONSOLE): elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: - console.log(f"[bold red]Error deleting {file_path}: {str(e)}[/bold red]") + console.log( + f"[bold red]Error deleting {file_path}: {str(e)}[/bold red]" + ) console.log("[bold green]Cache cleared successfully.[/bold green]") except Exception as e: console.log(f"[bold red]Error clearing cache: {str(e)}[/bold red]") + def cache_key( model: str, - op_type: str, + op_type: str, messages: List[Dict[str, str]], output_schema: Dict[str, str], scratchpad: str = None, @@ -89,4 +100,4 @@ def cache_key( "scratchpad": scratchpad, "system_prompt": json.dumps(system_prompt, sort_keys=True), } - return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode()).hexdigest() \ No newline at end of file + return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode()).hexdigest() diff --git a/docetl/operations/utils/llm.py b/docetl/operations/utils/llm.py index e1fcd3fd..484d0981 100644 --- a/docetl/operations/utils/llm.py +++ b/docetl/operations/utils/llm.py @@ -1,23 +1,24 @@ -import ast import json import threading -import time from typing import Any, Dict, List, Optional + import tiktoken -from jinja2 import Template from litellm import model_cost from pydantic import BaseModel from rich import print as rprint -from docetl.utils import completion_cost, count_tokens +from docetl.utils import count_tokens + class LLMResult(BaseModel): response: Any total_cost: float validated: bool + class InvalidOutputError(Exception): """Custom exception raised when the LLM output is invalid or cannot be parsed.""" + def __init__( self, message: str, @@ -42,6 +43,7 @@ def __str__(self): f"Tool calls generated by LLM: {self.tools}" ) + def timeout(seconds): def decorator(func): def wrapper(*args, **kwargs): @@ -61,12 +63,12 @@ def target(): return result[0] return wrapper + return decorator + def truncate_messages( - messages: List[Dict[str, str]], - model: str, - from_agent: bool = False + messages: List[Dict[str, str]], model: str, from_agent: bool = False ) -> List[Dict[str, str]]: """Truncate messages to fit within model's context length.""" model_input_context_length = model_cost.get(model.split("/")[-1], {}).get( @@ -86,7 +88,7 @@ def truncate_messages( encoder = tiktoken.encoding_for_model(model.split("/")[-1]) except Exception: encoder = tiktoken.encoding_for_model("gpt-4o") - + encoded_content = encoder.encode(content) tokens_to_remove = min(len(encoded_content), excess_tokens) mid_point = len(encoded_content) // 2 @@ -104,4 +106,4 @@ def truncate_messages( ) longest_message["content"] = truncated_content - return truncated_messages \ No newline at end of file + return truncated_messages diff --git a/docetl/operations/utils/progress.py b/docetl/operations/utils/progress.py index b78f62fc..1723c734 100644 --- a/docetl/operations/utils/progress.py +++ b/docetl/operations/utils/progress.py @@ -1,10 +1,12 @@ -from typing import Iterable, Optional, Union from concurrent.futures import as_completed +from typing import Iterable, Optional, Union + from tqdm import tqdm + class RichLoopBar: """A progress bar class that integrates with Rich console.""" - + def __init__( self, iterable: Optional[Union[Iterable, range]] = None, @@ -58,6 +60,7 @@ def update(self, n=1): if self.tqdm: self.tqdm.update(n) + def rich_as_completed(futures, total=None, desc=None, leave=True, console=None): """Yield completed futures with a Rich progress bar.""" if console is None: @@ -66,4 +69,4 @@ def rich_as_completed(futures, total=None, desc=None, leave=True, console=None): with RichLoopBar(total=total, desc=desc, leave=leave, console=console) as pbar: for future in as_completed(futures): yield future - pbar.update() \ No newline at end of file + pbar.update() diff --git a/docetl/operations/utils/validation.py b/docetl/operations/utils/validation.py index 5bbc3be6..b89f5388 100644 --- a/docetl/operations/utils/validation.py +++ b/docetl/operations/utils/validation.py @@ -1,47 +1,46 @@ -import ast import json -from typing import Union, Dict, Any -from asteval import Interpreter -from rich import print as rprint -from rich.prompt import Prompt +from typing import Any, Dict, Union +from asteval import Interpreter from jinja2 import Environment, StrictUndefined, Template from jinja2.exceptions import UndefinedError - +from rich import print as rprint +from rich.prompt import Prompt aeval = Interpreter() + def strict_render(template: Union[Template, str], context: Dict[str, Any]) -> str: """ Renders a Jinja template with strict undefined checking. - + Args: template: Either a Jinja2 Template object or a template string context: Dictionary containing the template variables - + Returns: The rendered template string - + Raises: UndefinedError: When any undefined variable, attribute or index is accessed ValueError: When template is invalid """ # Create strict environment env = Environment(undefined=StrictUndefined) - + # Convert string to Template if needed if isinstance(template, str): # # If "inputs" in the context, make sure they are not accessing some attribute of inputs # if "inputs" in context and "{{ inputs." in template: # raise UndefinedError("The inputs variable is a list, so you cannot access attributes of inputs. Use inputs[index].key instead.") - + try: template = env.from_string(template) except Exception as e: raise ValueError(f"Invalid template: {str(e)}") - - try: + + try: return template.render(context) except UndefinedError as e: # Get the available context keys for better error reporting @@ -53,8 +52,12 @@ def strict_render(template: Union[Template, str], context: Dict[str, Any]) -> st if isinstance(context[var], dict): var_attributes[var] = list(context[var].keys()) elif isinstance(context[var], list) and len(context[var]) > 0: - var_attributes[var] = [f"inputs[i].{k}" for k in context[var][0].keys() if "_observability" not in k] - + var_attributes[var] = [ + f"inputs[i].{k}" + for k in context[var][0].keys() + if "_observability" not in k + ] + raise UndefinedError( f"{str(e)}\n" f"Your prompt can include the following variables: {available_vars}\n" @@ -74,6 +77,7 @@ def safe_eval(expression: str, output: Dict) -> bool: except Exception: return False + def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: """Convert a string representation of a type to a dictionary representation.""" value = value.strip().lower() @@ -110,11 +114,13 @@ def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: else: raise ValueError(f"Unsupported value type: {value}") + def convert_dict_schema_to_list_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Convert a dictionary schema to a list schema.""" schema_str = "{" + ", ".join([f"{k}: {v}" for k, v in schema.items()]) + "}" return {"results": f"list[{schema_str}]"} + def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Prompt the user for input for each key in the schema.""" user_input = {} @@ -128,11 +134,15 @@ def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if isinstance(parsed_value, eval(value_type)): user_input[key] = parsed_value else: - rprint(f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}.") + rprint( + f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}." + ) return get_user_input_for_schema(schema) except json.JSONDecodeError: - rprint(f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again.") + rprint( + f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again." + ) return get_user_input_for_schema(schema) - return user_input \ No newline at end of file + return user_input diff --git a/docetl/optimizer.py b/docetl/optimizer.py new file mode 100644 index 00000000..49fc2719 --- /dev/null +++ b/docetl/optimizer.py @@ -0,0 +1,722 @@ +""" +The Optimizer module implements a pipeline optimization system that works with DocETL's pull-based execution model. +It analyzes operations marked for optimization and rewrites them into more efficient sub-pipelines while preserving +the lazy evaluation semantics of the container system. + +The architecture follows these key principles: +- Integration with the container-based lazy evaluation model +- Specialized optimizers for different operation types (map, reduce, join) +- Sample-based optimization to handle large datasets efficiently +- Cost tracking and caching of intermediate results +""" + +import copy +import hashlib +import os +import random +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple + +import yaml +from rich.panel import Panel +from rich.traceback import install + +from docetl.containers import OpContainer, StepBoundary +from docetl.operations.utils import flush_cache +from docetl.optimizers.join_optimizer import JoinOptimizer +from docetl.optimizers.map_optimizer import MapOptimizer +from docetl.optimizers.reduce_optimizer import ReduceOptimizer +from docetl.optimizers.utils import LLMClient +from docetl.utils import CapturedOutput + +if TYPE_CHECKING: + from docetl.runner import DSLRunner + +install(show_locals=True) + +SAMPLE_SIZE_MAP = { + "reduce": 40, + "map": 5, + "resolve": 100, + "equijoin": 100, + "filter": 5, + "split": 10, + "gather": 10, + "unnest": 10, +} + + +class Optimizer: + """ + Orchestrates the optimization of a DocETL pipeline by analyzing and potentially rewriting + operations marked for optimization. Works with the runner's pull-based execution model + to maintain lazy evaluation while improving pipeline efficiency. + """ + + def __init__( + self, + runner: "DSLRunner", + model: str = "gpt-4o", + resume: bool = False, + timeout: int = 60, + ): + """ + Initialize the optimizer with a runner instance and configuration. + Sets up optimization parameters, caching, and cost tracking. + + Args: + yaml_file (str): Path to the YAML configuration file. + model (str): The name of the language model to use. Defaults to "gpt-4o". + resume (bool): Whether to resume optimization from a previous run. Defaults to False. + timeout (int): Timeout in seconds for operations. Defaults to 60. + + Attributes: + config (Dict): Stores the loaded configuration from the YAML file. + console (Console): Rich console for formatted output. + max_threads (int): Maximum number of threads for parallel processing. + base_name (str): Base name used for file paths. + yaml_file_suffix (str): Suffix for YAML configuration files. + runner (DSLRunner): The DSL runner instance. + status: Status tracking for the runner. + optimized_config (Dict): A copy of the original config to be optimized. + llm_client (LLMClient): Client for interacting with the language model. + timeout (int): Timeout for operations in seconds. + resume (bool): Whether to resume from previous optimization. + captured_output (CapturedOutput): Captures output during optimization. + sample_cache (Dict): Maps operation names to tuples of (output_data, sample_size). + optimized_ops_path (str): Path to store optimized operations. + sample_size_map (Dict): Maps operation types to sample sizes. + + The method also calls print_optimizer_config() to display the initial configuration. + """ + self.config = runner.config + self.console = runner.console + self.max_threads = runner.max_threads + + self.base_name = runner.base_name + self.yaml_file_suffix = runner.yaml_file_suffix + self.runner = runner + self.status = runner.status + + self.optimized_config = copy.deepcopy(self.config) + self.llm_client = LLMClient(model) + self.timeout = timeout + self.resume = resume + self.captured_output = CapturedOutput() + + # Add sample cache for build operations + self.sample_cache = {} # Maps operation names to (output_data, sample_size) + + home_dir = os.environ.get("DOCETL_HOME_DIR", os.path.expanduser("~")) + cache_dir = os.path.join(home_dir, f".docetl/cache/{runner.yaml_file_suffix}") + os.makedirs(cache_dir, exist_ok=True) + + # Hash the config to create a unique identifier + config_hash = hashlib.sha256(str(self.config).encode()).hexdigest() + self.optimized_ops_path = f"{cache_dir}/{config_hash}.yaml" + + # Update sample size map + self.sample_size_map = SAMPLE_SIZE_MAP + if self.config.get("optimizer_config", {}).get("sample_sizes", {}): + self.sample_size_map.update(self.config["optimizer_config"]["sample_sizes"]) + + self.print_optimizer_config() + + def print_optimizer_config(self): + """ + Print the current configuration of the optimizer. + + This method uses the Rich console to display a formatted output of the optimizer's + configuration. It includes details such as the YAML file path, sample sizes for + different operation types, maximum number of threads, the language model being used, + and the timeout setting. + + The output is color-coded and formatted for easy readability, with a header and + separator lines to clearly delineate the configuration information. + """ + self.console.log( + Panel.fit( + "[bold cyan]Optimizer Configuration[/bold cyan]\n" + f"[yellow]Sample Size:[/yellow] {self.sample_size_map}\n" + f"[yellow]Max Threads:[/yellow] {self.max_threads}\n" + f"[yellow]Model:[/yellow] {self.llm_client.model}\n" + f"[yellow]Timeout:[/yellow] {self.timeout} seconds", + title="Optimizer Configuration", + ) + ) + + def _insert_empty_resolve_operations(self): + """ + Determines whether to insert resolve operations in the pipeline. + + For each reduce operation in the tree, checks if it has any map operation as a descendant + without a resolve operation in between. If found, inserts an empty resolve operation + right after the reduce operation. + + The method modifies the operation container tree in-place. + + Returns: + None + """ + if not self.runner.last_op_container: + return + + def find_map_without_resolve(container, visited=None): + """Helper to find first map descendant without a resolve operation in between.""" + if visited is None: + visited = set() + + if container.name in visited: + return None + visited.add(container.name) + + if not container.children: + return None + + for child in container.children: + if child.config["type"] == "map": + return child + if child.config["type"] == "resolve": + continue + map_desc = find_map_without_resolve(child, visited) + if map_desc: + return map_desc + return None + + # Walk down the operation container tree + containers_to_check = [self.runner.last_op_container] + while containers_to_check: + current = containers_to_check.pop(0) + + # Skip if this is a boundary or has no children + if isinstance(current, StepBoundary) or not current.children: + containers_to_check.extend(current.children) + continue + + # Get the step name from the container's name + step_name = current.name.split("/")[0] + + # Check if current container is a reduce operation + if current.config["type"] == "reduce" and current.config.get( + "synthesize_resolve", True + ): + reduce_key = current.config.get("reduce_key", "_all") + if isinstance(reduce_key, str): + reduce_key = [reduce_key] + + if "_all" not in reduce_key: + # Find map descendant without resolve + map_desc = find_map_without_resolve(current) + if map_desc: + # Synthesize an empty resolver + self.console.log( + "[yellow]Synthesizing empty resolver operation:[/yellow]" + ) + self.console.log( + f" • [cyan]Reduce operation:[/cyan] [bold]{current.name}[/bold]" + ) + self.console.log( + f" • [cyan]Step:[/cyan] [bold]{step_name}[/bold]" + ) + + # Create new resolve operation config + new_resolve_name = ( + f"synthesized_resolve_{len(self.config['operations'])}" + ) + new_resolve_config = { + "name": new_resolve_name, + "type": "resolve", + "empty": True, + "optimize": True, + "embedding_model": "text-embedding-3-small", + "resolution_model": self.config.get( + "default_model", "gpt-4o-mini" + ), + "comparison_model": self.config.get( + "default_model", "gpt-4o-mini" + ), + "_intermediates": { + "map_prompt": map_desc.config.get("prompt"), + "reduce_key": reduce_key, + }, + } + + # Add to operations list + self.config["operations"].append(new_resolve_config) + + # Create new resolve container + new_resolve_container = OpContainer( + f"{step_name}/{new_resolve_name}", + self.runner, + new_resolve_config, + ) + + # Insert the new container between reduce and its children + new_resolve_container.children = current.children + for child in new_resolve_container.children: + child.parent = new_resolve_container + current.children = [new_resolve_container] + new_resolve_container.parent = current + + # Add to container map + self.runner.op_container_map[ + f"{step_name}/{new_resolve_name}" + ] = new_resolve_container + + # Add children to the queue + containers_to_check.extend(new_resolve_container.children) + + def _add_map_prompts_to_reduce_operations(self): + """ + Add relevant map prompts to reduce operations based on their reduce keys. + + This method walks the operation container tree to find map operations and their + output schemas, then associates those with reduce operations that use those keys. + When a reduce operation is found, it looks through its descendants to find the + relevant map operations and adds their prompts. + + The method modifies the operation container tree in-place. + """ + if not self.runner.last_op_container: + return + + def find_map_prompts_for_keys(container, keys, visited=None): + """Helper to find map prompts for given keys in the container's descendants.""" + if visited is None: + visited = set() + + if container.name in visited: + return [] + visited.add(container.name) + + prompts = [] + if container.config["type"] == "map": + output_schema = container.config.get("output", {}).get("schema", {}) + if any(key in output_schema for key in keys): + prompts.append(container.config.get("prompt", "")) + + for child in container.children: + prompts.extend(find_map_prompts_for_keys(child, keys, visited)) + + return prompts + + # Walk down the operation container tree + containers_to_check = [self.runner.last_op_container] + while containers_to_check: + current = containers_to_check.pop(0) + + # Skip if this is a boundary or has no children + if isinstance(current, StepBoundary) or not current.children: + containers_to_check.extend(current.children) + continue + + # If this is a reduce operation, find relevant map prompts + if current.config["type"] == "reduce": + reduce_keys = current.config.get("reduce_key", []) + if isinstance(reduce_keys, str): + reduce_keys = [reduce_keys] + + # Find map prompts in descendants + relevant_prompts = find_map_prompts_for_keys(current, reduce_keys) + + if relevant_prompts: + current.config["_intermediates"] = current.config.get( + "_intermediates", {} + ) + current.config["_intermediates"]["last_map_prompt"] = ( + relevant_prompts[-1] + ) + + # Add children to the queue + containers_to_check.extend(current.children) + + def should_optimize( + self, step_name: str, op_name: str + ) -> Tuple[str, List[Dict[str, Any]], List[Dict[str, Any]], float]: + """ + Analyzes whether an operation should be optimized by running it on a sample of input data + and evaluating potential optimizations. Returns the optimization suggestion and relevant data. + """ + self.console.rule("[bold cyan]Beginning Pipeline Optimization[/bold cyan]") + + self._insert_empty_resolve_operations() + + node_of_interest = self.runner.op_container_map[f"{step_name}/{op_name}"] + + # Run the node_of_interest's children + input_data = [] + for child in node_of_interest.children: + input_data.append( + child.next( + is_build=True, + sample_size_needed=SAMPLE_SIZE_MAP.get(child.config["type"]), + )[0] + ) + + # Set the step + self.captured_output.set_step(step_name) + + # Determine whether we should optimize the node_of_interest + if ( + node_of_interest.config.get("type") == "map" + or node_of_interest.config.get("type") == "filter" + ): + # Create instance of map optimizer + map_optimizer = MapOptimizer( + self.runner, + self.runner._run_operation, + is_filter=node_of_interest.config.get("type") == "filter", + ) + should_optimize_output, input_data, output_data = ( + map_optimizer.should_optimize(node_of_interest.config, input_data[0]) + ) + elif node_of_interest.config.get("type") == "reduce": + reduce_optimizer = ReduceOptimizer( + self.runner, + self._run_operation, + ) + should_optimize_output, input_data, output_data = ( + reduce_optimizer.should_optimize(node_of_interest.config, input_data[0]) + ) + elif node_of_interest.config.get("type") == "resolve": + resolve_optimizer = JoinOptimizer( + self.runner, + node_of_interest.config, + target_recall=self.config.get("optimizer_config", {}) + .get("resolve", {}) + .get("target_recall", 0.95), + ) + _, should_optimize_output = resolve_optimizer.should_optimize(input_data[0]) + + # if should_optimize_output is empty, then we should move to the reduce operation + if should_optimize_output == "": + return "", [], [], 0.0 + else: + return "", [], [], 0.0 + + # Return the string and operation cost + return ( + should_optimize_output, + input_data, + output_data, + self.runner.total_cost + self.llm_client.total_cost, + ) + + def optimize(self) -> float: + """ + Optimizes the entire pipeline by walking the operation DAG and applying + operation-specific optimizers where marked. Returns the total optimization cost. + """ + self.console.rule("[bold cyan]Beginning Pipeline Optimization[/bold cyan]") + + # If self.resume is True and there's a checkpoint, load it + if self.resume: + if os.path.exists(self.optimized_ops_path): + # Load the yaml and change the runner with it + with open(self.optimized_ops_path, "r") as f: + partial_optimized_config = yaml.safe_load(f) + self.console.log( + "[yellow]Loading partially optimized pipeline from checkpoint...[/yellow]" + ) + self.runner._build_operation_graph(partial_optimized_config) + else: + self.console.log( + "[yellow]No checkpoint found, starting optimization from scratch...[/yellow]" + ) + + else: + self._insert_empty_resolve_operations() + + # Start with the last operation container and visit each child + self.runner.last_op_container.optimize() + + flush_cache(self.console) + + # Print the query plan + self.console.rule("[bold cyan]Optimized Query Plan[/bold cyan]") + self.runner.print_query_plan() + + return self.llm_client.total_cost + + def _optimize_equijoin( + self, + op_config: Dict[str, Any], + left_name: str, + right_name: str, + left_data: List[Dict[str, Any]], + right_data: List[Dict[str, Any]], + run_operation: Callable[ + [Dict[str, Any], List[Dict[str, Any]]], List[Dict[str, Any]] + ], + ) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]], str, str]: + """ + Optimizes an equijoin operation by analyzing join conditions and potentially inserting + map operations to improve join efficiency. Returns the optimized configuration and updated data. + """ + max_iterations = 2 + new_left_name = left_name + new_right_name = right_name + new_steps = [] + for _ in range(max_iterations): + join_optimizer = JoinOptimizer( + self.runner, + op_config, + target_recall=self.runner.config.get("optimizer_config", {}) + .get("equijoin", {}) + .get("target_recall", 0.95), + estimated_selectivity=self.runner.config.get("optimizer_config", {}) + .get("equijoin", {}) + .get("estimated_selectivity", None), + ) + optimized_config, cost, agent_results = join_optimizer.optimize_equijoin( + left_data, right_data + ) + self.runner.total_cost += cost + # Update the operation config with the optimized values + op_config.update(optimized_config) + + if not agent_results.get("optimize_map", False): + break # Exit the loop if no more map optimizations are necessary + + # Update the status to indicate we're optimizing a map operation + output_key = agent_results["output_key"] + if self.runner.status: + self.runner.status.update( + f"Optimizing map operation for {output_key} extraction to help with the equijoin" + ) + map_prompt = agent_results["map_prompt"] + dataset_to_transform = ( + left_data + if agent_results["dataset_to_transform"] == "left" + else right_data + ) + + # Create a new step for the map operation + map_operation = { + "name": f"synthesized_{output_key}_extraction", + "type": "map", + "prompt": map_prompt, + "model": self.config.get("default_model", "gpt-4o-mini"), + "output": {"schema": {output_key: "string"}}, + "optimize": False, + } + + # Optimize the map operation + if map_operation["optimize"]: + dataset_to_transform_sample = ( + random.sample(dataset_to_transform, self.sample_size_map.get("map")) + if self.config.get("optimizer_config", {}).get( + "random_sample", False + ) + else dataset_to_transform[: self.sample_size_map.get("map")] + ) + optimized_map_operations = self._optimize_map( + map_operation, dataset_to_transform_sample + ) + else: + optimized_map_operations = [map_operation] + + new_step = { + "name": f"synthesized_{output_key}_extraction", + "input": ( + left_name + if agent_results["dataset_to_transform"] == "left" + else right_name + ), + "operations": [mo["name"] for mo in optimized_map_operations], + } + if agent_results["dataset_to_transform"] == "left": + new_left_name = new_step["name"] + else: + new_right_name = new_step["name"] + + new_steps.append((new_step["name"], new_step, optimized_map_operations)) + + # Now run the optimized map operation on the entire dataset_to_transform + for op in optimized_map_operations: + dataset_to_transform = run_operation(op, dataset_to_transform) + + # Update the appropriate dataset for the next iteration + if agent_results["dataset_to_transform"] == "left": + left_data = dataset_to_transform + else: + right_data = dataset_to_transform + + if self.runner.status: + self.runner.status.update( + f"Optimizing equijoin operation with {output_key} extraction" + ) + + return op_config, new_steps, new_left_name, new_right_name + + def checkpoint_optimized_ops(self) -> None: + """ + Generates the clean config and saves it to the self.optimized_ops_path + This is used to resume optimization from a previous run + """ + clean_config = self.clean_optimized_config() + with open(self.optimized_ops_path, "w") as f: + yaml.safe_dump(clean_config, f, default_flow_style=False, width=80) + + # Recursively resolve all anchors and aliases + @staticmethod + def resolve_anchors(data): + """ + Recursively resolve all anchors and aliases in a nested data structure. + + This static method traverses through dictionaries and lists, resolving any YAML anchors and aliases. + + Args: + data: The data structure to resolve. Can be a dictionary, list, or any other type. + + Returns: + The resolved data structure with all anchors and aliases replaced by their actual values. + """ + if isinstance(data, dict): + return {k: Optimizer.resolve_anchors(v) for k, v in data.items()} + elif isinstance(data, list): + return [Optimizer.resolve_anchors(item) for item in data] + else: + return data + + def clean_optimized_config(self) -> Dict: + """ + Creates a clean YAML configuration from the optimized operation containers, + removing internal fields and organizing operations into proper pipeline steps. + """ + if not self.runner.last_op_container: + return self.config + + # Create a clean copy of the config + clean_config = { + "datasets": self.config.get("datasets", {}), + "operations": [], + "pipeline": self.runner.config.get( + "pipeline", {} + ).copy(), # Copy entire pipeline config + } + + # Reset steps to regenerate + clean_config["pipeline"]["steps"] = [] + + # Keep track of operations we've seen to avoid duplicates + seen_operations = set() + + def clean_operation(op_container: OpContainer) -> Dict: + """Remove internal fields from operation config""" + op_config = op_container.config + clean_op = op_config.copy() + + clean_op.pop("_intermediates", None) + + # If op has already been optimized, remove the recursively_optimize and optimize fields + if op_container.is_optimized: + for field in ["recursively_optimize", "optimize"]: + clean_op.pop(field, None) + + return clean_op + + def process_container(container, current_step=None): + """Process an operation container and its dependencies""" + # Skip step boundaries + if isinstance(container, StepBoundary): + if container.children: + return process_container(container.children[0], current_step) + return None, None + + # Get step name from container name + step_name = container.name.split("/")[0] + + # If this is a new step, create it + if not current_step or current_step["name"] != step_name: + current_step = {"name": step_name, "operations": []} + clean_config["pipeline"]["steps"].insert(0, current_step) + + # Skip scan operations but process their dependencies + if container.config["type"] == "scan": + if container.children: + return process_container(container.children[0], current_step) + return None, current_step + + # Handle equijoin operations + if container.is_equijoin: + # Add operation to list if not seen + if container.name not in seen_operations: + op_config = clean_operation(container) + clean_config["operations"].append(op_config) + seen_operations.add(container.name) + + # Add to step operations with left and right inputs + current_step["operations"].insert( + 0, + { + container.config["name"]: { + "left": container.kwargs["left_name"], + "right": container.kwargs["right_name"], + } + }, + ) + + # Process both children + if container.children: + process_container(container.children[0], current_step) + process_container(container.children[1], current_step) + else: + # Add operation to list if not seen + if container.name not in seen_operations: + op_config = clean_operation(container) + clean_config["operations"].append(op_config) + seen_operations.add(container.name) + + # Add to step operations + current_step["operations"].insert(0, container.config["name"]) + + # Process children + if container.children: + for child in container.children: + process_container(child, current_step) + + return container, current_step + + # Start processing from the last container + process_container(self.runner.last_op_container) + + # Add inputs to steps based on their first operation + for step in clean_config["pipeline"]["steps"]: + first_op = step["operations"][0] + if isinstance(first_op, dict): # This is an equijoin + continue # Equijoin steps don't need an input field + elif len(step["operations"]) > 0: + # Find the first non-scan operation's input by looking at its dependencies + op_container = self.runner.op_container_map.get( + f"{step['name']}/{first_op}" + ) + if op_container and op_container.children: + child = op_container.children[0] + while ( + child + and child.config["type"] == "step_boundary" + and child.children + ): + child = child.children[0] + if child and child.config["type"] == "scan": + step["input"] = child.config["dataset_name"] + + # Preserve all other config key-value pairs from original config + for key, value in self.config.items(): + if key not in ["datasets", "operations", "pipeline"]: + clean_config[key] = value + + return clean_config + + def save_optimized_config(self, optimized_config_path: str): + """ + Saves the optimized configuration to a YAML file after resolving all references + and cleaning up internal optimization artifacts. + """ + resolved_config = self.clean_optimized_config() + + with open(optimized_config_path, "w") as f: + yaml.safe_dump(resolved_config, f, default_flow_style=False, width=80) + self.console.log( + f"[green italic]💾 Optimized config saved to {optimized_config_path}[/green italic]" + ) diff --git a/docetl/optimizers/__init__.py b/docetl/optimizers/__init__.py index e69de29b..f518b319 100644 --- a/docetl/optimizers/__init__.py +++ b/docetl/optimizers/__init__.py @@ -0,0 +1,5 @@ +from docetl.optimizers.join_optimizer import JoinOptimizer +from docetl.optimizers.map_optimizer import MapOptimizer +from docetl.optimizers.reduce_optimizer import ReduceOptimizer + +__all__ = ["JoinOptimizer", "MapOptimizer", "ReduceOptimizer"] \ No newline at end of file diff --git a/docetl/optimizers/join_optimizer.py b/docetl/optimizers/join_optimizer.py index 8da92c5d..c9d1f92b 100644 --- a/docetl/optimizers/join_optimizer.py +++ b/docetl/optimizers/join_optimizer.py @@ -5,44 +5,37 @@ import numpy as np from litellm import model_cost -from rich.console import Console from rich.prompt import Confirm -from rich.status import Status from docetl.operations.equijoin import EquijoinOperation from docetl.operations.resolve import ResolveOperation -from docetl.utils import completion_cost, extract_jinja_variables, StageType +from docetl.utils import completion_cost, extract_jinja_variables class JoinOptimizer: def __init__( self, runner, - config: Dict[str, Any], op_config: Dict[str, Any], - console: Console, - llm_client: Any, - max_threads: int, target_recall: float = 0.95, sample_size: int = 500, sampling_weight: float = 20, agent_max_retries: int = 5, estimated_selectivity: float = None, - status: Status = None, ): self.runner = runner - self.config = config + self.config = runner.config self.op_config = op_config - self.llm_client = llm_client - self.max_threads = max_threads - self.console = console + self.llm_client = runner.optimizer.llm_client + self.max_threads = runner.max_threads + self.console = runner.console self.target_recall = target_recall self.sample_size = sample_size self.sampling_weight = sampling_weight self.agent_max_retries = agent_max_retries self.estimated_selectivity = estimated_selectivity self.console.log(f"Target Recall: {self.target_recall}") - self.status = status + self.status = self.runner.status # if self.estimated_selectivity is not None: # self.console.log( # f"[yellow]Using estimated selectivity of {self.estimated_selectivity}[/yellow]" @@ -376,13 +369,15 @@ def synthesize_resolution_prompt( ) return resolution_prompt - + def should_optimize(self, input_data: List[Dict[str, Any]]) -> Tuple[bool, str]: """ Determine if the given operation configuration should be optimized. """ # If there are no blocking keys or embeddings, then we don't need to optimize - if not self.op_config.get("blocking_conditions") or not self.op_config.get("blocking_threshold"): + if not self.op_config.get("blocking_conditions") or not self.op_config.get( + "blocking_threshold" + ): return True, "" # Check if the operation is marked as empty @@ -401,7 +396,9 @@ def should_optimize(self, input_data: List[Dict[str, Any]]) -> Tuple[bool, str]: if map_prompt: # Analyze the map prompt - analysis, explanation = self._analyze_map_prompt_categorization(map_prompt) + analysis, explanation = self._analyze_map_prompt_categorization( + map_prompt + ) if analysis: dedup = False @@ -431,15 +428,14 @@ def should_optimize(self, input_data: List[Dict[str, Any]]) -> Tuple[bool, str]: if duplicates_found: dedup = True - + return dedup, explanation - + return False, "" def optimize_resolve( self, input_data: List[Dict[str, Any]] ) -> Tuple[Dict[str, Any], float]: - # Check if the operation is marked as empty if self.op_config.get("empty", False): # Extract the map prompt from the intermediates @@ -622,7 +618,7 @@ def optimize_equijoin( # Log the generated blocking keys self.console.log( - f"[bold]Generated blocking keys (for embeddings-based blocking):[/bold]" + "[bold]Generated blocking keys (for embeddings-based blocking):[/bold]" ) self.console.log(f"Left keys: {left_keys}") self.console.log(f"Right keys: {right_keys}") diff --git a/docetl/optimizers/map_optimizer/config_generators.py b/docetl/optimizers/map_optimizer/config_generators.py index 0dbaf0a8..a91d0584 100644 --- a/docetl/optimizers/map_optimizer/config_generators.py +++ b/docetl/optimizers/map_optimizer/config_generators.py @@ -148,7 +148,10 @@ def _get_split_config( result["subprompt_output_schema"].update(op_config["output"]["schema"]) - result["subprompt"] = result["subprompt"] + " Only process the main chunk in --- Begin Main Chunk --- and --- End Main Chunk --- delimiters if they are present." + result["subprompt"] = ( + result["subprompt"] + + " Only process the main chunk in --- Begin Main Chunk --- and --- End Main Chunk --- delimiters if they are present." + ) self.console.log( f"[yellow]Breaking down operation {op_config['name']}[/yellow]" diff --git a/docetl/optimizers/map_optimizer/evaluator.py b/docetl/optimizers/map_optimizer/evaluator.py index 7ca2664c..9074f2d4 100644 --- a/docetl/optimizers/map_optimizer/evaluator.py +++ b/docetl/optimizers/map_optimizer/evaluator.py @@ -171,19 +171,21 @@ def _compare_two_plans( winner = { "plan_1": f"[cyan]{plan1_name}[/cyan]", "plan_2": f"[green]{plan2_name}[/green]", - "tie": "[yellow]Tie[/yellow]" + "tie": "[yellow]Tie[/yellow]", }[comp["better_plan"]] - + comparison_content += ( f"[bold]Sample {i+1}:[/bold]\n" f"Winner: {winner}\n" f"Reason: {comp['reason']}\n\n" ) - self.console.print(Panel.fit( - comparison_content, - title=f"[bold magenta]Pairwise Comparison: {plan1_name} vs {plan2_name}[/bold magenta]" - )) + self.console.log( + Panel.fit( + comparison_content, + title=f"[bold magenta]Pairwise Comparison: {plan1_name} vs {plan2_name}[/bold magenta]", + ) + ) if plan1_wins > plan2_wins: return plan1_name @@ -371,7 +373,7 @@ def _assess_operation( Custom Validator Prompt: {validator_prompt} - Based on the above information, please assess the operation's performance. + Based on the above information, please assess the operation's performance. If it needs improvement, provide specific examples in your assessment. Be very detailed in your reasons for improvements, if any. Provide your assessment in the following format: diff --git a/docetl/optimizers/map_optimizer/optimizer.py b/docetl/optimizers/map_optimizer/optimizer.py index 1367e31d..02597455 100644 --- a/docetl/optimizers/map_optimizer/optimizer.py +++ b/docetl/optimizers/map_optimizer/optimizer.py @@ -7,15 +7,13 @@ from jinja2 import Template from litellm import model_cost -from rich.console import Console from rich.table import Table from docetl.optimizers.map_optimizer.evaluator import Evaluator from docetl.optimizers.map_optimizer.plan_generators import PlanGenerator from docetl.optimizers.map_optimizer.prompt_generators import PromptGenerator from docetl.optimizers.map_optimizer.utils import select_evaluation_samples -from docetl.optimizers.utils import LLMClient -from docetl.utils import count_tokens, CapturedOutput, StageType +from docetl.utils import StageType, count_tokens class MapOptimizer: @@ -40,10 +38,6 @@ class MapOptimizer: def __init__( self, runner, - config: Dict[str, Any], - console: Console, - llm_client: LLMClient, - max_threads: int, run_operation: Callable, timeout: int = 10, is_filter: bool = False, @@ -53,55 +47,87 @@ def __init__( Initialize the MapOptimizer. Args: - config (Dict[str, Any]): The configuration dictionary for the optimizer. - console (Console): A Rich console object for pretty printing. - llm_client (LLMClient): A client for interacting with a language model. - max_threads (int): The maximum number of threads to use for parallel execution. + runner (Runner): The runner object. run_operation (Callable): A function to execute operations. timeout (int, optional): The timeout in seconds for operation execution. Defaults to 10. is_filter (bool, optional): If True, the operation is a filter operation. Defaults to False. """ self.runner = runner - self.config = config - self.console = console - self.llm_client = llm_client + self.config = runner.config + self.console = runner.console + self.llm_client = runner.optimizer.llm_client self._run_operation = run_operation - self.max_threads = max_threads - self.timeout = timeout + self.max_threads = runner.max_threads + self.timeout = runner.optimizer.timeout self._num_plans_to_evaluate_in_parallel = 5 self.is_filter = is_filter self.k_to_pairwise_compare = 6 self.plan_generator = PlanGenerator( - runner, llm_client, console, config, run_operation, max_threads, is_filter, depth + runner, + self.llm_client, + self.console, + self.config, + run_operation, + self.max_threads, + is_filter, + depth, ) self.evaluator = Evaluator( - llm_client, - console, - run_operation, - timeout, + self.llm_client, + self.console, + self._run_operation, + self.timeout, self._num_plans_to_evaluate_in_parallel, - is_filter, + self.is_filter, ) self.prompt_generator = PromptGenerator( - runner, llm_client, console, config, max_threads, is_filter + self.runner, + self.llm_client, + self.console, + self.config, + self.max_threads, + self.is_filter, ) - def should_optimize(self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]]) -> Tuple[str, List[Dict[str, Any]], List[Dict[str, Any]]]: + def should_optimize( + self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]] + ) -> Tuple[str, List[Dict[str, Any]], List[Dict[str, Any]]]: """ Determine if the given operation configuration should be optimized. """ - input_data, output_data, _, _, validator_prompt, assessment, data_exceeds_limit = self._should_optimize_helper(op_config, input_data) + ( + input_data, + output_data, + _, + _, + validator_prompt, + assessment, + data_exceeds_limit, + ) = self._should_optimize_helper(op_config, input_data) if data_exceeds_limit or assessment.get("needs_improvement", True): - assessment_str = "\n".join(assessment.get("reasons", [])) + "\n\nHere are some improvements that may help:\n" + "\n".join(assessment.get("improvements", [])) + assessment_str = ( + "\n".join(assessment.get("reasons", [])) + + "\n\nHere are some improvements that may help:\n" + + "\n".join(assessment.get("improvements", [])) + ) if data_exceeds_limit: assessment_str += "\nAlso, the input data exceeds the token limit." return assessment_str, input_data, output_data else: return "", input_data, output_data - - def _should_optimize_helper(self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], int, float, str, Dict[str, Any], bool]: + def _should_optimize_helper( + self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]] + ) -> Tuple[ + List[Dict[str, Any]], + List[Dict[str, Any]], + int, + float, + str, + Dict[str, Any], + bool, + ]: """ Determine if the given operation configuration should be optimized. Create a custom validator prompt and assess the operation's performance @@ -150,7 +176,7 @@ def _should_optimize_helper(self, op_config: Dict[str, Any], input_data: List[Di no_change_runtime = time.time() - no_change_start # Capture output for the sample run - self.runner.captured_output.save_optimizer_output( + self.runner.optimizer.captured_output.save_optimizer_output( stage_type=StageType.SAMPLE_RUN, output={ "operation_config": op_config, @@ -159,7 +185,6 @@ def _should_optimize_helper(self, op_config: Dict[str, Any], input_data: List[Di }, ) - # Generate custom validator prompt self.console.post_optimizer_status(StageType.SHOULD_OPTIMIZE) validator_prompt = self.prompt_generator._generate_validator_prompt( @@ -181,12 +206,10 @@ def _should_optimize_helper(self, op_config: Dict[str, Any], input_data: List[Di f"[bold]Assessment for whether we should improve operation {op_config['name']}:[/bold]" ) for key, value in assessment.items(): - self.console.print( - f"[bold cyan]{key}:[/bold cyan] [yellow]{value}[/yellow]" - ) + self.console.log(f"[bold cyan]{key}:[/bold cyan] [yellow]{value}[/yellow]") self.console.log("\n") # Add a newline for better readability - self.runner.captured_output.save_optimizer_output( + self.runner.optimizer.captured_output.save_optimizer_output( stage_type=StageType.SHOULD_OPTIMIZE, output={ "validator_prompt": validator_prompt, @@ -203,11 +226,21 @@ def _should_optimize_helper(self, op_config: Dict[str, Any], input_data: List[Di validator_prompt, ) - return input_data, output_data, model_input_context_length, no_change_runtime, validator_prompt, assessment, data_exceeds_limit - + return ( + input_data, + output_data, + model_input_context_length, + no_change_runtime, + validator_prompt, + assessment, + data_exceeds_limit, + ) def optimize( - self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]], plan_types: Optional[List[str]] = ["chunk", "proj_synthesis", "glean"] + self, + op_config: Dict[str, Any], + input_data: List[Dict[str, Any]], + plan_types: Optional[List[str]] = ["chunk", "proj_synthesis", "glean"], ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], float]: """ Optimize the given operation configuration for the input data. @@ -256,10 +289,19 @@ def optimize( # Verify that the plan types are valid for plan_type in plan_types: if plan_type not in ["chunk", "proj_synthesis", "glean"]: - raise ValueError(f"Invalid plan type: {plan_type}. Valid plan types are: chunk, proj_synthesis, glean.") + raise ValueError( + f"Invalid plan type: {plan_type}. Valid plan types are: chunk, proj_synthesis, glean." + ) - - input_data, output_data, model_input_context_length, no_change_runtime, validator_prompt, assessment, data_exceeds_limit = self._should_optimize_helper(op_config, input_data) + ( + input_data, + output_data, + model_input_context_length, + no_change_runtime, + validator_prompt, + assessment, + data_exceeds_limit, + ) = self._should_optimize_helper(op_config, input_data) # Check if improvement is needed based on the assessment if not self.config.get("optimizer_config", {}).get("force_decompose", False): @@ -267,7 +309,11 @@ def optimize( self.console.log( f"[green]No improvement needed for operation {op_config['name']}[/green]" ) - return [op_config], output_data, self.plan_generator.subplan_optimizer_cost + return ( + [op_config], + output_data, + self.plan_generator.subplan_optimizer_cost, + ) candidate_plans = {} @@ -282,7 +328,9 @@ def optimize( # Generate chunk size plans self.console.post_optimizer_status(StageType.CANDIDATE_PLANS) if "chunk" in plan_types: - self.console.log("[bold magenta]Generating chunking plans...[/bold magenta]") + self.console.log( + "[bold magenta]Generating chunking plans...[/bold magenta]" + ) chunk_size_plans = self.plan_generator._generate_chunk_size_plans( op_config, input_data, validator_prompt, model_input_context_length ) @@ -330,7 +378,7 @@ def optimize( plans_list = list(candidate_plans.items()) # Capture candidate plans - self.runner.captured_output.save_optimizer_output( + self.runner.optimizer.captured_output.save_optimizer_output( stage_type=StageType.CANDIDATE_PLANS, output=candidate_plans, ) @@ -458,7 +506,7 @@ def optimize( ratings = {k: v[0] for k, v in results.items()} runtime = {k: v[1] for k, v in results.items()} sample_outputs = {k: v[2] for k, v in results.items()} - self.runner.captured_output.save_optimizer_output( + self.runner.optimizer.captured_output.save_optimizer_output( stage_type=StageType.EVALUATION_RESULTS, output={ "input_data": evaluation_samples, diff --git a/docetl/optimizers/map_optimizer/plan_generators.py b/docetl/optimizers/map_optimizer/plan_generators.py index bb026110..b0b904b9 100644 --- a/docetl/optimizers/map_optimizer/plan_generators.py +++ b/docetl/optimizers/map_optimizer/plan_generators.py @@ -1,8 +1,7 @@ import copy import json -import random from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Tuple from rich.console import Console @@ -130,7 +129,10 @@ def determine_metadata_with_retry(): f"Metadata prompt and output schema: {metadata_info.get('metadata_prompt', 'N/A')}; {metadata_info.get('output_schema', 'N/A')}" ) self.console.log(f"Reason: {metadata_info.get('reason', 'N/A')}") - split_subprompt = "Given the following metadata about the document:\n{{ input.metadata }}\n\n" + split_subprompt + split_subprompt = ( + "Given the following metadata about the document:\n{{ input.metadata }}\n\n" + + split_subprompt + ) # Create header extraction prompt header_extraction_prompt, header_output_schema = ( @@ -255,7 +257,7 @@ def determine_metadata_with_retry(): map_op, sample_map_input, "shared_submap", - plan_types=["proj_synthesis", "glean"] + plan_types=["proj_synthesis", "glean"], ) self.subplan_optimizer_cost += cost except Exception as e: @@ -269,19 +271,18 @@ def determine_metadata_with_retry(): try: optimized_reduce_ops, _, cost = ReduceOptimizer( self.runner, - self.config, - self.console, - self.llm_client, - self.max_threads, self._run_operation, ).optimize(reduce_op, sample_output) self.subplan_optimizer_cost += cost except Exception as e: - import traceback + import traceback + self.console.log( f"[yellow]Warning: Failed to recursively optimize reduce operation: {e}. Using original reduce operation.[/yellow]" ) - self.console.log(f"[yellow]Traceback:[/yellow]\n{traceback.format_exc()}") + self.console.log( + f"[yellow]Traceback:[/yellow]\n{traceback.format_exc()}" + ) # Create plans for each chunk size plans = {} @@ -315,7 +316,7 @@ def task(): header_extraction_prompt, header_output_schema, ) - + # Create the plan by combining all operations plan = copy.deepcopy(base_operations) plan.extend(smg_ops + optimized_map_ops + optimized_reduce_ops) @@ -729,7 +730,7 @@ def _generate_chain_plans( ) -> Dict[str, List[Dict[str, Any]]]: """ Generate chain decomposition plans for the given operation. - + If recursively_optimize is True in the op_config, each subtask in the chain will be recursively optimized using a new MapOptimizer instance. """ @@ -855,10 +856,10 @@ def _generate_chain_plans( if op_config.get("recursively_optimize", False): try: optimized_subtask_plan, cost = self._recursively_optimize_subtask( - subtask_config, + subtask_config, input_data, f"chain_subtask_{idx+1}", - plan_types=["proj_synthesis", "glean"] + plan_types=["proj_synthesis", "glean"], ) self.subplan_optimizer_cost += cost chain_plan.extend(optimized_subtask_plan) @@ -893,7 +894,7 @@ def _recursively_optimize_subtask( subtask_config: Dict[str, Any], input_data: List[Dict[str, Any]], subtask_name: str, - plan_types: List[str] + plan_types: List[str], ) -> Tuple[List[Dict[str, Any]], float]: """ Recursively optimize a subtask using a new MapOptimizer instance. @@ -906,24 +907,20 @@ def _recursively_optimize_subtask( from docetl.optimizers.map_optimizer.optimizer import MapOptimizer - self.console.log(f"[cyan]Recursively optimizing {subtask_name} (depth {self.depth})...[/cyan]") + self.console.log( + f"[cyan]Recursively optimizing {subtask_name} (depth {self.depth})...[/cyan]" + ) subtask_optimizer = MapOptimizer( self.runner, - self.config, - self.console, - self.llm_client, - self.max_threads, self._run_operation, is_filter=self.is_filter, - depth=self.depth + 1 + depth=self.depth + 1, ) try: optimized_plan, _, cost = subtask_optimizer.optimize( - subtask_config, - input_data, - plan_types + subtask_config, input_data, plan_types ) return optimized_plan, cost diff --git a/docetl/optimizers/map_optimizer/prompt_generators.py b/docetl/optimizers/map_optimizer/prompt_generators.py index 94e778fa..b0fc261b 100644 --- a/docetl/optimizers/map_optimizer/prompt_generators.py +++ b/docetl/optimizers/map_optimizer/prompt_generators.py @@ -4,7 +4,6 @@ from litellm import model_cost from rich.console import Console -from rich.prompt import Prompt from docetl.optimizers.map_optimizer.utils import generate_and_validate_prompt from docetl.optimizers.utils import LLMClient @@ -14,7 +13,7 @@ class PromptGenerator: def __init__( self, - runner: "DSLRunner", + runner, llm_client: LLMClient, console: Console, config: Dict[str, Any], diff --git a/docetl/optimizers/reduce_optimizer.py b/docetl/optimizers/reduce_optimizer.py index 2fac0cae..7681d836 100644 --- a/docetl/optimizers/reduce_optimizer.py +++ b/docetl/optimizers/reduce_optimizer.py @@ -3,19 +3,16 @@ import random from concurrent.futures import ThreadPoolExecutor, as_completed from statistics import mean -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Union from jinja2 import Template from litellm import model_cost -from rich.console import Console from rich.prompt import Confirm -from rich.status import Status from docetl.operations.base import BaseOperation from docetl.operations.utils import truncate_messages from docetl.optimizers.join_optimizer import JoinOptimizer -from docetl.optimizers.utils import LLMClient -from docetl.utils import count_tokens, extract_jinja_variables, StageType +from docetl.utils import StageType, count_tokens, extract_jinja_variables class ReduceOptimizer: @@ -38,14 +35,9 @@ class ReduceOptimizer: def __init__( self, runner, - config: Dict[str, Any], - console: Console, - llm_client: LLMClient, - max_threads: int, run_operation: Callable, num_fold_prompts: int = 1, num_samples_in_validation: int = 10, - status: Optional[Status] = None, ): """ Initialize the ReduceOptimizer. @@ -60,14 +52,14 @@ def __init__( num_samples_in_validation (int, optional): Number of samples to use in validation. Defaults to 10. """ self.runner = runner - self.config = config - self.console = console - self.llm_client = llm_client + self.config = self.runner.config + self.console = self.runner.console + self.llm_client = self.runner.optimizer.llm_client self._run_operation = run_operation - self.max_threads = max_threads + self.max_threads = self.runner.max_threads self.num_fold_prompts = num_fold_prompts self.num_samples_in_validation = num_samples_in_validation - self.status = status + self.status = self.runner.status def should_optimize_helper( self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]] @@ -122,20 +114,44 @@ def should_optimize_helper( op_config, validator_inputs, original_output, validator_prompt ) - return validation_results, prompt_tokens, model_input_context_length, model, validator_prompt, original_output - - def should_optimize(self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]]) -> Tuple[str, List[Dict[str, Any]], List[Dict[str, Any]]]: - validation_results, prompt_tokens, model_input_context_length, model, validator_prompt, original_output = self.should_optimize_helper(op_config, input_data) + return ( + validation_results, + prompt_tokens, + model_input_context_length, + model, + validator_prompt, + original_output, + ) + + def should_optimize( + self, op_config: Dict[str, Any], input_data: List[Dict[str, Any]] + ) -> Tuple[str, List[Dict[str, Any]], List[Dict[str, Any]]]: + ( + validation_results, + prompt_tokens, + model_input_context_length, + model, + validator_prompt, + original_output, + ) = self.should_optimize_helper(op_config, input_data) if prompt_tokens * 1.5 > model_input_context_length: - return "The reduce prompt is likely to exceed the token limit for model {model}.", input_data, original_output + return ( + "The reduce prompt is likely to exceed the token limit for model {model}.", + input_data, + original_output, + ) if validation_results.get("needs_improvement", False): - return "\n".join( - [ - f"Issues: {result['issues']} Suggestions: {result['suggestions']}" - for result in validation_results["validation_results"] - ] - ), input_data, original_output + return ( + "\n".join( + [ + f"Issues: {result['issues']} Suggestions: {result['suggestions']}" + for result in validation_results["validation_results"] + ] + ), + input_data, + original_output, + ) else: return "", input_data, original_output @@ -166,11 +182,18 @@ def optimize( Tuple[List[Dict[str, Any]], List[Dict[str, Any]], float]: A tuple containing the list of optimized configurations and the list of outputs from the optimized operation(s), and the cost of the operation due to synthesizing any resolve operations. """ - validation_results, prompt_tokens, model_input_context_length, model, validator_prompt, original_output = self.should_optimize_helper(op_config, input_data) - - add_map_op = False + ( + validation_results, + prompt_tokens, + model_input_context_length, + model, + validator_prompt, + original_output, + ) = self.should_optimize_helper(op_config, input_data) + + # add_map_op = False if prompt_tokens * 2 > model_input_context_length: - add_map_op = True + # add_map_op = True self.console.log( f"[yellow]Warning: The reduce prompt exceeds the token limit for model {model}. " f"Token count: {prompt_tokens}, Limit: {model_input_context_length}. " @@ -196,13 +219,14 @@ def optimize( # # Return unoptimized map and reduce operations # return [map_prompt, op_config], input_data, 0.0 - # Print the validation results self.console.log("[bold]Validation Results on Initial Sample:[/bold]") - if validation_results["needs_improvement"] or self.config.get("optimizer_config", {}).get("force_decompose", False): + if validation_results["needs_improvement"] or self.config.get( + "optimizer_config", {} + ).get("force_decompose", False): self.console.post_optimizer_rationale( should_optimize=True, - rationale= "\n".join( + rationale="\n".join( [ f"Issues: {result['issues']} Suggestions: {result['suggestions']}" for result in validation_results["validation_results"] @@ -299,7 +323,7 @@ def _should_use_map( should_preprocess = preprocessing_result["preprocessing_needed"] preprocessing_rationale = preprocessing_result["rationale"] - self.console.log(f"[bold]Map-Reduce Decomposition Analysis:[/bold]") + self.console.log("[bold]Map-Reduce Decomposition Analysis:[/bold]") self.console.log(f"Should write a map operation: {should_preprocess}") self.console.log(f"Rationale: {preprocessing_rationale}") @@ -1076,6 +1100,8 @@ def _generate_validator_prompt( 4. How well does the output adhere to any specific formatting requirements mentioned in the original prompt, such as character limits for summaries or specific data types for aggregated values? Note that the output may reflect more than just the input provided, since we only provide a one-item sample input. Provide your response as a single string containing the custom validator prompt. The prompt should be tailored to the task and avoid generic criteria. The prompt should not reference a specific value in the sample input, but rather a general property. + + Your prompt should not have any placeholders like {{ reduce_key }} or {{ input_key }}. It should just be a string. """ parameters = { @@ -1119,7 +1145,10 @@ def _validate_reduce_output( with ThreadPoolExecutor(max_workers=self.max_threads) as executor: futures = [] for reduce_key, inputs in validation_inputs.items(): - if op_config["reduce_key"] == ["_all"] or op_config["reduce_key"] == "_all": + if ( + op_config["reduce_key"] == ["_all"] + or op_config["reduce_key"] == "_all" + ): sample_output = output_data[0] elif isinstance(op_config["reduce_key"], list): sample_output = next( @@ -1490,8 +1519,10 @@ def _synthesize_fold_prompts( def get_random_examples(): reduce_key = op_config["reduce_key"] - reduce_key = list(reduce_key) if not isinstance(reduce_key, list) else reduce_key - + reduce_key = ( + list(reduce_key) if not isinstance(reduce_key, list) else reduce_key + ) + if reduce_key == ["_all"]: # For _all case, just pick random input and output examples input_example = random.choice(sample_input) @@ -1581,11 +1612,15 @@ def generate_single_prompt(): # Run the operation with the fold prompt try: - self._run_operation(temp_plan, sample_input[: temp_plan["fold_batch_size"]]) + self._run_operation( + temp_plan, sample_input[: temp_plan["fold_batch_size"]] + ) return fold_prompt except Exception as e: - self.console.log(f"[red]Error in agent-generated fold prompt: {e}[/red]") + self.console.log( + f"[red]Error in agent-generated fold prompt: {e}[/red]" + ) # Create a default fold prompt that instructs folding new data into existing output fold_prompt = f"""Analyze this batch of data using the following instructions: @@ -1594,11 +1629,11 @@ def generate_single_prompt(): However, instead of starting fresh, fold your analysis into the existing output that has already been generated. The existing output is provided in the 'output' variable below: -{{{{ output }}}} +{{{{ output }}}} Remember, you must fold the new data into the existing output, do not start fresh.""" return fold_prompt - + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: fold_prompts = list( executor.map(lambda _: generate_single_prompt(), range(num_prompts)) diff --git a/docetl/optimizers/utils.py b/docetl/optimizers/utils.py index 53520c8b..d0305afd 100644 --- a/docetl/optimizers/utils.py +++ b/docetl/optimizers/utils.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from litellm import completion, completion_cost +from litellm import completion from docetl.operations.utils import truncate_messages from docetl.utils import completion_cost diff --git a/docetl/parsing_tools.py b/docetl/parsing_tools.py index a85290ab..b3e44152 100644 --- a/docetl/parsing_tools.py +++ b/docetl/parsing_tools.py @@ -1,8 +1,8 @@ import importlib import io import os -from typing import Dict, List, Optional, Any from functools import wraps +from typing import Any, Dict, List, Optional def with_input_output_key(fn): @@ -61,7 +61,7 @@ def whisper_speech_to_text(filename: str) -> List[str]: Returns: List[str]: Transcribed text. """ - import os + from litellm import transcription file_size = os.path.getsize(filename) @@ -274,7 +274,6 @@ def azure_di_read( Raises: ValueError: If DOCUMENTINTELLIGENCE_API_KEY or DOCUMENTINTELLIGENCE_ENDPOINT environment variables are not set. """ - import os from azure.ai.documentintelligence import DocumentIntelligenceClient from azure.ai.documentintelligence.models import AnalyzeDocumentRequest @@ -385,9 +384,9 @@ def paddleocr_pdf_to_string( Returns: List[str]: Extracted content as a list of formatted strings. """ - from paddleocr import PaddleOCR import fitz import numpy as np + from paddleocr import PaddleOCR ocr = PaddleOCR(use_angle_cls=True, lang=lang) @@ -454,10 +453,10 @@ def gptpdf_to_string( Returns: str: Extracted content as a string. """ - from gptpdf import parse_pdf - import tempfile + from gptpdf import parse_pdf + with tempfile.TemporaryDirectory() as temp_dir: kwargs = { "pdf_path": input_path, @@ -483,7 +482,7 @@ def gptpdf_to_string( def get_parser(name: str): try: entrypoint = importlib.metadata.entry_points(group="docetl.parser")[name] - except KeyError as e: + except KeyError: raise KeyError(f"Unrecognized parser {name}") return entrypoint.load() diff --git a/docetl/runner.py b/docetl/runner.py index e3bb8ed3..61067d15 100644 --- a/docetl/runner.py +++ b/docetl/runner.py @@ -1,24 +1,49 @@ -from collections import defaultdict +""" +The DSLRunner module implements a declarative pipeline execution engine with a pull-based +evaluation model. Key architectural decisions include: + +Design Patterns: +- Pull-based DAG: Operations are lazily evaluated only when their outputs are needed, + enabling efficient resource utilization and caching +- Dependency Injection: Operations receive their dependencies through a standardized interface, + making the system modular and testable +- Builder Pattern: Pipeline construction is separated from execution, allowing validation + and optimization before runtime + +Core Features: +- Transparent Caching: Automatic checkpointing and reuse of intermediate results +- Cost Tracking: Built-in tracking of operation costs for optimization +- Schema Validation: Type checking and schema validation at both build and runtime +- Extensible Operations: New operations can be added by implementing the operation interface + +The architecture prioritizes: +1. Separation of Concerns: Building, validation, and execution are distinct phases +2. Flexibility: Support for both streaming and batch processing patterns +3. Observability: Rich logging and cost tracking throughout execution +4. Performance: Lazy evaluation and caching optimize resource usage +""" + +import functools +import hashlib import json import os import shutil import time -import functools +from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple, Union -from docetl.builder import Optimizer -from docetl.console import get_console -from pydantic import BaseModel from dotenv import load_dotenv -import hashlib -from rich.console import Console -from rich.prompt import Confirm +from pydantic import BaseModel +from rich.markup import escape from rich.panel import Panel +from docetl.config_wrapper import ConfigWrapper +from docetl.containers import OpContainer, StepBoundary from docetl.dataset import Dataset, create_parsing_tool_map from docetl.operations import get_operation, get_operations -from docetl.operations.utils import flush_cache -from docetl.config_wrapper import ConfigWrapper +from docetl.operations.base import BaseOperation +from docetl.optimizer import Optimizer + from . import schemas from .utils import classproperty @@ -27,15 +52,25 @@ class DSLRunner(ConfigWrapper): """ - This class is responsible for running DocETL pipelines. It manages datasets, executes pipeline steps, and tracks - the cost of operations. - - Attributes: - config (Dict): The loaded configuration from the YAML file. - default_model (str): The default language model to use for operations. - max_threads (int): Maximum number of threads for parallel processing. - console (Console): Rich console for output formatting. - datasets (Dict): Storage for loaded datasets. + DSLRunner orchestrates pipeline execution by building and traversing a DAG of OpContainers. + The runner uses a two-phase approach: + + 1. Build Phase: + - Parses YAML config into a DAG of OpContainers + - Each operation becomes a node connected to its dependencies + - Special handling for equijoins which have two parent nodes + - Validates operation syntax and schema compatibility + + 2. Execution Phase: + - Starts from the final operation and pulls data through the DAG + - Handles caching/checkpointing of intermediate results + - Tracks costs and execution metrics + - Manages dataset loading and result persistence + + The separation between build and execution phases allows for: + - Pipeline validation before any execution + - Cost estimation and optimization + - Partial pipeline execution for testing """ @classproperty @@ -79,32 +114,164 @@ def __init__(self, config: Dict, max_threads: int = None, **kwargs): max_threads=max_threads, **kwargs, ) - self.datasets = {} + self.total_cost = 0 + self._initialize_state() + self._setup_parsing_tools() + self._build_operation_graph(config) + self._compute_operation_hashes() + + # Run initial validation + self.syntax_check() + def _initialize_state(self) -> None: + """Initialize basic runner state and datasets""" + self.datasets = {} self.intermediate_dir = ( self.config.get("pipeline", {}).get("output", {}).get("intermediate_dir") ) - # Create parsing tool map + def _setup_parsing_tools(self) -> None: + """Set up parsing tools from configuration""" self.parsing_tool_map = create_parsing_tool_map( self.config.get("parsing_tools", None) ) - self.syntax_check() + def _build_operation_graph(self, config: Dict) -> None: + """Build the DAG of operations from configuration""" + self.config = config + self.op_container_map = {} + self.last_op_container = None - op_map = {op["name"]: op for op in self.config["operations"]} + for step in self.config["pipeline"]["steps"]: + self._validate_step(step) + + if step.get("input"): + self._add_scan_operation(step) + else: + self._add_equijoin_operation(step) + + self._add_step_operations(step) + self._add_step_boundary(step) + + def _validate_step(self, step: Dict) -> None: + """Validate step configuration""" + assert "name" in step.keys(), f"Step {step} does not have a name" + assert "operations" in step.keys(), f"Step {step} does not have `operations`" + + def _add_scan_operation(self, step: Dict) -> None: + """Add a scan operation for input datasets""" + scan_op_container = OpContainer( + f"{step['name']}/scan_{step['input']}", + self, + { + "type": "scan", + "dataset_name": step["input"], + "name": f"scan_{step['input']}", + }, + ) + self.op_container_map[f"{step['name']}/scan_{step['input']}"] = ( + scan_op_container + ) + if self.last_op_container: + scan_op_container.add_child(self.last_op_container) + self.last_op_container = scan_op_container + + def _add_equijoin_operation(self, step: Dict) -> None: + """Add an equijoin operation with its scan operations""" + equijoin_operation_name = list(step["operations"][0].keys())[0] + left_dataset_name = list(step["operations"][0].values())[0]["left"] + right_dataset_name = list(step["operations"][0].values())[0]["right"] + + left_scan_op_container = OpContainer( + f"{step['name']}/scan_{left_dataset_name}", + self, + { + "type": "scan", + "dataset_name": left_dataset_name, + "name": f"scan_{left_dataset_name}", + }, + ) + if self.last_op_container: + left_scan_op_container.add_child(self.last_op_container) + right_scan_op_container = OpContainer( + f"{step['name']}/scan_{right_dataset_name}", + self, + { + "type": "scan", + "dataset_name": right_dataset_name, + "name": f"scan_{right_dataset_name}", + }, + ) + if self.last_op_container: + right_scan_op_container.add_child(self.last_op_container) + equijoin_op_container = OpContainer( + f"{step['name']}/{equijoin_operation_name}", + self, + self.find_operation(equijoin_operation_name), + left_name=left_dataset_name, + right_name=right_dataset_name, + ) + + equijoin_op_container.add_child(left_scan_op_container) + equijoin_op_container.add_child(right_scan_op_container) + + self.last_op_container = equijoin_op_container + self.op_container_map[f"{step['name']}/{equijoin_operation_name}"] = ( + equijoin_op_container + ) + self.op_container_map[f"{step['name']}/scan_{left_dataset_name}"] = ( + left_scan_op_container + ) + self.op_container_map[f"{step['name']}/scan_{right_dataset_name}"] = ( + right_scan_op_container + ) - # Hash each pipeline step/operation - # for each step op, hash the code of each op up until and (including that op) + def _add_step_operations(self, step: Dict) -> None: + """Add operations for a step""" + op_start_idx = 1 if not step.get("input") else 0 + + for operation_name in step["operations"][op_start_idx:]: + if not isinstance(operation_name, str): + raise ValueError( + f"Operation {operation_name} in step {step['name']} should be a string. " + "If you intend for it to be an equijoin, don't specify an input in the step." + ) + + op_container = OpContainer( + f"{step['name']}/{operation_name}", + self, + self.find_operation(operation_name), + ) + op_container.add_child(self.last_op_container) + self.last_op_container = op_container + self.op_container_map[f"{step['name']}/{operation_name}"] = op_container + + def _add_step_boundary(self, step: Dict) -> None: + """Add a step boundary node""" + step_boundary = StepBoundary( + f"{step['name']}/boundary", + self, + {"type": "step_boundary", "name": f"{step['name']}/boundary"}, + ) + step_boundary.add_child(self.last_op_container) + self.op_container_map[f"{step['name']}/boundary"] = step_boundary + self.last_op_container = step_boundary + + def _compute_operation_hashes(self) -> None: + """Compute hashes for operations to enable caching""" + op_map = {op["name"]: op for op in self.config["operations"]} self.step_op_hashes = defaultdict(dict) + for step in self.config["pipeline"]["steps"]: for idx, op in enumerate(step["operations"]): op_name = op if isinstance(op, str) else list(op.keys())[0] - all_ops_until_and_including_current = [ - op_map[prev_op] for prev_op in step["operations"][:idx] - ] + [op_map[op_name]] + [self.config.get("system_prompt", {})] - # If there's no model in the op, add the default model + all_ops_until_and_including_current = ( + [op_map[prev_op] for prev_op in step["operations"][:idx]] + + [op_map[op_name]] + + [self.config.get("system_prompt", {})] + ) + for op in all_ops_until_and_including_current: if "model" not in op: op["model"] = self.default_model @@ -135,32 +302,125 @@ def syntax_check(self): """ Perform a syntax check on all operations defined in the configuration. """ - syntax_content = "[yellow]Performing syntax check on all operations...[/yellow]\n\n" - + self.console.log("[yellow]Checking operations...[/yellow]") + # Just validate that it's a json file if specified self.get_output_path() + current = self.last_op_container try: - for operation_config in self.config["operations"]: - operation = operation_config["name"] - operation_type = operation_config["type"] - - operation_class = get_operation(operation_type) - operation_class( - self, - operation_config, - self.default_model, - self.max_threads, - self.console, - ) - syntax_content += f"[green]✓[/green] Operation '{operation}' ({operation_type})\n" + # Walk the last op container to check syntax + op_containers = [] + if self.last_op_container: + op_containers = [self.last_op_container] + + while op_containers: + current = op_containers.pop(0) + syntax_result = current.syntax_check() + self.console.log(syntax_result, end="") + # Add all children to the queue + op_containers.extend(current.children) except Exception as e: raise ValueError( - f"Syntax check failed for operation '{operation}': {str(e)}" + f"Syntax check failed for operation '{current.name}': {str(e)}" + ) + + self.console.log("[green]✓ All operations passed syntax check[/green]") + + def print_query_plan(self, show_boundaries=False): + """ + Print a visual representation of the entire query plan using indentation and arrows. + Operations are color-coded by step to show the pipeline structure while maintaining + dependencies between steps. + """ + if not self.last_op_container: + self.console.log("\n[bold]Pipeline Steps:[/bold]") + self.console.log( + Panel("No operations in pipeline", title="Query Plan", width=100) ) + self.console.log() + return + + def _print_op( + op: OpContainer, indent: int = 0, step_colors: Dict[str, str] = None + ) -> str: + # Handle boundary operations based on show_boundaries flag + if isinstance(op, StepBoundary): + if show_boundaries: + output = [] + indent_str = " " * indent + step_name = op.name.split("/")[0] + color = step_colors.get(step_name, "white") + output.append( + f"{indent_str}[{color}][bold]{op.name}[/bold][/{color}]" + ) + output.append(f"{indent_str}Type: step_boundary") + if op.children: + output.append(f"{indent_str}[yellow]▼[/yellow]") + for child in op.children: + output.append(_print_op(child, indent + 1, step_colors)) + return "\n".join(output) + elif op.children: + return _print_op(op.children[0], indent, step_colors) + return "" + + # Build the string for the current operation with indentation + indent_str = " " * indent + output = [] + + # Color code the operation name based on its step + step_name = op.name.split("/")[0] + color = step_colors.get(step_name, "white") + output.append(f"{indent_str}[{color}][bold]{op.name}[/bold][/{color}]") + output.append(f"{indent_str}Type: {op.config['type']}") + + # Add schema if available + if "output" in op.config and "schema" in op.config["output"]: + output.append(f"{indent_str}Output Schema:") + for field, field_type in op.config["output"]["schema"].items(): + escaped_type = escape(str(field_type)) + output.append( + f"{indent_str} {field}: [bright_white]{escaped_type}[/bright_white]" + ) + + # Add children + if op.children: + if op.is_equijoin: + output.append(f"{indent_str}[yellow]▼ LEFT[/yellow]") + output.append(_print_op(op.children[0], indent + 1, step_colors)) + output.append(f"{indent_str}[yellow]▼ RIGHT[/yellow]") + output.append(_print_op(op.children[1], indent + 1, step_colors)) + else: + output.append(f"{indent_str}[yellow]▼[/yellow]") + for child in op.children: + output.append(_print_op(child, indent + 1, step_colors)) + + return "\n".join(output) + + # Get all step boundaries and extract unique step names + step_boundaries = [ + op + for name, op in self.op_container_map.items() + if isinstance(op, StepBoundary) + ] + step_boundaries.sort(key=lambda x: x.name) + + # Create a color map for steps - using distinct colors + colors = ["cyan", "magenta", "green", "yellow", "blue", "red"] + step_names = [b.name.split("/")[0] for b in step_boundaries] + step_colors = { + name: colors[i % len(colors)] for i, name in enumerate(step_names) + } + + # Print the legend + self.console.log("\n[bold]Pipeline Steps:[/bold]") + for step_name, color in step_colors.items(): + self.console.log(f"[{color}]■[/{color}] {step_name}") - syntax_content += "\n[green]Syntax check passed for all operations.[/green]" - self.console.print(Panel.fit(syntax_content, title="[yellow]Syntax Check[/yellow]")) + # Print the full query plan starting from the last step boundary + query_plan = _print_op(self.last_op_container, step_colors=step_colors) + self.console.log(Panel(query_plan, title="Query Plan", width=100)) + self.console.log() def find_operation(self, op_name: str) -> Dict: for operation_config in self.config["operations"]: @@ -171,93 +431,44 @@ def find_operation(self, op_name: str) -> Dict: def load_run_save(self) -> float: """ Execute the entire pipeline defined in the configuration. - - This method loads datasets, executes each step in the pipeline, saves the output, - and returns the total cost of execution. - - Returns: - float: The total cost of executing the pipeline. """ - - # Fail early if we can't save the output... output_path = self.get_output_path(require=True) - self.console.rule("[bold blue]Pipeline Execution[/bold blue]") + # Print the query plan + self.print_query_plan() + start_time = time.time() - output, total_cost = self.run(self.load()) - self.save(output) + if self.last_op_container: + self.load() + self.console.rule("[bold]Pipeline Execution[/bold]") + output, _, _ = self.last_op_container.next() + self.save(output) execution_time = time.time() - start_time - summary_content = ( - "[bold]Pipeline Execution Complete[/bold]\n\n" - f"[bold green]Total cost:[/bold green] [green]${total_cost:.2f}[/green]\n" - f"[bold green]Total time:[/bold green] [green]{execution_time:.2f} seconds[/green]\n" - f"[bold green]Intermediate directory:[/bold green]\n[green]{self.intermediate_dir}[/green]\n" - f"[bold green]Saved output to:[/bold green]\n[green]{output_path}[/green]" - ) - self.console.print(Panel.fit(summary_content, title="[bold green]Execution Summary[/bold green]")) - - return total_cost - - def run(self, datasets) -> float: - """ - Execute the entire pipeline defined in the configuration on some data. - - Args: - datasets (dict[str, Dataset | List[Dict]]): input datasets to transform - Returns: - (List[Dict], float): The transformed data and the total cost of execution. - """ - self.datasets = { - name: ( - dataset - if isinstance(dataset, Dataset) - else Dataset(self, "memory", dataset) + # Print execution summary + summary = ( + f"Cost: [green]${self.total_cost:.2f}[/green]\n" + f"Time: {execution_time:.2f}s\n" + + ( + f"Cache: [dim]{self.intermediate_dir}[/dim]\n" + if self.intermediate_dir + else "" ) - for name, dataset in datasets.items() - } - total_cost = 0 - for step in self.config["pipeline"]["steps"]: - step_name = step["name"] - input_data = ( - self.datasets[step["input"]].load() if "input" in step else None - ) - output_data, step_cost = self.execute_step(step, input_data) - self.datasets[step_name] = Dataset(self, "memory", output_data) - flush_cache(self.console) - total_cost += step_cost - self.console.log( - f"Step [cyan]{step_name}[/cyan] completed. Cost: [green]${step_cost:.2f}[/green]" - ) - - # Save the self.step_op_hashes to a file if self.intermediate_dir exists - if self.intermediate_dir: - os.makedirs(self.intermediate_dir, exist_ok=True) - with open( - os.path.join(self.intermediate_dir, ".docetl_intermediate_config.json"), - "w", - ) as f: - json.dump(self.step_op_hashes, f) - - return ( - self.datasets[self.config["pipeline"]["steps"][-1]["name"]].load(), - total_cost, + + f"Output: [dim]{output_path}[/dim]" ) + self.console.log(Panel(summary, title="Execution Summary")) - def load(self): + return self.total_cost + + def load(self) -> None: """ Load all datasets defined in the configuration. - - This method creates Dataset objects for each dataset in the configuration. - - Raises: - ValueError: If an unsupported dataset type is encountered. """ - dataset_content = "" datasets = {} - + self.console.rule("[bold]Loading Datasets[/bold]") + for name, dataset_config in self.config["datasets"].items(): if dataset_config["type"] == "file": datasets[name] = Dataset( @@ -268,26 +479,28 @@ def load(self): parsing=dataset_config.get("parsing", []), user_defined_parsing_tool_map=self.parsing_tool_map, ) - dataset_content += f"[green]✓[/green] Loaded dataset: [bold]{name}[/bold]\n" + self.console.log( + f"[green]✓[/green] Loaded dataset '{name}' from {dataset_config['path']}" + ) else: raise ValueError(f"Unsupported dataset type: {dataset_config['type']}") - - self.console.print(Panel.fit(dataset_content, title="[cyan]Loading Datasets[/cyan]")) - return datasets - def save(self, data: List[Dict]): + self.datasets = { + name: ( + dataset + if isinstance(dataset, Dataset) + else Dataset(self, "memory", dataset) + ) + for name, dataset in datasets.items() + } + self.console.log() + + def save(self, data: List[Dict]) -> None: """ Save the final output of the pipeline. - - Args: - data (List[Dict]): The data to be saved. - - Raises: - ValueError: If an unsupported output type is specified in the configuration. """ self.get_output_path(require=True) - self.console.rule("[cyan]Saving Output[/cyan]") output_config = self.config["pipeline"]["output"] if output_config["type"] == "file": # Create the directory if it doesn't exist @@ -301,115 +514,18 @@ def save(self, data: List[Dict]): with open(output_config["path"], "w", newline="") as file: writer = csv.DictWriter(file, fieldnames=data[0].keys()) - limited_data = [{k: d.get(k, None) for k in data[0].keys()} for d in data] + limited_data = [ + {k: d.get(k, None) for k in data[0].keys()} for d in data + ] writer.writeheader() writer.writerows(limited_data) - self.console.print( - f"[green italic]💾 Output saved to {output_config['path']}[/green italic]" + self.console.log( + f"[green]✓[/green] Saved to [dim]{output_config['path']}[/dim]\n" ) else: - raise ValueError(f"Unsupported output type: {output_config['type']}. Supported types: file") - - def execute_step( - self, step: Dict, input_data: Optional[List[Dict]] - ) -> Tuple[List[Dict], float]: - """ - Execute a single step in the pipeline. - """ - step_content = f"[bold blue]Step: {step['name']}[/bold blue]\n\n" - total_cost = 0 - - for operation in step["operations"]: - if isinstance(operation, dict): - operation_name = list(operation.keys())[0] - operation_config = self.find_operation(operation_name) - else: - operation_name = operation - operation_config = {} - - # Load from checkpoint if exists - attempted_input_data = self._load_from_checkpoint_if_exists( - step["name"], operation_name + raise ValueError( + f"Unsupported output type: {output_config['type']}. Supported types: file" ) - if attempted_input_data is not None: - input_data = attempted_input_data - step_content += f"[green]✓[/green] [italic]Loaded saved data for operation '{operation_name}'[/italic]\n" - continue - - # Delete existing intermediate file before running operation - if self.intermediate_dir: - checkpoint_path = os.path.join( - self.intermediate_dir, step["name"], f"{operation_name}.json" - ) - if os.path.exists(checkpoint_path): - os.remove(checkpoint_path) - - op_object = self.find_operation(operation_name).copy() - op_object.update(operation_config) - - # If sample is set, sample the input data - if op_object.get("sample"): - if input_data is None: - input_data = self.datasets[step["input"]].sample(op_object["sample"], False) - else: - input_data = input_data[: op_object["sample"]] - - with self.console.status("[bold]Running Operation:[/bold]") as status: - status.update(f"Type: [cyan]{op_object['type']}[/cyan]") - status.update(f"Name: [cyan]{op_object.get('name', 'Unnamed')}[/cyan]") - self.status = status - - operation_class = get_operation(op_object["type"]) - operation_instance = operation_class( - self, - op_object, - self.default_model, - self.max_threads, - self.console, - self.status, - ) - if op_object["type"] == "equijoin": - left_data = self.datasets[next(iter(operation.values()))["left"]].load() - right_data = self.datasets[next(iter(operation.values()))["right"]].load() - input_data, cost = operation_instance.execute(left_data, right_data) - else: - input_data, cost = operation_instance.execute(input_data) - total_cost += cost - step_content += f"[green]✓[/green] Operation [cyan]{operation_name}[/cyan] (Cost: [green]${cost:.2f}[/green])\n" - self.console.print(f"[green]✓[/green] Operation [cyan]{operation_name}[/cyan] completed (Cost: [green]${cost:.2f}[/green])") - - # Checkpoint after each operation - if self.intermediate_dir: - self._save_checkpoint(step["name"], operation_name, input_data) - - # Load existing step op hash, if exists, merge self.step_op_hashes[step["name"]][operation_name] into it - # Save the step op hash - intermediate_config_path = os.path.join( - self.intermediate_dir, ".docetl_intermediate_config.json" - ) - if os.path.exists(intermediate_config_path): - with open(intermediate_config_path, "r") as f: - existing_config = json.load(f) - else: - existing_config = {} - - if step["name"] not in existing_config: - existing_config[step["name"]] = {} - existing_config[step["name"]][operation_name] = self.step_op_hashes[ - step["name"] - ][operation_name] - - # Resave - with open(intermediate_config_path, "w") as f: - json.dump(existing_config, f, indent=2) - - step_content += f"[green]✓[/green] [italic]Saved checkpoint for operation '{operation_name}'[/italic]\n" - - self.console.print(Panel.fit( - step_content, - title=f"[bold blue]Step Execution: {step['name']}[/bold blue]" - )) - return input_data, total_cost def _load_from_checkpoint_if_exists( self, step_name: str, operation_name: str @@ -424,6 +540,13 @@ def _load_from_checkpoint_if_exists( if not os.path.exists(intermediate_config_path): return None + # Make sure the step and op name is in the checkpoint config path + if ( + step_name not in self.step_op_hashes + or operation_name not in self.step_op_hashes[step_name] + ): + return None + # See if the checkpoint config is the same as the current step op hash with open(intermediate_config_path, "r") as f: intermediate_config = json.load(f) @@ -444,7 +567,7 @@ def _load_from_checkpoint_if_exists( self, "file", checkpoint_path, "local" ) - self.console.print( + self.console.log( f"[green]✓[/green] [italic]Loaded checkpoint for operation '{operation_name}' in step '{step_name}' from {checkpoint_path}[/italic]" ) @@ -462,7 +585,9 @@ def clear_intermediate(self) -> None: raise ValueError("Intermediate directory not set. Cannot clear intermediate.") - def _save_checkpoint(self, step_name: str, operation_name: str, data: List[Dict]): + def _save_checkpoint( + self, step_name: str, operation_name: str, data: List[Dict] + ) -> None: """ Save a checkpoint of the current data after an operation. @@ -487,13 +612,18 @@ def _save_checkpoint(self, step_name: str, operation_name: str, data: List[Dict] with open(checkpoint_path, "w") as f: json.dump(data, f) - self.console.print( + self.console.log( f"[green]✓ [italic]Intermediate saved for operation '{operation_name}' in step '{step_name}' at {checkpoint_path}[/italic][/green]" ) - def should_optimize(self, step_name: str, op_name: str, **kwargs) -> Tuple[str, float, List[Dict[str, Any]], List[Dict[str, Any]]]: + def should_optimize( + self, step_name: str, op_name: str, **kwargs + ) -> Tuple[str, float, List[Dict[str, Any]], List[Dict[str, Any]]]: + self.load() builder = Optimizer(self, **kwargs) - return builder.should_optimize(step_name, op_name) + self.optimizer = builder + result = builder.should_optimize(step_name, op_name) + return result def optimize( self, @@ -502,28 +632,77 @@ def optimize( **kwargs, ) -> Tuple[Union[Dict, "DSLRunner"], float]: + if not self.last_op_container: + raise ValueError("No operations in pipeline. Cannot optimize.") + + self.load() + builder = Optimizer( self, **kwargs, ) - cost = builder.optimize() - - # Dump via json - # import json - # with open(f"{self.base_name}_optimizer_output.json", "wb") as f: - # json.dump(builder.captured_output.optimizer_output, f) - + self.optimizer = builder + llm_api_cost = builder.optimize() + self.total_cost += llm_api_cost if save: builder.save_optimized_config(f"{self.base_name}_opt.yaml") self.optimized_config_path = f"{self.base_name}_opt.yaml" if return_pipeline: - return DSLRunner(builder.clean_optimized_config(), self.max_threads), cost + return ( + DSLRunner(builder.clean_optimized_config(), self.max_threads), + self.total_cost, + ) - return builder.clean_optimized_config(), cost + return builder.clean_optimized_config(), self.total_cost + def _run_operation( + self, + op_config: Dict[str, Any], + input_data: Union[List[Dict[str, Any]], Dict[str, Any]], + return_instance: bool = False, + is_build: bool = False, + ) -> Union[List[Dict[str, Any]], Tuple[List[Dict[str, Any]], BaseOperation, float]]: + """ + Run a single operation based on its configuration. -if __name__ == "__main__": - runner = DSLRunner("workloads/medical/map_opt.yaml") - runner.run() + This method creates an instance of the appropriate operation class and executes it. + It also updates the total operation cost. + + Args: + op_config (Dict[str, Any]): The configuration of the operation to run. + input_data (List[Dict[str, Any]]): The input data for the operation. + return_instance (bool, optional): If True, return the operation instance along with the output data. + + Returns: + Union[List[Dict[str, Any]], Tuple[List[Dict[str, Any]], BaseOperation, float]]: + If return_instance is False, returns the output data. + If return_instance is True, returns a tuple of the output data, the operation instance, and the cost. + """ + operation_class = get_operation(op_config["type"]) + + oc_kwargs = { + "runner": self, + "config": op_config, + "default_model": self.config["default_model"], + "max_threads": self.max_threads, + "console": self.console, + "status": self.status, + } + operation_instance = operation_class(**oc_kwargs) + if op_config["type"] == "equijoin": + output_data, cost = operation_instance.execute( + input_data["left_data"], input_data["right_data"] + ) + elif op_config["type"] == "filter": + output_data, cost = operation_instance.execute(input_data, is_build) + else: + output_data, cost = operation_instance.execute(input_data) + + self.total_cost += cost + + if return_instance: + return output_data, operation_instance + else: + return output_data diff --git a/docetl/schemas.py b/docetl/schemas.py index b54f3438..a25c49f4 100644 --- a/docetl/schemas.py +++ b/docetl/schemas.py @@ -1,18 +1,22 @@ -from .base_schemas import * - -from .operations import cluster -from .operations import equijoin -from .operations import filter -from .operations import gather -from .operations import map -from .operations import reduce -from .operations import resolve -from .operations import sample -from .operations import split -from .operations import unnest +from typing import Union from . import dataset +# ruff: noqa: F403 +from .base_schemas import * +from .operations import ( + cluster, + equijoin, + filter, + gather, + map, + reduce, + resolve, + sample, + split, + unnest, +) + MapOp = map.MapOperation.schema ResolveOp = resolve.ResolveOperation.schema ReduceOp = reduce.ReduceOperation.schema diff --git a/docetl/utils.py b/docetl/utils.py index 740eb886..9e2872ba 100644 --- a/docetl/utils.py +++ b/docetl/utils.py @@ -1,43 +1,47 @@ import json +import math import re -from typing import Any, Dict, List from enum import Enum +from typing import Any, Dict, List + import tiktoken import yaml from jinja2 import Environment, meta from litellm import completion_cost as lcc - from lzstring import LZString + class Decryptor: def __init__(self, secret_key: str): self.key = secret_key self.lz = LZString() - + def decrypt(self, encrypted_data: str) -> str: try: # First decompress the data compressed = self.lz.decompressFromBase64(encrypted_data) if not compressed: raise ValueError("Invalid compressed data") - + # Then decode using the key - result = '' + result = "" for i in range(len(compressed)): char_code = ord(compressed[i]) - ord(self.key[i % len(self.key)]) result += chr(char_code) - + return result - + except Exception as e: print(f"Decryption failed: {str(e)}") return None + def decrypt(encrypted_data: str, secret_key: str) -> str: if not secret_key: return encrypted_data return Decryptor(secret_key).decrypt(encrypted_data) + class StageType(Enum): SAMPLE_RUN = "sample_run" SHOULD_OPTIMIZE = "should_optimize" @@ -45,6 +49,7 @@ class StageType(Enum): EVALUATION_RESULTS = "evaluation_results" END = "end" + def get_stage_description(stage_type: StageType) -> str: if stage_type == StageType.SAMPLE_RUN: return "Running samples..." @@ -58,14 +63,15 @@ def get_stage_description(stage_type: StageType) -> str: return "Optimization complete!" raise ValueError(f"Unknown stage type: {stage_type}") + class CapturedOutput: def __init__(self): self.optimizer_output = {} self.step = None - + def set_step(self, step: str): self.step = step - + def save_optimizer_output(self, stage_type: StageType, output: Any): if self.step is None: raise ValueError("Step must be set before saving optimizer output") @@ -76,6 +82,7 @@ def save_optimizer_output(self, stage_type: StageType, output: Any): self.optimizer_output[self.step][stage_type] = output + def extract_jinja_variables(template_string: str) -> List[str]: """ Extract variables from a Jinja2 template string. @@ -222,9 +229,70 @@ def truncate_sample_data( return truncated_data +def smart_sample( + input_data: List[Dict], sample_size_needed: int, max_unique_values: int = 5 +) -> List[Dict]: + """ + Smart sampling strategy that: + 1. Identifies categorical fields by checking for low cardinality (few unique values) + 2. Stratifies on up to 3 categorical fields + 3. Takes largest documents per stratum + + Args: + input_data (List[Dict]): List of input documents + sample_size_needed (int): Number of samples needed + max_unique_values (int): Maximum number of unique values for a field to be considered categorical + + Returns: + List[Dict]: Sampled documents + """ + if not input_data or sample_size_needed >= len(input_data): + return input_data + + # Find fields with low cardinality (categorical fields) + field_unique_values = {} + for field in input_data[0].keys(): + unique_values = set(str(doc.get(field, "")) for doc in input_data) + if len(unique_values) <= max_unique_values: + field_unique_values[field] = len(unique_values) + + # Sort by number of unique values and take top 3 categorical fields + categorical_fields = sorted(field_unique_values.items(), key=lambda x: x[1])[:3] + categorical_fields = [field for field, _ in categorical_fields] + + # If no categorical fields, return largest documents + if not categorical_fields: + return sorted(input_data, key=lambda x: len(json.dumps(x)), reverse=True)[ + :sample_size_needed + ] + + # Group data by categorical fields + groups = {} + for doc in input_data: + key = tuple(str(doc.get(field, "")) for field in categorical_fields) + if key not in groups: + groups[key] = [] + groups[key].append(doc) + + # Calculate samples needed per group (evenly distributed) + samples_per_group = math.ceil(sample_size_needed / len(groups)) + + # Take largest documents from each group + result = [] + for docs in groups.values(): + sorted_docs = sorted(docs, key=lambda x: len(json.dumps(x)), reverse=True) + result.extend(sorted_docs[:samples_per_group]) + + # If we have too many samples, trim to exact size needed + # Sort by size again to ensure we keep the largest documents + return sorted(result, key=lambda x: len(json.dumps(x)), reverse=True)[ + :sample_size_needed + ] + class classproperty(object): def __init__(self, f): self.f = f + def __get__(self, obj, owner): return self.f(owner) diff --git a/server/app/routes/pipeline.py b/server/app/routes/pipeline.py index 6ddc91cd..e09b6ef1 100644 --- a/server/app/routes/pipeline.py +++ b/server/app/routes/pipeline.py @@ -212,7 +212,7 @@ async def run_pipeline(): ) if user_message == "kill": - runner.console.print("Stopping process...") + runner.console.log("Stopping process...") await websocket.send_json({ "type": "error", "message": "Process stopped by user request" diff --git a/tailwind.config.js b/tailwind.config.js deleted file mode 100644 index 2c759cc0..00000000 --- a/tailwind.config.js +++ /dev/null @@ -1,16 +0,0 @@ -{ - theme: { - extend: { - keyframes: { - shake: { - '0%, 100%': { transform: 'translateX(0)' }, - '25%': { transform: 'translateX(-4px)' }, - '75%': { transform: 'translateX(4px)' }, - }, - }, - animation: { - shake: 'shake 0.2s ease-in-out 0s 2', - }, - }, - }, -} \ No newline at end of file diff --git a/tests/basic/test_optimizer.py b/tests/basic/test_optimizer.py index 1cae2101..b821535b 100644 --- a/tests/basic/test_optimizer.py +++ b/tests/basic/test_optimizer.py @@ -2,7 +2,6 @@ import pytest import json import shutil -from docetl.builder import Optimizer from docetl.runner import DSLRunner @pytest.fixture @@ -71,22 +70,13 @@ def runner(test_config): def test_optimize_map_operation(runner, test_dir): """Test that the optimizer can optimize a simple map operation""" - # Initialize optimizer - optimizer = Optimizer( - runner=runner, - model="gpt-4o-mini", - timeout=30 - ) # Run optimization - total_cost = optimizer.optimize() + optimized_config, total_cost = runner.optimize(return_pipeline=False) # Check that optimization completed successfully assert total_cost >= 0 # Cost should be non-negative - # Get the optimized config - optimized_config = optimizer.clean_optimized_config() - # Check that the optimized config contains operations assert "operations" in optimized_config assert len(optimized_config["operations"]) > 0 @@ -101,10 +91,4 @@ def test_optimize_map_operation(runner, test_dir): assert first_step["name"] == "name_extraction" assert "operations" in first_step assert len(first_step["operations"]) > 0 - - # Save the optimized config - output_path = test_dir / "optimized_config.yaml" - optimizer.save_optimized_config(str(output_path)) - - # Check that the file was created - assert output_path.exists() + diff --git a/tests/test_api.py b/tests/test_api.py index 0e2743b3..da4b7cd9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -232,7 +232,7 @@ def test_pipeline_optimization( ) assert isinstance(optimized_pipeline, Pipeline) - assert len(optimized_pipeline.operations) == len(pipeline.operations) + assert len(optimized_pipeline.operations) == len(pipeline.operations) + 1 assert len(optimized_pipeline.steps) == len(pipeline.steps) diff --git a/tests/test_synth_gather.py b/tests/test_synth_gather.py index c832c1c3..b1ae88ed 100644 --- a/tests/test_synth_gather.py +++ b/tests/test_synth_gather.py @@ -2,7 +2,7 @@ import json import tempfile import os -from docetl.builder import Optimizer + from docetl.runner import DSLRunner from docetl.operations.split import SplitOperation from docetl.operations.map import MapOperation diff --git a/tests/test_synth_resolve.py b/tests/test_synth_resolve.py index a050da40..ecc770b2 100644 --- a/tests/test_synth_resolve.py +++ b/tests/test_synth_resolve.py @@ -3,7 +3,6 @@ import json import tempfile import os -from docetl.builder import Optimizer @pytest.fixture diff --git a/website/package-lock.json b/website/package-lock.json index bc61c2ef..0c5dee39 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -92,6 +92,7 @@ "eslint-plugin-unused-imports": "^4.1.4", "globals": "^15.11.0", "postcss": "^8", + "tailwind-scrollbar": "^3.1.0", "tailwindcss": "^3.4.1", "typescript": "^5", "typescript-eslint": "^8.11.0" @@ -12406,6 +12407,19 @@ "url": "https://github.com/sponsors/dcastil" } }, + "node_modules/tailwind-scrollbar": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/tailwind-scrollbar/-/tailwind-scrollbar-3.1.0.tgz", + "integrity": "sha512-pmrtDIZeHyu2idTejfV59SbaJyvp1VRjYxAjZBH0jnyrPRo6HL1kD5Glz8VPagasqr6oAx6M05+Tuw429Z8jxg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12.13.0" + }, + "peerDependencies": { + "tailwindcss": "3.x" + } + }, "node_modules/tailwindcss": { "version": "3.4.11", "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.11.tgz", diff --git a/website/package.json b/website/package.json index 6c371680..be0ef37e 100644 --- a/website/package.json +++ b/website/package.json @@ -93,6 +93,7 @@ "eslint-plugin-unused-imports": "^4.1.4", "globals": "^15.11.0", "postcss": "^8", + "tailwind-scrollbar": "^3.1.0", "tailwindcss": "^3.4.1", "typescript": "^5", "typescript-eslint": "^8.11.0" diff --git a/website/src/app/playground/page.tsx b/website/src/app/playground/page.tsx index a824f37c..04e5fcae 100644 --- a/website/src/app/playground/page.tsx +++ b/website/src/app/playground/page.tsx @@ -723,7 +723,7 @@ const CodeEditorPipelineApp: React.FC = () => { onDragEnd={() => (document.body.style.cursor = "default")} > {showFileExplorer && ( - + { onDragEnd={() => (document.body.style.cursor = "default")} > @@ -756,7 +756,7 @@ const CodeEditorPipelineApp: React.FC = () => { @@ -769,7 +769,11 @@ const CodeEditorPipelineApp: React.FC = () => { )} - + = ({ isWebSocketClosed ? "opacity-50" : "" }`} > -
+
       
diff --git a/website/tailwind.config.ts b/website/tailwind.config.ts index 312113ea..53c28214 100644 --- a/website/tailwind.config.ts +++ b/website/tailwind.config.ts @@ -1,85 +1,87 @@ import type { Config } from "tailwindcss"; +import animate from "tailwindcss-animate"; +import scrollbar from "tailwind-scrollbar"; const config: Config = { - darkMode: ["class"], - content: [ + darkMode: ["class"], + content: [ "./src/pages/**/*.{js,ts,jsx,tsx,mdx}", "./src/components/**/*.{js,ts,jsx,tsx,mdx}", "./src/app/**/*.{js,ts,jsx,tsx,mdx}", ], theme: { - extend: { - colors: { - background: 'hsl(var(--background))', - foreground: 'hsl(var(--foreground))', - card: { - DEFAULT: 'hsl(var(--card))', - foreground: 'hsl(var(--card-foreground))' - }, - popover: { - DEFAULT: 'hsl(var(--popover))', - foreground: 'hsl(var(--popover-foreground))' - }, - primary: { - DEFAULT: 'hsl(var(--primary))', - foreground: 'hsl(var(--primary-foreground))' - }, - secondary: { - DEFAULT: 'hsl(var(--secondary))', - foreground: 'hsl(var(--secondary-foreground))' - }, - muted: { - DEFAULT: 'hsl(var(--muted))', - foreground: 'hsl(var(--muted-foreground))' - }, - accent: { - DEFAULT: 'hsl(var(--accent))', - foreground: 'hsl(var(--accent-foreground))' - }, - destructive: { - DEFAULT: 'hsl(var(--destructive))', - foreground: 'hsl(var(--destructive-foreground))' - }, - border: 'hsl(var(--border))', - input: 'hsl(var(--input))', - ring: 'hsl(var(--ring))', - chart: { - '1': 'hsl(var(--chart-1))', - '2': 'hsl(var(--chart-2))', - '3': 'hsl(var(--chart-3))', - '4': 'hsl(var(--chart-4))', - '5': 'hsl(var(--chart-5))' - } - }, - borderRadius: { - lg: 'var(--radius)', - md: 'calc(var(--radius) - 2px)', - sm: 'calc(var(--radius) - 4px)' - }, - keyframes: { - 'accordion-down': { - from: { - height: '0' - }, - to: { - height: 'var(--radix-accordion-content-height)' - } - }, - 'accordion-up': { - from: { - height: 'var(--radix-accordion-content-height)' - }, - to: { - height: '0' - } - } - }, - animation: { - 'accordion-down': 'accordion-down 0.2s ease-out', - 'accordion-up': 'accordion-up 0.2s ease-out' - } - } + extend: { + colors: { + background: "hsl(var(--background))", + foreground: "hsl(var(--foreground))", + card: { + DEFAULT: "hsl(var(--card))", + foreground: "hsl(var(--card-foreground))", + }, + popover: { + DEFAULT: "hsl(var(--popover))", + foreground: "hsl(var(--popover-foreground))", + }, + primary: { + DEFAULT: "hsl(var(--primary))", + foreground: "hsl(var(--primary-foreground))", + }, + secondary: { + DEFAULT: "hsl(var(--secondary))", + foreground: "hsl(var(--secondary-foreground))", + }, + muted: { + DEFAULT: "hsl(var(--muted))", + foreground: "hsl(var(--muted-foreground))", + }, + accent: { + DEFAULT: "hsl(var(--accent))", + foreground: "hsl(var(--accent-foreground))", + }, + destructive: { + DEFAULT: "hsl(var(--destructive))", + foreground: "hsl(var(--destructive-foreground))", + }, + border: "hsl(var(--border))", + input: "hsl(var(--input))", + ring: "hsl(var(--ring))", + chart: { + "1": "hsl(var(--chart-1))", + "2": "hsl(var(--chart-2))", + "3": "hsl(var(--chart-3))", + "4": "hsl(var(--chart-4))", + "5": "hsl(var(--chart-5))", + }, + }, + borderRadius: { + lg: "var(--radius)", + md: "calc(var(--radius) - 2px)", + sm: "calc(var(--radius) - 4px)", + }, + keyframes: { + "accordion-down": { + from: { + height: "0", + }, + to: { + height: "var(--radix-accordion-content-height)", + }, + }, + "accordion-up": { + from: { + height: "var(--radix-accordion-content-height)", + }, + to: { + height: "0", + }, + }, + }, + animation: { + "accordion-down": "accordion-down 0.2s ease-out", + "accordion-up": "accordion-up 0.2s ease-out", + }, + }, }, - plugins: [require("tailwindcss-animate")], + plugins: [animate, scrollbar], }; export default config;