Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Executor] Add resume to batch engine #2003

Merged
merged 17 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/promptflow/promptflow/_core/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,9 @@ class DuplicateToolMappingError(ValidationException):
"""Exception raised when multiple tools are linked to the same deprecated tool id."""

pass


class ResumeCopyError(SystemErrorException):
"""Exception raised when failed to copy the results when resuming the run."""

pass
15 changes: 14 additions & 1 deletion src/promptflow/promptflow/_utils/execution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import AbstractSet, Any, Dict, List, Mapping

from promptflow._utils.logger_utils import logger
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputValueType
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputAssignment, InputValueType
from promptflow.contracts.run_info import FlowRunInfo, Status
from promptflow.executor import _input_assignment_parser


def apply_default_value_for_input(inputs: Dict[str, FlowInputDefinition], line_inputs: Mapping) -> Dict[str, Any]:
Expand Down Expand Up @@ -56,3 +57,15 @@ def get_aggregation_inputs_properties(flow: Flow) -> AbstractSet[str]:
def collect_lines(indexes: List[int], kvs: Mapping[str, List]) -> Mapping[str, List]:
"""Collect the values from the kvs according to the indexes."""
return {k: [v[i] for i in indexes] for k, v in kvs.items()}


def extract_aggregation_inputs(flow: Flow, nodes_outputs: dict) -> Dict[str, Any]:
"""Extract the aggregation inputs of a flow from the nodes outputs."""
_aggregation_inputs_references = get_aggregation_inputs_properties(flow)
return {prop: _parse_aggregation_input(nodes_outputs, prop) for prop in _aggregation_inputs_references}


def _parse_aggregation_input(nodes_outputs: dict, aggregation_input_property: str):
"""Parse the value of the aggregation input from the nodes outputs."""
assign = InputAssignment.deserialize(aggregation_input_property)
return _input_assignment_parser.parse_value(assign, nodes_outputs, {})
37 changes: 37 additions & 0 deletions src/promptflow/promptflow/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import re
import shutil
import time
import traceback
from datetime import datetime
Expand Down Expand Up @@ -70,6 +71,14 @@ def dump_list_to_jsonl(file_path: Union[str, Path], list_data: List[Dict]):
jsonl_file.write("\n")


def load_list_from_jsonl(file: Union[str, Path]):
content = []
with open(file, "r", encoding=DEFAULT_ENCODING) as fin:
for line in fin:
content.append(json.loads(line))
return content


def transpose(values: List[Dict[str, Any]], keys: Optional[List] = None) -> Dict[str, List]:
keys = keys or list(values[0].keys())
return {key: [v.get(key) for v in values] for key in keys}
Expand Down Expand Up @@ -329,3 +338,31 @@ def _match_reference(env_val: str):
return None, None
name, key = m.groups()
return name, key


def copy_file_except(src_dir, dst_dir, exclude_file):
"""
Copy all files from src_dir to dst_dir recursively, excluding a specific file
directly under the root of src_dir.

:param src_dir: Source directory path
:type src_dir: str
:param dst_dir: Destination directory path
:type dst_dir: str
:param exclude_file: Name of the file to exclude from copying
:type exclude_file: str
"""
os.makedirs(dst_dir, exist_ok=True)

for root, dirs, files in os.walk(src_dir):
rel_path = os.path.relpath(root, src_dir)
current_dst_dir = os.path.join(dst_dir, rel_path)

os.makedirs(current_dst_dir, exist_ok=True)

for file in files:
if rel_path == "." and file == exclude_file:
continue # Skip the excluded file
src_file_path = os.path.join(root, file)
dst_file_path = os.path.join(current_dst_dir, file)
shutil.copy2(src_file_path, dst_file_path)
92 changes: 86 additions & 6 deletions src/promptflow/promptflow/batch/_batch_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import asyncio
import signal
import threading
Expand All @@ -11,20 +10,23 @@
from typing import Any, Dict, List, Mapping, Optional

