Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge staging to main (after parsers refactor) #82

Merged
merged 8 commits into from
Oct 8, 2024
37 changes: 13 additions & 24 deletions docetl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,14 @@ def _validate_parsing(
for tool in parsing_tools:
if (
not isinstance(tool, dict)
or "input_key" not in tool
or "function" not in tool
or "output_key" not in tool
):
raise ValueError(
"Each parsing tool must be a dictionary with 'input_key', 'function', and 'output_key' keys"
"Each parsing tool must be a dictionary with a 'function' key and any arguments required by that function"
)
if (
not isinstance(tool["input_key"], str)
or not isinstance(tool["function"], str)
or not isinstance(tool["output_key"], str)
):
if not isinstance(tool["function"], str):
raise ValueError(
"'input_key', 'function', and 'output_key' in parsing tools must be strings"
"'function' in parsing tools must be a string"
)
if "function_kwargs" in tool and not isinstance(
tool["function_kwargs"], dict
Expand Down Expand Up @@ -213,19 +207,12 @@ def load(self) -> List[Dict]:
def _process_item(
self,
item: Dict[str, Any],
input_key: str,
output_key: str,
func: Callable,
**function_kwargs: Dict[str, Any],
):
if input_key not in item:
raise ValueError(f"Input key {input_key} not found in item: {item}")
result = func(item[input_key], **function_kwargs)
if isinstance(result, list):
return [item.copy() | {output_key: res} for res in result]
else:
return [item | {output_key: result}]

result = func(item, **function_kwargs)
return [item.copy() | res for res in result]

def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
"""
Apply parsing tools to the data.
Expand All @@ -240,7 +227,13 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
ValueError: If a parsing tool is not found or if an input key is missing from an item.
"""
for tool in self.parsing:
input_key = tool["input_key"]
function_kwargs = dict(tool)
function_kwargs.pop("function")
# FIXME: The following is just for backwards compatibility
# with the existing yaml format...
if "function_kwargs" in function_kwargs:
function_kwargs.update(function_kwargs.pop("function_kwargs"))

try:
func = get_parser(tool["function"])
except KeyError:
Expand All @@ -261,17 +254,13 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
f"Parsing tool {tool['function']} not found. Please define it or use one of our existing parsing tools: {get_parsing_tools()}"
)

output_key = tool["output_key"]
function_kwargs = tool.get("function_kwargs", {})
new_data = []

with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
self._process_item,
item,
input_key,
output_key,
func,
**function_kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def validation_fn(response: Dict[str, Any]):
max_retries_per_timeout=self.config.get(
"max_retries_per_timeout", 2
),
verbose=self.config.get("verbose", False),
),
validation_fn=validation_fn,
val_rule=self.config.get("validate", []),
Expand Down
6 changes: 6 additions & 0 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
Returns:
Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation.
"""
if self.config.get("gleaning", {}).get("validation_prompt", None):
self.console.log(
f"Using gleaning with validation prompt: {self.config.get('gleaning', {}).get('validation_prompt', '')}"
)

reduce_keys = self.config["reduce_key"]
if isinstance(reduce_keys, str):
reduce_keys = [reduce_keys]
Expand Down Expand Up @@ -860,6 +865,7 @@ def _batch_reduce(
console=self.console,
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
verbose=self.config.get("verbose", False),
)
item_cost += gleaning_cost
else:
Expand Down
12 changes: 9 additions & 3 deletions docetl/operations/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
split_key = self.config["split_key"]
method = self.config["method"]
method_kwargs = self.config["method_kwargs"]
encoder = tiktoken.encoding_for_model(
self.config["method_kwargs"].get("model", self.default_model).split("/")[-1]
)
try:
encoder = tiktoken.encoding_for_model(
self.config["method_kwargs"]
.get("model", self.default_model)
.split("/")[-1]
)
except Exception:
encoder = tiktoken.encoding_for_model("gpt-4o")

results = []
cost = 0.0

Expand Down
18 changes: 12 additions & 6 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def safe_eval(expression: str, output: Dict) -> bool:
# Safely evaluate the expression
return bool(aeval(expression))
except Exception:
return False
# try to evaluate with python eval
try:
return bool(eval(expression, locals={"output": output}))
except Exception:
return False


class APIWrapper(object):
Expand Down Expand Up @@ -720,6 +724,7 @@ def call_llm_with_gleaning(
console: Console = Console(),
timeout_seconds: int = 120,
max_retries_per_timeout: int = 2,
verbose: bool = False,
) -> Tuple[str, float]:
"""
Call LLM with a gleaning process, including validation and improvement rounds.
Expand Down Expand Up @@ -789,7 +794,7 @@ def call_llm_with_gleaning(
# Call LLM for validation
self.runner.rate_limiter.try_acquire("llm_call", weight=1)
validator_response = completion(
model="gpt-4o-mini",
model=model,
messages=truncate_messages(
messages + [{"role": "user", "content": validator_prompt}], model
),
Expand Down Expand Up @@ -817,9 +822,10 @@ def call_llm_with_gleaning(
if not suggestion["should_refine"]:
break

# console.log(
# f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}"
# )
if verbose:
console.log(
f"Validator improvements (gleaning round {rnd + 1}): {suggestion['improvements']}"
)

# Prompt for improvement
improvement_prompt = f"""Based on the validation feedback:
Expand Down Expand Up @@ -1166,4 +1172,4 @@ def rich_as_completed(futures, total=None, desc=None, leave=True, console=None):
with RichLoopBar(total=total, desc=desc, leave=leave, console=console) as pbar:
for future in as_completed(futures):
yield future
pbar.update()
pbar.update()
2 changes: 1 addition & 1 deletion docetl/optimizers/reduce_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ def _calculate_compression_ratio(
reduce_key = op_config["reduce_key"]
input_schema = op_config.get("input", {}).get("schema", {})
output_schema = op_config["output"]["schema"]
model = op_config.get("model", "gpt-4o")
model = op_config.get("model", "gpt-4o-mini")

compression_ratios = {}

Expand Down
47 changes: 36 additions & 11 deletions docetl/parsing_tools.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,49 @@
import importlib
import io
import os
from typing import Dict, List, Optional


def llama_index_simple_directory_reader(filename: str) -> List[str]:
from typing import Dict, List, Optional, Any

def with_input_output_key(fn):
"""Decorator that wraps a parser function that takes a single
string parameter and return list of strings and makes it a full
parser function that takes an item as a dictionary and return a
list of dictionaries."""
def wrapper(item, input_key="text", output_key="text", **kw):
if input_key not in item:
raise ValueError(f"Input key {input_key} not found in item: {item}")
result = fn(item[input_key], **kw)
if not isinstance(result, list):
result = [result]
return [{output_key: res} for res in result]
return wrapper

def llama_index_simple_directory_reader(item: dict[str, Any], input_key: str ="path") -> List[dict[str, Any]]:
from llama_index.core import SimpleDirectoryReader

documents = SimpleDirectoryReader(filename).load_data()
# FIXME: What about doc.metadata? Would be good to include that too...
return [doc.text for doc in documents]
documents = SimpleDirectoryReader(item[input_key]).load_data()
return [{"text": doc.text,
"metadata": doc.metadata}
for doc in documents]


def llama_index_wikipedia_reader(filename: str) -> List[str]:
def llama_index_wikipedia_reader(item: dict[str, Any], input_key: str = "pages") -> List[dict[str, Any]]:
from llama_index.readers.wikipedia import WikipediaReader

loader = WikipediaReader()
pages = [filename]
pages = item[input_key]
if not isinstance(pages, list):
pages = [pages]
documents = loader.load_data(pages=pages, auto_suggest=False)
# The wikipedia reader does not include the page url in the metadata, which is impractical...
for name, doc in zip(pages, documents):
doc.metadata["source"] = "https://en.wikipedia.org/wiki/" + name

# FIXME: What about doc.metadata? Would be good to include that too...
return [doc.text for doc in documents]
return [{"text": doc.text,
"metadata": doc.metadata}
for doc in documents]


@with_input_output_key
def whisper_speech_to_text(filename: str) -> List[str]:
"""
Transcribe speech from an audio file to text using Whisper model via litellm.
Expand Down Expand Up @@ -72,6 +90,7 @@ def whisper_speech_to_text(filename: str) -> List[str]:
return [response.text]


@with_input_output_key
def xlsx_to_string(
filename: str,
orientation: str = "col",
Expand Down Expand Up @@ -128,6 +147,7 @@ def process_sheet(sheet):
return [process_sheet(wb.active)]


@with_input_output_key
def txt_to_string(filename: str) -> List[str]:
"""
Read the content of a text file and return it as a list of strings (only one element).
Expand All @@ -142,6 +162,7 @@ def txt_to_string(filename: str) -> List[str]:
return [file.read()]


@with_input_output_key
def docx_to_string(filename: str) -> List[str]:
"""
Extract text from a Word document.
Expand All @@ -158,6 +179,7 @@ def docx_to_string(filename: str) -> List[str]:
return ["\n".join([paragraph.text for paragraph in doc.paragraphs])]


@with_input_output_key
def pptx_to_string(filename: str, doc_per_slide: bool = False) -> List[str]:
"""
Extract text from a PowerPoint presentation.
Expand Down Expand Up @@ -195,6 +217,7 @@ def pptx_to_string(filename: str, doc_per_slide: bool = False) -> List[str]:
return result


@with_input_output_key
def azure_di_read(
filename: str,
use_url: bool = False,
Expand Down Expand Up @@ -334,6 +357,7 @@ def azure_di_read(
]


@with_input_output_key
def paddleocr_pdf_to_string(
input_path: str,
doc_per_page: bool = False,
Expand Down Expand Up @@ -399,6 +423,7 @@ def paddleocr_pdf_to_string(
return pdf_content


@with_input_output_key
def gptpdf_to_string(
input_path: str,
gpt_model: str,
Expand Down
5 changes: 4 additions & 1 deletion docetl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def truncate_sample_data(
remaining_tokens = available_tokens - current_tokens

# Encode the value
encoder = tiktoken.encoding_for_model(model)
try:
encoder = tiktoken.encoding_for_model(model)
except Exception:
encoder = tiktoken.encoding_for_model("gpt-4o")
encoded_value = encoder.encode(str(data[key]))

# Calculate how many tokens to keep
Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ validate:
- all(len(insight["supporting_actions"]) >= 1 for insight in output["insights"])
```

Access variables using dictionary syntax: `input["field"]` or `output["field"]`.
Access variables using dictionary syntax: `output["field"]`. Note that you can't access `input` docs in validation, but the output docs should have all the fields from the input docs (for non-reduce operations), since fields pass through unchanged.

The `num_retries_on_validate_failure` attribute specifies how many times to retry the LLM if any validation statements fail.

Expand Down
Loading
Loading