Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Aug 27, 2024
1 parent e5729aa commit bcd0242
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
22 changes: 12 additions & 10 deletions motion/operations/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
"main_chunk_start", "--- Begin Main Chunk ---"
)
main_chunk_end = self.config.get("main_chunk_end", "--- End Main Chunk ---")
doc_header_keys = self.config.get("doc_header_keys", [])
doc_header_key = self.config.get("doc_header_key", None)
results = []
cost = 0.0

Expand All @@ -111,7 +111,7 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
order_key,
main_chunk_start,
main_chunk_end,
doc_header_keys,
doc_header_key,
)

result = chunk.copy()
Expand All @@ -129,7 +129,7 @@ def render_chunk_with_context(
order_key: str,
main_chunk_start: str,
main_chunk_end: str,
doc_header_keys: List[Dict[str, Any]],
doc_header_key: str,
) -> str:
"""
Render a chunk with its peripheral context and headers.
Expand All @@ -142,7 +142,7 @@ def render_chunk_with_context(
order_key (str): Key for the order of each chunk.
main_chunk_start (str): String to mark the start of the main chunk.
main_chunk_end (str): String to mark the end of the main chunk.
doc_header_keys (List[Dict[str, Any]]): List of dicts containing 'header' and 'level' keys.
doc_header_key (str): The key for the headers in the current chunk.
Returns:
str: Renderted chunk with context and headers.
Expand All @@ -164,7 +164,7 @@ def render_chunk_with_context(
# Process main chunk
main_chunk = chunks[current_index]
headers = self.render_hierarchy_headers(
main_chunk, chunks[: current_index + 1], doc_header_keys
main_chunk, chunks[: current_index + 1], doc_header_key
)
if headers:
combined_parts.append(headers)
Expand Down Expand Up @@ -270,24 +270,26 @@ def render_hierarchy_headers(
self,
current_chunk: Dict,
chunks: List[Dict],
doc_header_keys: List[Dict[str, Any]],
doc_header_key: str,
) -> str:
"""
Render headers for the current chunk's hierarchy.
Args:
current_chunk (Dict): The current chunk being processed.
chunks (List[Dict]): List of chunks up to and including the current chunk.
doc_header_keys (List[Dict[str, Any]]): List of dicts containing 'header' and 'level' keys.
doc_header_key (str): The key for the headers in the current chunk.
Returns:
str: Renderted headers in the current chunk's hierarchy.
"""
rendered_headers = []
current_hierarchy = {}

if doc_header_key is None:
return ""

# Find the largest/highest level in the current chunk
current_chunk_headers = current_chunk.get(doc_header_keys, [])
current_chunk_headers = current_chunk.get(doc_header_key, [])
highest_level = float("inf") # Initialize with positive infinity
for header_info in current_chunk_headers:
level = header_info.get("level")
Expand All @@ -299,7 +301,7 @@ def render_hierarchy_headers(
highest_level = None

for chunk in chunks:
for header_info in chunk.get(doc_header_keys, []):
for header_info in chunk.get(doc_header_key, []):
header = header_info["header"]
level = header_info["level"]
if header and level:
Expand Down
2 changes: 1 addition & 1 deletion motion/optimizers/map_optimizer/operation_creators.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create_split_map_gather_operations(
"content_key": content_key,
"doc_id_key": f"{split_name}_id",
"order_key": f"{split_name}_chunk_num",
"doc_header_keys": ("headers" if header_output_schema else []),
"doc_header_key": "headers" if header_output_schema else [],
"peripheral_chunks": {},
}

Expand Down
4 changes: 2 additions & 2 deletions tests/test_synth_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_synth_gather(config_yaml):
assert "doc_id_key" in synthesized_op
assert "order_key" in synthesized_op
assert "peripheral_chunks" in synthesized_op
assert "doc_header_keys" in synthesized_op
assert "doc_header_key" in synthesized_op

break
if synthesized_gather_found:
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_split_map_gather(sample_data):
"previous": {"tail": {"count": 1}},
"next": {"head": {"count": 1}},
},
"doc_header_keys": "headers",
"doc_header_key": "headers",
}

# Initialize operations
Expand Down
2 changes: 2 additions & 0 deletions todos.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ TODO:
- [x] Encode this in API somehow
- [ ] Support this kind of chunking in the optimizer
- [x] Extract headers & levels from documents, and add the level hierarchy to the chunk.
- [ ] Support tool use in operations
- [ ] Fix DSL to handle inputs like we've done in the Overleaf writeup
- [ ] Support prompts exceeding context windows; figure out how to throw out data / prioritize elements
- [ ] Support retries in the optimizers
- [ ] Write tests for optimizers
Expand Down

0 comments on commit bcd0242

Please sign in to comment.