from promptflow._constants import LANGUAGE_KEY, LINE_NUMBER_KEY, LINE_TIMEOUT_SEC, FlowLanguage
from promptflow._core._errors import UnexpectedError
from promptflow._core._errors import ResumeCopyError, UnexpectedError
from promptflow._core.operation_context import OperationContext
from promptflow._utils.async_utils import async_run_allowing_running_loop
from promptflow._utils.context_utils import _change_working_dir
from promptflow._utils.execution_utils import (
apply_default_value_for_input,
collect_lines,
extract_aggregation_inputs,
get_aggregation_inputs_properties,
handle_line_failures,
)
from promptflow._utils.logger_utils import bulk_logger
from promptflow._utils.utils import (
copy_file_except,
dump_list_to_jsonl,
get_int_env_var,
load_list_from_jsonl,
log_progress,
resolve_dir_to_absolute,
transpose,
Expand All @@ -42,7 +44,7 @@
from promptflow.executor._line_execution_process_pool import signal_handler
from promptflow.executor._result import AggregationResult, LineResult
from promptflow.executor.flow_validator import FlowValidator
from promptflow.storage._run_storage import AbstractBatchRunStorage, AbstractRunStorage
from promptflow.storage import AbstractBatchRunStorage, AbstractRunStorage

OUTPUT_FILE_NAME = "output.jsonl"
DEFAULT_CONCURRENCY = 10
Expand Down Expand Up @@ -192,9 +194,21 @@ def run(
batch_inputs = batch_input_processor.process_batch_inputs(input_dirs, inputs_mapping)
# resolve output dir
output_dir = resolve_dir_to_absolute(self._working_dir, output_dir)

previous_run_results = None
if resume_from_run_storage and resume_from_run_output_dir:
previous_run_results = self._copy_previous_run_result(
resume_from_run_storage, resume_from_run_output_dir, batch_inputs, output_dir
)

# run flow in batch mode
return async_run_allowing_running_loop(
self._exec_in_task, batch_inputs, run_id, output_dir, raise_on_line_failure
self._exec_in_task,
batch_inputs,
run_id,
output_dir,
raise_on_line_failure,
previous_run_results,
)
finally:
async_run_allowing_running_loop(self._executor_proxy.destroy)
Expand All @@ -213,6 +227,66 @@ def run(
)
raise unexpected_error from e

def _copy_previous_run_result(
self,
resume_from_run_storage: AbstractBatchRunStorage,
resume_from_run_output_dir: Path,
batch_inputs: List,
output_dir: Path,
) -> List[LineResult]:
guming-learning marked this conversation as resolved.
Show resolved Hide resolved
"""Duplicate the previous debug_info from resume_from_run_storage and output from resume_from_run_output_dir
to the storage of new run,
return the list of previous line results for the usage of aggregation and summarization.
"""
# Load the previous flow run output from output.jsonl
previous_run_output = load_list_from_jsonl(resume_from_run_output_dir / "output.jsonl")
previous_run_output_dict = {
each_line_output[LINE_NUMBER_KEY]: each_line_output for each_line_output in previous_run_output
}

# Copy other files from resume_from_run_output_dir to output_dir in case there are images
copy_file_except(resume_from_run_output_dir, output_dir, "output.jsonl")

try:
previous_run_results = []
for i in range(len(batch_inputs)):
previous_run_info = resume_from_run_storage.load_flow_run_info(i)

if previous_run_info and previous_run_info.status == Status.Completed:
# Load previous node run info
previous_node_run_infos = resume_from_run_storage.load_node_run_info_for_line(i)
previous_node_run_infos_dict = {node_run.node: node_run for node_run in previous_node_run_infos}
previous_node_run_outputs = {
node_info.node: node_info.output for node_info in previous_node_run_infos
}

# Extract aggregation inputs for flow with aggregation node
aggregation_inputs = extract_aggregation_inputs(self._flow, previous_node_run_outputs)

# Persist previous run info and node run info
self._storage.persist_flow_run(previous_run_info)
for node_run_info in previous_node_run_infos:
self._storage.persist_node_run(node_run_info)

# Create LineResult object for previous line result
previous_line_result = LineResult(
output=previous_run_output_dict[i],
aggregation_inputs=aggregation_inputs,
run_info=previous_run_info,
node_run_infos=previous_node_run_infos_dict,
)
previous_run_results.append(previous_line_result)

return previous_run_results
except Exception as e:
bulk_logger.error(f"Error occurred while copying previous run result. Exception: {str(e)}")
Jasmin3q marked this conversation as resolved.
Show resolved Hide resolved
resume_copy_error = ResumeCopyError(
target=ErrorTarget.BATCH,
message_format="Failed to copy results when resuming the run. Error: {error_type_and_message}.",
error_type_and_message=f"({e.__class__.__name__}) {e}",
)
raise resume_copy_error from e

def cancel(self):
"""Cancel the batch run"""
self._is_canceled = True
Expand All @@ -223,11 +297,13 @@ async def _exec_in_task(
run_id: str = None,
output_dir: Path = None,
raise_on_line_failure: bool = False,
previous_line_results: List[LineResult] = None,
) -> BatchResult:
# if the batch run is canceled, asyncio.CancelledError will be raised and no results will be returned,
# so we pass empty line results list and aggr results and update them in _exec so that when the batch
# run is canceled we can get the current completed line results and aggr results.
line_results: List[LineResult] = []
line_results.extend(previous_line_results or [])
aggr_result = AggregationResult({}, {}, {})
task = asyncio.create_task(
self._exec(line_results, aggr_result, batch_inputs, run_id, output_dir, raise_on_line_failure)
Expand Down Expand Up @@ -260,13 +336,18 @@ async def _exec(
batch_inputs = [
apply_default_value_for_input(self._flow.inputs, each_line_input) for each_line_input in batch_inputs
]

existing_results_line_numbers = set([r.run_info.index for r in line_results])
bulk_logger.info(f"Skipped the execution of {len(existing_results_line_numbers)} existing results.")
inputs_to_run = [input for input in batch_inputs if input[LINE_NUMBER_KEY] not in existing_results_line_numbers]

run_id = run_id or str(uuid.uuid4())

# execute lines
is_timeout = False
if isinstance(self._executor_proxy, PythonExecutorProxy):
results, is_timeout = self._executor_proxy._exec_batch(
batch_inputs,
inputs_to_run,
output_dir,
run_id,
batch_timeout_sec=self._batch_timeout_sec,
Expand All @@ -278,7 +359,6 @@ async def _exec(
# TODO: Enable batch timeout for other api based executor proxy
await self._exec_batch(line_results, batch_inputs, run_id)
handle_line_failures([r.run_info for r in line_results], raise_on_line_failure)

# persist outputs to output dir
outputs = [
{LINE_NUMBER_KEY: r.run_info.index, **r.output}
Expand Down
12 changes: 2 additions & 10 deletions src/promptflow/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from promptflow._utils.execution_utils import (
apply_default_value_for_input,
collect_lines,
extract_aggregation_inputs,
get_aggregation_inputs_properties,
)
from promptflow._utils.logger_utils import flow_logger, logger
Expand Down Expand Up @@ -660,15 +661,6 @@ def _exec_in_thread(self, args) -> LineResult:
self._completed_idx[line_number] = thread_name
return results

def _extract_aggregation_inputs(self, nodes_outputs: dict):
return {
prop: self._extract_aggregation_input(nodes_outputs, prop) for prop in self._aggregation_inputs_references
}

def _extract_aggregation_input(self, nodes_outputs: dict, aggregation_input_property: str):
assign = InputAssignment.deserialize(aggregation_input_property)
return _input_assignment_parser.parse_value(assign, nodes_outputs, {})

def exec_line(
self,
inputs: Mapping[str, Any],
Expand Down Expand Up @@ -833,7 +825,7 @@ def _exec_inner(
run_tracker.persist_selected_node_runs(run_info, generator_output_nodes)
run_tracker.allow_generator_types = allow_generator_output
run_tracker.end_run(run_info.run_id, result=output)
aggregation_inputs = self._extract_aggregation_inputs(nodes_outputs)
aggregation_inputs = extract_aggregation_inputs(self._flow, nodes_outputs)
return output, aggregation_inputs

def _exec(
Expand Down
Loading
Loading