Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 0 additions & 123 deletions docs/pre_commit_migration_plan.md

This file was deleted.

31 changes: 24 additions & 7 deletions test_utils/pre_commit/sql_rewriter_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ def rewrite_sql(self, call_id: str, dry_run: bool = True, backup: bool = True):

call = matching_calls[0]

# Check if this call is flagged for manual review
if call.query_builder_equivalent and "# MANUAL:" in call.query_builder_equivalent:
# Check if this call is flagged for manual review or has a conversion error
if call.query_builder_equivalent and (
"# MANUAL:" in call.query_builder_equivalent
or "# Error" in call.query_builder_equivalent
):
print(
f"\n{Colors.YELLOW}Skipping {call.call_id[:12]} - flagged for manual review:{Colors.RESET}"
)
Expand Down Expand Up @@ -199,9 +202,11 @@ def rewrite_sql(self, call_id: str, dry_run: bool = True, backup: bool = True):
# Write the modified content
file_path.write_text(formatted_content, encoding="utf-8")

# Update registry
# Update registry and re-scan the modified file so line numbers stay current
# for any subsequent rewrites on the same file in this session
call.implementation_type = "query_builder"
call.notes = f"Converted by sql_rewriter on {call.updated_at}"
self.registry.scan_file(file_path)
self.registry.save_registry()

print(f"{Colors.GREEN}Successfully converted SQL to Query Builder{Colors.RESET}")
Expand Down Expand Up @@ -235,8 +240,11 @@ def rewrite_batch(
]
if matching:
call = matching[0]
# Skip calls flagged for manual review
if call.query_builder_equivalent and "# MANUAL:" in call.query_builder_equivalent:
# Skip calls flagged for manual review or with conversion errors
if call.query_builder_equivalent and (
"# MANUAL:" in call.query_builder_equivalent
or "# Error" in call.query_builder_equivalent
):
skipped_manual.append(call)
continue
calls_to_rewrite.append(call)
Expand Down Expand Up @@ -395,8 +403,17 @@ def replace_sql_in_content(self, content: str, call) -> str:
# Empty lines
indented_replacement.append("")

# Replace the lines
new_lines = lines[:start_line] + indented_replacement + lines[end_line + 1 :]
# If this call was the iterable in a for loop, reconstruct the for statement
if call.variable_name and call.variable_name.startswith("__for_"):
loop_var = call.variable_name[
len("__for_") : -2
] # strip __for_ prefix and __ suffix
for_line = indent_str + f"for {loop_var} in result:"
new_lines = (
lines[:start_line] + indented_replacement + [for_line] + lines[end_line + 1 :]
)
else:
new_lines = lines[:start_line] + indented_replacement + lines[end_line + 1 :]

return "\n".join(new_lines)

Expand Down
2 changes: 2 additions & 0 deletions test_utils/utils/sql_registry/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def convert_select_to_qb(
result_prefix = "yield "
elif variable_name == "__expr__":
result_prefix = ""
elif variable_name and variable_name.startswith("__for_"):
result_prefix = "result = "
elif variable_name:
result_prefix = f"{variable_name} = "
else:
Expand Down
40 changes: 23 additions & 17 deletions test_utils/utils/sql_registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,10 @@ def generate_report(self) -> str:
for call in calls.values():
by_type[call.implementation_type] = by_type.get(call.implementation_type, 0) + 1
if call.query_builder_equivalent:
if "# MANUAL:" in call.query_builder_equivalent:
if (
"# MANUAL:" in call.query_builder_equivalent
or "# Error" in call.query_builder_equivalent
):
manual_count += 1
elif "frappe.get_all(" in call.query_builder_equivalent:
orm_count += 1
Expand Down Expand Up @@ -451,22 +454,25 @@ def generate_report(self) -> str:
report += "| Call ID | Status | Line | Function | SQL Preview |\n"
report += "|---------|--------|------|----------|-------------|\n"

for call in sorted_calls:
status = "✅"
if call.query_builder_equivalent:
if "# MANUAL:" in call.query_builder_equivalent:
status = "🔧"
elif "frappe.get_all(" in call.query_builder_equivalent:
status = "💡"
elif "# TODO" in call.query_builder_equivalent:
status = "⚠️"

sql_preview = call.sql_query.replace("\n", " ").strip()[:50]
if len(call.sql_query) > 50:
sql_preview += "..."
sql_preview = sql_preview.replace("|", "\\|")
func_name = call.function_context[:25] if call.function_context else ""
report += f"| `{call.call_id[:8]}` | {status} | {call.line_number} | {func_name} | {sql_preview} |\n"
for call in sorted_calls:
status = "✅"
if call.query_builder_equivalent:
if (
"# MANUAL:" in call.query_builder_equivalent
or "# Error" in call.query_builder_equivalent
):
status = "🔧"
elif "frappe.get_all(" in call.query_builder_equivalent:
status = "💡"
elif "# TODO" in call.query_builder_equivalent:
status = "⚠️"

sql_preview = call.sql_query.replace("\n", " ").strip()[:50]
if len(call.sql_query) > 50:
sql_preview += "..."
sql_preview = sql_preview.replace("|", "\\|")
func_name = call.function_context[:25] if call.function_context else ""
report += f"| `{call.call_id[:8]}` | {status} | {call.line_number} | {func_name} | {sql_preview} |\n"

report += "\n"

Expand Down
9 changes: 9 additions & 0 deletions test_utils/utils/sql_registry/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ def extract_variable_name(tree: ast.AST, call_node: ast.Call) -> str | None:
if node.value and node_contains(node.value, call_node):
return "__yield__"

for node in ast.walk(tree):
if isinstance(node, ast.For):
if node_contains(node.iter, call_node):
if isinstance(node.target, ast.Name):
return f"__for_{node.target.id}__"
return "__for_iter__"

for node in ast.walk(tree):
if isinstance(node, ast.Expr):
if node_contains(node.value, call_node):
Expand Down Expand Up @@ -452,6 +459,8 @@ def convert_select_to_orm(
result_prefix = "yield "
elif variable_name == "__expr__":
result_prefix = ""
elif variable_name and variable_name.startswith("__for_"):
result_prefix = "result = "
elif variable_name:
result_prefix = f"{variable_name} = "
else:
Expand Down