diff --git a/motion/operations/gather.py b/motion/operations/gather.py index a6df9464..ad181e0a 100644 --- a/motion/operations/gather.py +++ b/motion/operations/gather.py @@ -84,7 +84,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: "main_chunk_start", "--- Begin Main Chunk ---" ) main_chunk_end = self.config.get("main_chunk_end", "--- End Main Chunk ---") - doc_header_keys = self.config.get("doc_header_keys", []) + doc_header_key = self.config.get("doc_header_key", None) results = [] cost = 0.0 @@ -111,7 +111,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: order_key, main_chunk_start, main_chunk_end, - doc_header_keys, + doc_header_key, ) result = chunk.copy() @@ -129,7 +129,7 @@ def render_chunk_with_context( order_key: str, main_chunk_start: str, main_chunk_end: str, - doc_header_keys: List[Dict[str, Any]], + doc_header_key: str, ) -> str: """ Render a chunk with its peripheral context and headers. @@ -142,7 +142,7 @@ def render_chunk_with_context( order_key (str): Key for the order of each chunk. main_chunk_start (str): String to mark the start of the main chunk. main_chunk_end (str): String to mark the end of the main chunk. - doc_header_keys (List[Dict[str, Any]]): List of dicts containing 'header' and 'level' keys. + doc_header_key (str): The key for the headers in the current chunk. Returns: str: Renderted chunk with context and headers. @@ -164,7 +164,7 @@ def render_chunk_with_context( # Process main chunk main_chunk = chunks[current_index] headers = self.render_hierarchy_headers( - main_chunk, chunks[: current_index + 1], doc_header_keys + main_chunk, chunks[: current_index + 1], doc_header_key ) if headers: combined_parts.append(headers) @@ -270,7 +270,7 @@ def render_hierarchy_headers( self, current_chunk: Dict, chunks: List[Dict], - doc_header_keys: List[Dict[str, Any]], + doc_header_key: str, ) -> str: """ Render headers for the current chunk's hierarchy. @@ -278,16 +278,18 @@ def render_hierarchy_headers( Args: current_chunk (Dict): The current chunk being processed. chunks (List[Dict]): List of chunks up to and including the current chunk. - doc_header_keys (List[Dict[str, Any]]): List of dicts containing 'header' and 'level' keys. - + doc_header_key (str): The key for the headers in the current chunk. Returns: str: Renderted headers in the current chunk's hierarchy. """ rendered_headers = [] current_hierarchy = {} + if doc_header_key is None: + return "" + # Find the largest/highest level in the current chunk - current_chunk_headers = current_chunk.get(doc_header_keys, []) + current_chunk_headers = current_chunk.get(doc_header_key, []) highest_level = float("inf") # Initialize with positive infinity for header_info in current_chunk_headers: level = header_info.get("level") @@ -299,7 +301,7 @@ def render_hierarchy_headers( highest_level = None for chunk in chunks: - for header_info in chunk.get(doc_header_keys, []): + for header_info in chunk.get(doc_header_key, []): header = header_info["header"] level = header_info["level"] if header and level: diff --git a/motion/optimizers/map_optimizer/operation_creators.py b/motion/optimizers/map_optimizer/operation_creators.py index 3823947e..362b5005 100644 --- a/motion/optimizers/map_optimizer/operation_creators.py +++ b/motion/optimizers/map_optimizer/operation_creators.py @@ -120,7 +120,7 @@ def create_split_map_gather_operations( "content_key": content_key, "doc_id_key": f"{split_name}_id", "order_key": f"{split_name}_chunk_num", - "doc_header_keys": ("headers" if header_output_schema else []), + "doc_header_key": "headers" if header_output_schema else [], "peripheral_chunks": {}, } diff --git a/tests/test_synth_gather.py b/tests/test_synth_gather.py index 9e34bce3..1fa350d1 100644 --- a/tests/test_synth_gather.py +++ b/tests/test_synth_gather.py @@ -119,7 +119,7 @@ def test_synth_gather(config_yaml): assert "doc_id_key" in synthesized_op assert "order_key" in synthesized_op assert "peripheral_chunks" in synthesized_op - assert "doc_header_keys" in synthesized_op + assert "doc_header_key" in synthesized_op break if synthesized_gather_found: @@ -204,7 +204,7 @@ def test_split_map_gather(sample_data): "previous": {"tail": {"count": 1}}, "next": {"head": {"count": 1}}, }, - "doc_header_keys": "headers", + "doc_header_key": "headers", } # Initialize operations diff --git a/todos.md b/todos.md index cddcba6f..939f1593 100644 --- a/todos.md +++ b/todos.md @@ -64,6 +64,8 @@ TODO: - [x] Encode this in API somehow - [ ] Support this kind of chunking in the optimizer - [x] Extract headers & levels from documents, and add the level hierarchy to the chunk. +- [ ] Support tool use in operations +- [ ] Fix DSL to handle inputs like we've done in the Overleaf writeup - [ ] Support prompts exceeding context windows; figure out how to throw out data / prioritize elements - [ ] Support retries in the optimizers - [ ] Write tests for optimizers