Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: show better progress bars for operations #30

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docetl/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
self.default_model = default_model
self.max_threads = max_threads
self.console = console or Console()
self.manually_fix_errors = self.config.get("manually_fix_errors", False)
self.status = status
self.num_retries_on_validate_failure = self.config.get(
"num_retries_on_validate_failure", 0
Expand Down
8 changes: 7 additions & 1 deletion docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def compare_pair(
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
)
output = parse_llm_response(response)[0]
output = parse_llm_response(response, {"is_match": "bool"})[0]
return output["is_match"], completion_cost(response)


Expand Down Expand Up @@ -201,6 +201,9 @@ def get_hashable_key(item: Dict) -> str:
if len(left_data) == 0 or len(right_data) == 0:
return [], 0

if self.status:
self.status.stop()

# Initial blocking using multiprocessing
num_processes = min(cpu_count(), len(left_data))

Expand Down Expand Up @@ -441,4 +444,7 @@ def get_embeddings(
)
self.console.log(f"Equijoin selectivity: {join_selectivity:.4f}")

if self.status:
self.status.start()

return results, total_cost
14 changes: 12 additions & 2 deletions docetl/operations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,19 @@ def execute(
)
)

if self.status:
self.status.start()

def _process_filter_item(item: Dict) -> Tuple[Optional[Dict], float]:
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(input=item)

def validation_fn(response: Dict[str, Any]):
output = parse_llm_response(response)[0]
output = parse_llm_response(
response,
self.config["output"]["schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
for key, value in item.items():
if key not in self.config["output"]["schema"]:
output[key] = value
Expand Down Expand Up @@ -159,7 +166,7 @@ def validation_fn(response: Dict[str, Any]):
total_cost = 0
pbar = RichLoopBar(
range(len(futures)),
desc="Processing filter items",
desc=f"Processing {self.config['name']} (filter) on all documents",
console=self.console,
)
for i in pbar:
Expand All @@ -174,4 +181,7 @@ def validation_fn(response: Dict[str, Any]):
results.append(result)
pbar.update(1)

if self.status:
self.status.start()

return results, total_cost
26 changes: 22 additions & 4 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,19 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
dropped_results.append(new_item)
return dropped_results, 0.0 # Return the modified data with no cost

if self.status:
self.status.stop()

def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]:
prompt_template = Template(self.config["prompt"])
prompt = prompt_template.render(input=item)

def validation_fn(response: Dict[str, Any]):
output = parse_llm_response(
response, tools=self.config.get("tools", None)
response,
schema=self.config["output"]["schema"],
tools=self.config.get("tools", None),
manually_fix_errors=self.manually_fix_errors,
)[0]
for key, value in item.items():
if key not in self.config["output"]["schema"]:
Expand Down Expand Up @@ -196,7 +202,7 @@ def validation_fn(response: Dict[str, Any]):
total_cost = 0
pbar = RichLoopBar(
range(len(futures)),
desc="Processing map items",
desc=f"Processing {self.config['name']} (map) on all documents",
console=self.console,
)
for i in pbar:
Expand All @@ -212,6 +218,9 @@ def validation_fn(response: Dict[str, Any]):
total_cost += item_cost
pbar.update(i)

if self.status:
self.status.start()

return results, total_cost

def validate_output(self, output: Dict) -> bool:
Expand Down Expand Up @@ -349,6 +358,9 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
dropped_results.append(new_item)
return dropped_results, 0.0 # Return the modified data with no cost

if self.status:
self.status.stop()

def process_prompt(item, prompt_config):
prompt_template = Template(prompt_config["prompt"])
prompt = prompt_template.render(input=item)
Expand All @@ -368,7 +380,10 @@ def process_prompt(item, prompt_config):
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)
output = parse_llm_response(
response, tools=prompt_config.get("tools", None)
response,
schema=local_output_schema,
tools=prompt_config.get("tools", None),
manually_fix_errors=self.manually_fix_errors,
)[0]
return output, completion_cost(response)

