From 9054ce3d12d36e3ab60d66e76a1e060fb975ad7d Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Mon, 30 Sep 2024 12:31:18 -0700 Subject: [PATCH 1/2] feat: show better progress bars for operations --- docetl/operations/equijoin.py | 6 +++ docetl/operations/filter.py | 8 +++- docetl/operations/map.py | 16 +++++++- docetl/operations/reduce.py | 8 +++- docetl/operations/resolve.py | 10 ++--- docetl/runner.py | 73 ++++++++++++++++------------------- 6 files changed, 72 insertions(+), 49 deletions(-) diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index f61347f0..51128ecb 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -201,6 +201,9 @@ def get_hashable_key(item: Dict) -> str: if len(left_data) == 0 or len(right_data) == 0: return [], 0 + if self.status: + self.status.stop() + # Initial blocking using multiprocessing num_processes = min(cpu_count(), len(left_data)) @@ -441,4 +444,7 @@ def get_embeddings( ) self.console.log(f"Equijoin selectivity: {join_selectivity:.4f}") + if self.status: + self.status.start() + return results, total_cost diff --git a/docetl/operations/filter.py b/docetl/operations/filter.py index aee1df27..ffd81fb9 100644 --- a/docetl/operations/filter.py +++ b/docetl/operations/filter.py @@ -114,6 +114,9 @@ def execute( ) ) + if self.status: + self.status.start() + def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]: prompt_template = Template(self.config["prompt"]) prompt = prompt_template.render(input=item) @@ -159,7 +162,7 @@ def validation_fn(response: Dict[str, Any]): total_cost = 0 pbar = RichLoopBar( range(len(futures)), - desc="Processing filter items", + desc=f"Processing {self.config['name']} (filter) on all documents", console=self.console, ) for i in pbar: @@ -174,4 +177,7 @@ def validation_fn(response: Dict[str, Any]): results.append(result) pbar.update(1) + if self.status: + self.status.start() + return results, total_cost diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 15953b20..c0454839 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -128,6 +128,9 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: dropped_results.append(new_item) return dropped_results, 0.0 # Return the modified data with no cost + if self.status: + self.status.stop() + def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]: prompt_template = Template(self.config["prompt"]) prompt = prompt_template.render(input=item) @@ -196,7 +199,7 @@ def validation_fn(response: Dict[str, Any]): total_cost = 0 pbar = RichLoopBar( range(len(futures)), - desc="Processing map items", + desc=f"Processing {self.config['name']} (map) on all documents", console=self.console, ) for i in pbar: @@ -212,6 +215,9 @@ def validation_fn(response: Dict[str, Any]): total_cost += item_cost pbar.update(i) + if self.status: + self.status.start() + return results, total_cost def validate_output(self, output: Dict) -> bool: @@ -349,6 +355,9 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: dropped_results.append(new_item) return dropped_results, 0.0 # Return the modified data with no cost + if self.status: + self.status.stop() + def process_prompt(item, prompt_config): prompt_template = Template(prompt_config["prompt"]) prompt = prompt_template.render(input=item) @@ -384,7 +393,7 @@ def process_prompt(item, prompt_config): # Process results in order pbar = RichLoopBar( range(len(all_futures)), - desc="Processing parallel map items", + desc=f"Processing {self.config['name']} (parallel map) on all documents", console=self.console, ) for i in pbar: @@ -418,5 +427,8 @@ def process_prompt(item, prompt_config): for key in drop_keys: item.pop(key, None) + if self.status: + self.status.start() + # Return the results in order return [results[i] for i in range(len(input_data)) if i in results], total_cost diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index df9d8b16..bc2da2d6 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -252,6 +252,9 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: reduce_keys = [reduce_keys] input_schema = self.config.get("input", {}).get("schema", {}) + if self.status: + self.status.stop() + # Check if we need to group everything into one group if reduce_keys == ["_all"] or reduce_keys == "_all": grouped_data = [("_all", input_data)] @@ -341,7 +344,7 @@ def process_group( for future in rich_as_completed( futures, total=len(futures), - desc="Processing reduce items", + desc=f"Processing {self.config['name']} (reduce) on all documents", leave=True, console=self.console, ): @@ -358,6 +361,9 @@ def process_group( self.intermediates[key] ) + if self.status: + self.status.start() + return results, total_cost def _get_embeddings( diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 5d0666af..f1a766a6 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -199,11 +199,11 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: blocking_keys = self.config.get("blocking_keys", []) blocking_threshold = self.config.get("blocking_threshold") blocking_conditions = self.config.get("blocking_conditions", []) + if self.status: + self.status.stop() if not blocking_threshold and not blocking_conditions: # Prompt the user for confirmation - if self.status: - self.status.stop() if not Confirm.ask( f"[yellow]Warning: No blocking keys or conditions specified. " f"This may result in a large number of comparisons. " @@ -212,9 +212,6 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: ): raise ValueError("Operation cancelled by user.") - if self.status: - self.status.start() - input_schema = self.config.get("input", {}).get("schema", {}) if not blocking_keys: # Set them to all keys in the input data @@ -467,4 +464,7 @@ def process_cluster(cluster): ) self.console.log(f"Self-join selectivity: {true_match_selectivity:.4f}") + if self.status: + self.status.start() + return results, total_cost diff --git a/docetl/runner.py b/docetl/runner.py index dc783f1c..d022e395 100644 --- a/docetl/runner.py +++ b/docetl/runner.py @@ -122,21 +122,16 @@ def run(self) -> float: start_time = time.time() self.load_datasets() total_cost = 0 - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=self.console, - ) as progress: - for step in self.config["pipeline"]["steps"]: - step_name = step["name"] - input_data = self.datasets[step["input"]] if "input" in step else None - output_data, step_cost = self.execute_step(step, input_data, progress) - self.datasets[step_name] = output_data - flush_cache(self.console) - total_cost += step_cost - self.console.log( - f"Step [cyan]{step_name}[/cyan] completed. Cost: [green]${step_cost:.2f}[/green]" - ) + for step in self.config["pipeline"]["steps"]: + step_name = step["name"] + input_data = self.datasets[step["input"]] if "input" in step else None + output_data, step_cost = self.execute_step(step, input_data) + self.datasets[step_name] = output_data + flush_cache(self.console) + total_cost += step_cost + self.console.log( + f"Step [cyan]{step_name}[/cyan] completed. Cost: [green]${step_cost:.2f}[/green]" + ) self.save_output(self.datasets[self.config["pipeline"]["steps"][-1]["name"]]) self.console.rule("[bold green]Execution Summary[/bold green]") @@ -188,7 +183,7 @@ def save_output(self, data: List[Dict]): raise ValueError(f"Unsupported output type: {output_config['type']}") def execute_step( - self, step: Dict, input_data: Optional[List[Dict]], progress: Progress + self, step: Dict, input_data: Optional[List[Dict]] ) -> Tuple[List[Dict], float]: """ Execute a single step in the pipeline. @@ -199,7 +194,6 @@ def execute_step( Args: step (Dict): The step configuration. input_data (Optional[List[Dict]]): Input data for the step. - progress (Progress): Progress tracker for rich output. Returns: Tuple[List[Dict], float]: A tuple containing the output data and the total cost of the step. @@ -221,30 +215,29 @@ def execute_step( if op_object.get("sample"): input_data = input_data[: op_object["sample"]] - self.console.print("[bold]Running Operation:[/bold]") - self.console.print(f" Type: [cyan]{op_object['type']}[/cyan]") - self.console.print( - f" Name: [cyan]{op_object.get('name', 'Unnamed')}[/cyan]" - ) + with self.console.status("[bold]Running Operation:[/bold]") as status: + status.update(f"Type: [cyan]{op_object['type']}[/cyan]") + status.update(f"Name: [cyan]{op_object.get('name', 'Unnamed')}[/cyan]") + self.status = status - operation_class = get_operation(op_object["type"]) - operation_instance = operation_class( - op_object, - self.default_model, - self.max_threads, - self.console, - self.status, - ) - if op_object["type"] == "equijoin": - left_data = self.datasets[op_object["left"]] - right_data = self.datasets[op_object["right"]] - input_data, cost = operation_instance.execute(left_data, right_data) - else: - input_data, cost = operation_instance.execute(input_data) - total_cost += cost - self.console.log( - f"\tOperation [cyan]{operation_name}[/cyan] completed. Cost: [green]${cost:.2f}[/green]" - ) + operation_class = get_operation(op_object["type"]) + operation_instance = operation_class( + op_object, + self.default_model, + self.max_threads, + self.console, + self.status, + ) + if op_object["type"] == "equijoin": + left_data = self.datasets[op_object["left"]] + right_data = self.datasets[op_object["right"]] + input_data, cost = operation_instance.execute(left_data, right_data) + else: + input_data, cost = operation_instance.execute(input_data) + total_cost += cost + self.console.log( + f"\tOperation [cyan]{operation_name}[/cyan] completed. Cost: [green]${cost:.2f}[/green]" + ) # Checkpoint after each operation if self.intermediate_dir: From a9918ab112ccbc411273a859d2b4d499d27d57ec Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Mon, 30 Sep 2024 14:47:11 -0700 Subject: [PATCH 2/2] feat: show better progress bars for llm operations and upgrade litellm --- docetl/operations/base.py | 1 + docetl/operations/equijoin.py | 2 +- docetl/operations/filter.py | 6 +- docetl/operations/map.py | 10 ++- docetl/operations/reduce.py | 14 ++++- docetl/operations/resolve.py | 11 +++- docetl/operations/utils.py | 114 ++++++++++++++++++++++++++++++++-- poetry.lock | 91 ++++++++++++++++++++++++--- tests/test_map.py | 6 +- tests/test_ollama.py | 6 +- 10 files changed, 230 insertions(+), 31 deletions(-) diff --git a/docetl/operations/base.py b/docetl/operations/base.py index 2b5d9e0e..8b67f6e2 100644 --- a/docetl/operations/base.py +++ b/docetl/operations/base.py @@ -33,6 +33,7 @@ def __init__( self.default_model = default_model self.max_threads = max_threads self.console = console or Console() + self.manually_fix_errors = self.config.get("manually_fix_errors", False) self.status = status self.num_retries_on_validate_failure = self.config.get( "num_retries_on_validate_failure", 0 diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index 51128ecb..ad573208 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -86,7 +86,7 @@ def compare_pair( timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, ) - output = parse_llm_response(response)[0] + output = parse_llm_response(response, {"is_match": "bool"})[0] return output["is_match"], completion_cost(response) diff --git a/docetl/operations/filter.py b/docetl/operations/filter.py index ffd81fb9..35bd06af 100644 --- a/docetl/operations/filter.py +++ b/docetl/operations/filter.py @@ -122,7 +122,11 @@ def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]: prompt = prompt_template.render(input=item) def validation_fn(response: Dict[str, Any]): - output = parse_llm_response(response)[0] + output = parse_llm_response( + response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] for key, value in item.items(): if key not in self.config["output"]["schema"]: output[key] = value diff --git a/docetl/operations/map.py b/docetl/operations/map.py index c0454839..c6e53ee8 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -137,7 +137,10 @@ def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]: def validation_fn(response: Dict[str, Any]): output = parse_llm_response( - response, tools=self.config.get("tools", None) + response, + schema=self.config["output"]["schema"], + tools=self.config.get("tools", None), + manually_fix_errors=self.manually_fix_errors, )[0] for key, value in item.items(): if key not in self.config["output"]["schema"]: @@ -377,7 +380,10 @@ def process_prompt(item, prompt_config): max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) output = parse_llm_response( - response, tools=prompt_config.get("tools", None) + response, + schema=local_output_schema, + tools=prompt_config.get("tools", None), + manually_fix_errors=self.manually_fix_errors, )[0] return output, completion_cost(response) diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index bc2da2d6..c0389540 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -700,7 +700,11 @@ def _increment_fold( timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) - folded_output = parse_llm_response(response)[0] + folded_output = parse_llm_response( + response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] folded_output.update(dict(zip(self.config["reduce_key"], key))) fold_cost = completion_cost(response) @@ -741,7 +745,7 @@ def _merge_results( timeout_seconds=self.config.get("timeout", 120), max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) - merged_output = parse_llm_response(response)[0] + merged_output = parse_llm_response(response, self.config["output"]["schema"])[0] merged_output.update(dict(zip(self.config["reduce_key"], key))) merge_cost = completion_cost(response) end_time = time.time() @@ -850,7 +854,11 @@ def _batch_reduce( item_cost += completion_cost(response) - output = parse_llm_response(response)[0] + output = parse_llm_response( + response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] output.update(dict(zip(self.config["reduce_key"], key))) if validate_output(self.config, output, self.console): diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index f1a766a6..d72ee7ce 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -63,7 +63,10 @@ def compare_pair( timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, ) - output = parse_llm_response(response)[0] + output = parse_llm_response( + response, + {"is_match": "bool"}, + )[0] return output["is_match"], completion_cost(response) @@ -410,7 +413,11 @@ def process_cluster(cluster): "max_retries_per_timeout", 2 ), ) - reduction_output = parse_llm_response(reduction_response)[0] + reduction_output = parse_llm_response( + reduction_response, + self.config["output"]["schema"], + manually_fix_errors=self.manually_fix_errors, + )[0] reduction_cost = completion_cost(reduction_response) if validate_output(self.config, reduction_output, self.console): diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index 30ebffca..41ae1169 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -14,6 +14,7 @@ from litellm import completion, embedding, model_cost from docetl.utils import completion_cost from rich.console import Console +from rich.prompt import Prompt from tqdm import tqdm from diskcache import Cache import tiktoken @@ -357,6 +358,45 @@ def call_llm_with_validation( return parsed_output, cost, False +def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Prompt the user for input for each key in the schema using Rich, + then parse the input values with json.loads(). + + Args: + schema (Dict[str, Any]): The schema dictionary. + + Returns: + Dict[str, Any]: A dictionary with user inputs parsed according to the schema. + """ + user_input = {} + + for key, value_type in schema.items(): + prompt_text = f"Enter value for '{key}' ({value_type}): " + user_value = Prompt.ask(prompt_text) + + try: + # Parse the input value using json.loads() + parsed_value = json.loads(user_value) + + # Check if the parsed value matches the expected type + if isinstance(parsed_value, eval(value_type)): + user_input[key] = parsed_value + else: + rprint( + f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}." + ) + return get_user_input_for_schema(schema) # Recursive call to retry + + except json.JSONDecodeError: + rprint( + f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again." + ) + return get_user_input_for_schema(schema) # Recursive call to retry + + return user_input + + def call_llm( model: str, op_type: str, @@ -417,6 +457,25 @@ def call_llm( return {} +class InvalidOutputError(Exception): + """ + Custom exception raised when the LLM output is invalid or cannot be parsed. + + Attributes: + message (str): Explanation of the error. + output (str): The invalid output that caused the exception. + """ + + def __init__(self, message: str, output: str, expected_schema: Dict[str, Any]): + self.message = message + self.output = output + self.expected_schema = expected_schema + super().__init__(self.message) + + def __str__(self): + return f"{self.message}\nInvalid output: {self.output}\nExpected schema: {self.expected_schema}" + + def timeout(seconds): def decorator(func): @functools.wraps(func) @@ -663,7 +722,7 @@ def call_llm_with_gleaning( cost = 0.0 # Parse the response - parsed_response = parse_llm_response(response) + parsed_response = parse_llm_response(response, output_schema) output = parsed_response[0] messages = ( @@ -764,7 +823,7 @@ def call_llm_with_gleaning( messages.append( { "role": "assistant", - "content": json.dumps(parse_llm_response(response)[0]), + "content": json.dumps(parse_llm_response(response, output_schema)[0]), } ) @@ -772,7 +831,36 @@ def call_llm_with_gleaning( def parse_llm_response( - response: Any, tools: Optional[List[Dict[str, str]]] = None + response: Any, + schema: Dict[str, Any] = {}, + tools: Optional[List[Dict[str, str]]] = None, + manually_fix_errors: bool = False, +) -> List[Dict[str, Any]]: + """ + Parse the response from a language model. + This function extracts the tool calls from the LLM response and returns the arguments + """ + try: + return parse_llm_response_helper(response, schema, tools) + except InvalidOutputError as e: + if manually_fix_errors: + rprint( + f"[bold red]Could not parse LLM output:[/bold red] {e.message}\n" + f"\tExpected Schema: {e.expected_schema}\n" + f"\tPlease manually set this output." + ) + rprint(f"\n[bold yellow]LLM-Generated Response:[/bold yellow]\n{response}") + output = get_user_input_for_schema(schema) + + return [output] + else: + raise e + + +def parse_llm_response_helper( + response: Any, + schema: Dict[str, Any] = {}, + tools: Optional[List[Dict[str, str]]] = None, ) -> List[Dict[str, Any]]: """ Parse the response from a language model. @@ -782,13 +870,17 @@ def parse_llm_response( Args: response (Any): The response object from the language model. + schema (Optional[Dict[str, Any]]): The schema that was passed to the LLM. tools (Optional[List[Dict[str, str]]]): The tools that were passed to the LLM. Returns: List[Dict[str, Any]]: A list of dictionaries containing the parsed output. + + Raises: + InvalidOutputError: If the response is not valid. """ if not response: - return [{}] + raise InvalidOutputError("No response from LLM", [{}], schema) # Parse the response based on the provided tools if tools: @@ -817,7 +909,7 @@ def parse_llm_response( tool_calls = response.choices[0].message.tool_calls if not tool_calls: - raise ValueError("No tool calls found in response") + raise InvalidOutputError("No tool calls in LLM response", [{}], schema) outputs = [] for tool_call in tool_calls: @@ -839,7 +931,17 @@ def parse_llm_response( pass outputs.append(output_dict) except json.JSONDecodeError: - return [{}] + raise InvalidOutputError( + "Could not decode LLM JSON response", + [tool_call.function.arguments], + schema, + ) + except Exception as e: + raise InvalidOutputError( + f"Error parsing LLM response: {e}", + [tool_call.function.arguments], + schema, + ) return outputs else: diff --git a/poetry.lock b/poetry.lock index f840816a..e7f10f5e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -797,6 +797,76 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jiter" +version = "0.5.0" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.5.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b599f4e89b3def9a94091e6ee52e1d7ad7bc33e238ebb9c4c63f211d74822c3f"}, + {file = "jiter-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a063f71c4b06225543dddadbe09d203dc0c95ba352d8b85f1221173480a71d5"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acc0d5b8b3dd12e91dd184b87273f864b363dfabc90ef29a1092d269f18c7e28"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c22541f0b672f4d741382a97c65609332a783501551445ab2df137ada01e019e"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63314832e302cc10d8dfbda0333a384bf4bcfce80d65fe99b0f3c0da8945a91a"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a25fbd8a5a58061e433d6fae6d5298777c0814a8bcefa1e5ecfff20c594bd749"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503b2c27d87dfff5ab717a8200fbbcf4714516c9d85558048b1fc14d2de7d8dc"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6d1f3d27cce923713933a844872d213d244e09b53ec99b7a7fdf73d543529d6d"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c95980207b3998f2c3b3098f357994d3fd7661121f30669ca7cb945f09510a87"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:afa66939d834b0ce063f57d9895e8036ffc41c4bd90e4a99631e5f261d9b518e"}, + {file = "jiter-0.5.0-cp310-none-win32.whl", hash = "sha256:f16ca8f10e62f25fd81d5310e852df6649af17824146ca74647a018424ddeccf"}, + {file = "jiter-0.5.0-cp310-none-win_amd64.whl", hash = "sha256:b2950e4798e82dd9176935ef6a55cf6a448b5c71515a556da3f6b811a7844f1e"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d4c8e1ed0ef31ad29cae5ea16b9e41529eb50a7fba70600008e9f8de6376d553"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c6f16e21276074a12d8421692515b3fd6d2ea9c94fd0734c39a12960a20e85f3"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5280e68e7740c8c128d3ae5ab63335ce6d1fb6603d3b809637b11713487af9e6"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:583c57fc30cc1fec360e66323aadd7fc3edeec01289bfafc35d3b9dcb29495e4"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26351cc14507bdf466b5f99aba3df3143a59da75799bf64a53a3ad3155ecded9"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829df14d656b3fb87e50ae8b48253a8851c707da9f30d45aacab2aa2ba2d614"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a42a4bdcf7307b86cb863b2fb9bb55029b422d8f86276a50487982d99eed7c6e"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04d461ad0aebf696f8da13c99bc1b3e06f66ecf6cfd56254cc402f6385231c06"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6375923c5f19888c9226582a124b77b622f8fd0018b843c45eeb19d9701c403"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cec323a853c24fd0472517113768c92ae0be8f8c384ef4441d3632da8baa646"}, + {file = "jiter-0.5.0-cp311-none-win32.whl", hash = "sha256:aa1db0967130b5cab63dfe4d6ff547c88b2a394c3410db64744d491df7f069bb"}, + {file = "jiter-0.5.0-cp311-none-win_amd64.whl", hash = "sha256:aa9d2b85b2ed7dc7697597dcfaac66e63c1b3028652f751c81c65a9f220899ae"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9f664e7351604f91dcdd557603c57fc0d551bc65cc0a732fdacbf73ad335049a"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:044f2f1148b5248ad2c8c3afb43430dccf676c5a5834d2f5089a4e6c5bbd64df"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:702e3520384c88b6e270c55c772d4bd6d7b150608dcc94dea87ceba1b6391248"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:528d742dcde73fad9d63e8242c036ab4a84389a56e04efd854062b660f559544"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8cf80e5fe6ab582c82f0c3331df27a7e1565e2dcf06265afd5173d809cdbf9ba"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44dfc9ddfb9b51a5626568ef4e55ada462b7328996294fe4d36de02fce42721f"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c451f7922992751a936b96c5f5b9bb9312243d9b754c34b33d0cb72c84669f4e"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:308fce789a2f093dca1ff91ac391f11a9f99c35369117ad5a5c6c4903e1b3e3a"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7f5ad4a7c6b0d90776fdefa294f662e8a86871e601309643de30bf94bb93a64e"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ea189db75f8eca08807d02ae27929e890c7d47599ce3d0a6a5d41f2419ecf338"}, + {file = "jiter-0.5.0-cp312-none-win32.whl", hash = "sha256:e3bbe3910c724b877846186c25fe3c802e105a2c1fc2b57d6688b9f8772026e4"}, + {file = "jiter-0.5.0-cp312-none-win_amd64.whl", hash = "sha256:a586832f70c3f1481732919215f36d41c59ca080fa27a65cf23d9490e75b2ef5"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f04bc2fc50dc77be9d10f73fcc4e39346402ffe21726ff41028f36e179b587e6"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f433a4169ad22fcb550b11179bb2b4fd405de9b982601914ef448390b2954f3"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad4a6398c85d3a20067e6c69890ca01f68659da94d74c800298581724e426c7e"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6baa88334e7af3f4d7a5c66c3a63808e5efbc3698a1c57626541ddd22f8e4fbf"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ece0a115c05efca597c6d938f88c9357c843f8c245dbbb53361a1c01afd7148"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:335942557162ad372cc367ffaf93217117401bf930483b4b3ebdb1223dbddfa7"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649b0ee97a6e6da174bffcb3c8c051a5935d7d4f2f52ea1583b5b3e7822fbf14"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4be354c5de82157886ca7f5925dbda369b77344b4b4adf2723079715f823989"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5206144578831a6de278a38896864ded4ed96af66e1e63ec5dd7f4a1fce38a3a"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8120c60f8121ac3d6f072b97ef0e71770cc72b3c23084c72c4189428b1b1d3b6"}, + {file = "jiter-0.5.0-cp38-none-win32.whl", hash = "sha256:6f1223f88b6d76b519cb033a4d3687ca157c272ec5d6015c322fc5b3074d8a5e"}, + {file = "jiter-0.5.0-cp38-none-win_amd64.whl", hash = "sha256:c59614b225d9f434ea8fc0d0bec51ef5fa8c83679afedc0433905994fb36d631"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0af3838cfb7e6afee3f00dc66fa24695199e20ba87df26e942820345b0afc566"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:550b11d669600dbc342364fd4adbe987f14d0bbedaf06feb1b983383dcc4b961"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:489875bf1a0ffb3cb38a727b01e6673f0f2e395b2aad3c9387f94187cb214bbf"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b250ca2594f5599ca82ba7e68785a669b352156260c5362ea1b4e04a0f3e2389"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ea18e01f785c6667ca15407cd6dabbe029d77474d53595a189bdc813347218e"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:462a52be85b53cd9bffd94e2d788a09984274fe6cebb893d6287e1c296d50653"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92cc68b48d50fa472c79c93965e19bd48f40f207cb557a8346daa020d6ba973b"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c834133e59a8521bc87ebcad773608c6fa6ab5c7a022df24a45030826cf10bc"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab3a71ff31cf2d45cb216dc37af522d335211f3a972d2fe14ea99073de6cb104"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cccd3af9c48ac500c95e1bcbc498020c87e1781ff0345dd371462d67b76643eb"}, + {file = "jiter-0.5.0-cp39-none-win32.whl", hash = "sha256:368084d8d5c4fc40ff7c3cc513c4f73e02c85f6009217922d0823a48ee7adf61"}, + {file = "jiter-0.5.0-cp39-none-win_amd64.whl", hash = "sha256:ce03f7b4129eb72f1687fa11300fbf677b02990618428934662406d2a76742a1"}, + {file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"}, +] + [[package]] name = "joblib" version = "1.4.2" @@ -881,13 +951,13 @@ test = ["coverage", "pytest", "pytest-cov"] [[package]] name = "litellm" -version = "1.43.0" +version = "1.48.7" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.43.0-py3-none-any.whl", hash = "sha256:0a91b84637e9f5b2cf31f210f682e04134f9c1837501763685dc1a58ec25f685"}, - {file = "litellm-1.43.0.tar.gz", hash = "sha256:0d0faa25470783dc9f30894a3bb86c90cd3ff72cf769393eaac563e53d2ec162"}, + {file = "litellm-1.48.7-py3-none-any.whl", hash = "sha256:4971a9e681188635c2ee6dc44fe35bb2774586e9018682adcccdbb516b839c64"}, + {file = "litellm-1.48.7.tar.gz", hash = "sha256:ff1fef7049e9afa09598f98d1e510a6d5f252ec65c0526b8bfaf13eadfcf65e5"}, ] [package.dependencies] @@ -896,7 +966,7 @@ click = "*" importlib-metadata = ">=6.8.0" jinja2 = ">=3.1.2,<4.0.0" jsonschema = ">=4.22.0,<5.0.0" -openai = ">=1.27.0" +openai = ">=1.45.0" pydantic = ">=2.0.0,<3.0.0" python-dotenv = ">=0.2.0" requests = ">=2.31.0,<3.0.0" @@ -904,8 +974,8 @@ tiktoken = ">=0.7.0" tokenizers = "*" [package.extras] -extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "pynacl (>=1.5.0,<2.0.0)", "resend (>=0.8.0,<0.9.0)"] -proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"] +extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "resend (>=0.8.0,<0.9.0)"] +proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"] [[package]] name = "markdown" @@ -1444,23 +1514,24 @@ files = [ [[package]] name = "openai" -version = "1.37.0" +version = "1.50.2" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.37.0-py3-none-any.whl", hash = "sha256:a903245c0ecf622f2830024acdaa78683c70abb8e9d37a497b851670864c9f73"}, - {file = "openai-1.37.0.tar.gz", hash = "sha256:dc8197fc40ab9d431777b6620d962cc49f4544ffc3011f03ce0a805e6eb54adb"}, + {file = "openai-1.50.2-py3-none-any.whl", hash = "sha256:822dd2051baa3393d0d5406990611975dd6f533020dc9375a34d4fe67e8b75f7"}, + {file = "openai-1.50.2.tar.gz", hash = "sha256:3987ae027152fc8bea745d60b02c8f4c4a76e1b5c70e73565fa556db6f78c9e6"}, ] [package.dependencies] anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" pydantic = ">=1.9.0,<3" sniffio = "*" tqdm = ">4" -typing-extensions = ">=4.7,<5" +typing-extensions = ">=4.11,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] diff --git a/tests/test_map.py b/tests/test_map.py index ccf390a4..60aee877 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -1,3 +1,4 @@ +import docetl import pytest from docetl.operations.map import MapOperation @@ -99,6 +100,5 @@ def test_map_operation_with_timeout(simple_map_config, simple_sample_data): operation = MapOperation(map_config_with_timeout, "gpt-4o-mini", 4) # Execute the operation and expect empty results - results, cost = operation.execute(simple_sample_data) - for result in results: - assert "sentiment" not in result + with pytest.raises(docetl.operations.utils.InvalidOutputError): + operation.execute(simple_sample_data) diff --git a/tests/test_ollama.py b/tests/test_ollama.py index 81d8fbfb..08647804 100644 --- a/tests/test_ollama.py +++ b/tests/test_ollama.py @@ -54,7 +54,7 @@ def map_config(): type="map", prompt="Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", output={"schema": {"sentiment": "string"}}, - model="ollama/llama3", + model="ollama/llama3.1", ) @@ -66,7 +66,7 @@ def reduce_config(): reduce_key="group", prompt="Summarize the following group of values: {{ inputs }} Provide a total and any other relevant statistics.", output={"schema": {"total": "number", "avg": "number"}}, - model="ollama/llama3", + model="ollama/llama3.1", ) @@ -95,7 +95,7 @@ def test_ollama_map_reduce_pipeline( output=PipelineOutput( type="file", path=temp_output_file, intermediate_dir=temp_intermediate_dir ), - default_model="ollama/llama3", + default_model="ollama/llama3.1", ) cost = pipeline.run()