Expand All @@ -384,7 +399,7 @@ def process_prompt(item, prompt_config):
# Process results in order
pbar = RichLoopBar(
range(len(all_futures)),
desc="Processing parallel map items",
desc=f"Processing {self.config['name']} (parallel map) on all documents",
console=self.console,
)
for i in pbar:
Expand Down Expand Up @@ -418,5 +433,8 @@ def process_prompt(item, prompt_config):
for key in drop_keys:
item.pop(key, None)

if self.status:
self.status.start()

# Return the results in order
return [results[i] for i in range(len(input_data)) if i in results], total_cost
22 changes: 18 additions & 4 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
reduce_keys = [reduce_keys]
input_schema = self.config.get("input", {}).get("schema", {})

if self.status:
self.status.stop()

# Check if we need to group everything into one group
if reduce_keys == ["_all"] or reduce_keys == "_all":
grouped_data = [("_all", input_data)]
Expand Down Expand Up @@ -341,7 +344,7 @@ def process_group(
for future in rich_as_completed(
futures,
total=len(futures),
desc="Processing reduce items",
desc=f"Processing {self.config['name']} (reduce) on all documents",
leave=True,
console=self.console,
):
Expand All @@ -358,6 +361,9 @@ def process_group(
self.intermediates[key]
)

if self.status:
self.status.start()

return results, total_cost

def _get_embeddings(
Expand Down Expand Up @@ -694,7 +700,11 @@ def _increment_fold(
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)
folded_output = parse_llm_response(response)[0]
folded_output = parse_llm_response(
response,
self.config["output"]["schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]

folded_output.update(dict(zip(self.config["reduce_key"], key)))
fold_cost = completion_cost(response)
Expand Down Expand Up @@ -735,7 +745,7 @@ def _merge_results(
timeout_seconds=self.config.get("timeout", 120),
max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2),
)
merged_output = parse_llm_response(response)[0]
merged_output = parse_llm_response(response, self.config["output"]["schema"])[0]
merged_output.update(dict(zip(self.config["reduce_key"], key)))
merge_cost = completion_cost(response)
end_time = time.time()
Expand Down Expand Up @@ -844,7 +854,11 @@ def _batch_reduce(

item_cost += completion_cost(response)

output = parse_llm_response(response)[0]
output = parse_llm_response(
response,
self.config["output"]["schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
output.update(dict(zip(self.config["reduce_key"], key)))

if validate_output(self.config, output, self.console):
Expand Down
21 changes: 14 additions & 7 deletions docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def compare_pair(
timeout_seconds=timeout_seconds,
max_retries_per_timeout=max_retries_per_timeout,
)
output = parse_llm_response(response)[0]
output = parse_llm_response(
response,
{"is_match": "bool"},
)[0]
return output["is_match"], completion_cost(response)


Expand Down Expand Up @@ -199,11 +202,11 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
blocking_keys = self.config.get("blocking_keys", [])
blocking_threshold = self.config.get("blocking_threshold")
blocking_conditions = self.config.get("blocking_conditions", [])
if self.status:
self.status.stop()

if not blocking_threshold and not blocking_conditions:
# Prompt the user for confirmation
if self.status:
self.status.stop()
if not Confirm.ask(
f"[yellow]Warning: No blocking keys or conditions specified. "
f"This may result in a large number of comparisons. "
Expand All @@ -212,9 +215,6 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
):
raise ValueError("Operation cancelled by user.")

if self.status:
self.status.start()

input_schema = self.config.get("input", {}).get("schema", {})
if not blocking_keys:
# Set them to all keys in the input data
Expand Down Expand Up @@ -413,7 +413,11 @@ def process_cluster(cluster):
"max_retries_per_timeout", 2
),
)
reduction_output = parse_llm_response(reduction_response)[0]
reduction_output = parse_llm_response(
reduction_response,
self.config["output"]["schema"],
manually_fix_errors=self.manually_fix_errors,
)[0]
reduction_cost = completion_cost(reduction_response)

if validate_output(self.config, reduction_output, self.console):
Expand Down Expand Up @@ -467,4 +471,7 @@ def process_cluster(cluster):
)
self.console.log(f"Self-join selectivity: {true_match_selectivity:.4f}")

if self.status:
self.status.start()

return results, total_cost
Loading
Loading