Skip to content

Commit ae89efb

Browse files
authored
Merge pull request #66 from trishullab/bug/complicated-have
Major release with final round of fixes.
2 parents 150f8b2 + aab3ba1 commit ae89efb

File tree

5 files changed

+115
-30
lines changed

5 files changed

+115
-30
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [
55
build-backend = "hatchling.build"
66
[project]
77
name = "itp_interface"
8-
version = "1.1.19"
8+
version = "1.2.0"
99
authors = [
1010
{ name="Amitayush Thakur", email="amitayush@utexas.edu" },
1111
]

src/itp_interface/rl/simple_proof_env.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,12 @@ def render(self):
286286
self.logger.info("-"*50)
287287
pass
288288

289-
def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[str, typing.Any] = None):
289+
def collect_proof_search_result(self, additional_info: typing.Dict[str, typing.Any] = None) -> ProofSearchResult:
290290
assert self._loaded, "Env not loaded, call reset() first"
291+
if not hasattr(self, 'proof_search_res'):
292+
self.proof_search_res = None
293+
if self.proof_search_res is not None:
294+
return self.proof_search_res
291295
self.goal_end_time = time.time()
292296
self.time_taken = self.goal_end_time - self.goal_start_time
293297
proof_steps = [TheoremProvingTrainingDataFormat(proof_steps=tactic.proof_steps) for _, tactic in self._p_tree.tactics]
@@ -306,6 +310,27 @@ def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[st
306310
longest_success_path=-1,
307311
additional_info=additional_info,
308312
language=self.language)
313+
return self.proof_search_res
314+
315+
def max_proof_step_length(self) -> int|None:
316+
assert self._loaded, "Env not loaded, call reset() first"
317+
# Only happens for Lean 4
318+
if self.language != ProofAction.Language.LEAN4:
319+
return None
320+
assert isinstance(self._dynamic_proof_executor, DynamicLean4ProofExecutor), "Dynamic proof executor must be of type DynamicLean4ProofExecutor"
321+
return self._dynamic_proof_executor.max_threshold_for_tactic_length
322+
323+
def set_max_proof_step_length(self, max_length: int):
324+
assert self._loaded, "Env not loaded, call reset() first"
325+
# Only happens for Lean 4
326+
if self.language != ProofAction.Language.LEAN4:
327+
raise NotImplementedError("set_max_proof_step_length is only implemented for Lean 4")
328+
assert isinstance(self._dynamic_proof_executor, DynamicLean4ProofExecutor), "Dynamic proof executor must be of type DynamicLean4ProofExecutor"
329+
self._dynamic_proof_executor.max_threshold_for_tactic_length = max_length
330+
331+
def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[str, typing.Any] = None):
332+
assert self._loaded, "Env not loaded, call reset() first"
333+
self.proof_search_res = self.collect_proof_search_result(additional_info=additional_info)
309334
self.logger.info(f"Dumping proof search result:\n {self.proof_search_res}")
310335
if dump_file_name is not None:
311336
opening_mode = 'a' if os.path.exists(dump_file_name) else 'w'

src/itp_interface/tools/simple_lean4_sync_executor.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(self,
115115
self._last_tactic_was_modified = False
116116
self._last_modified_tactic : str | None = None
117117
self._recursion_depth = 0
118+
self.max_threshold_for_tactic_length = 575 # Max 575 characters for a tactic
118119
if self._enable_search:
119120
pass
120121
pass
@@ -166,6 +167,7 @@ def reset(self,
166167
self._last_tactic_was_modified = False
167168
self._last_modified_tactic : str | None = None
168169
self._recursion_depth = 0
170+
self.max_threshold_for_tactic_length = 575 # Max 575 characters for a tactic
169171
if self._enable_search:
170172
pass
171173
pass
@@ -285,7 +287,7 @@ def _get_lean_code_with_tactics(self, idx: int, stmt: str):
285287
tactics_so_far = self._get_tactics_so_far()
286288
assert len(tactics_so_far) > 0, "There should be at least one tactic so far"
287289
_ , _, theorem_stmt = self._last_theorem
288-
return theorem_stmt + tactics_so_far
290+
return theorem_stmt + "\n" + tactics_so_far + "\n"
289291

290292
def _backtrack_tactic_line(self, idx: int):
291293
# identify the keys to remove
@@ -354,7 +356,8 @@ def _reset_proof_context(self):
354356
def _set_proof_context(self,
355357
proof_is_running: bool,
356358
proof_goal_messages: List[str],
357-
last_tactic: LeanLineInfo):
359+
last_tactic: LeanLineInfo,
360+
errors: List[ErrorInfo]):
358361
self._proof_running = proof_is_running
359362
if self._proof_running:
360363
proof_goals = []
@@ -363,13 +366,23 @@ def _set_proof_context(self,
363366
else:
364367
proof_goals = [g_text for g_text in proof_goal_messages
365368
if g_text is not None and len(g_text) > 0]
366-
self.proof_context = self._parse_proof_context(proof_goals)
367-
if self.proof_context == ProofContext.empty() and \
368-
((self._enforce_qed and last_tactic.text.strip() == "done") or not self._enforce_qed):
369-
self._reset_proof_context()
369+
if len(proof_goals) == 0 and len(errors) > 0:
370+
# This means there are some errors which are similar to
371+
# masquerading as missing alignment or indentation errors
372+
# Ask to fix indentation or add an extra return to get the states
373+
self.lean_error_messages = [
374+
"The tactic seems to be correct but seems to have indentation issues."
375+
" Please check the proof steps and try adding appropriate indentation or add line breaks to fix the issue."
376+
]
377+
else:
378+
self.proof_context = self._parse_proof_context(proof_goals)
379+
if self.proof_context == ProofContext.empty() and \
380+
((self._enforce_qed and last_tactic.text.strip() == "done") or not self._enforce_qed):
381+
self._reset_proof_context()
382+
self.lean_error_messages.clear()
370383
else:
371384
self.proof_context : ProofContext | None = None
372-
self.lean_error_messages.clear()
385+
self.lean_error_messages.clear()
373386

374387
def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[ErrorInfo]) -> int:
375388
# See all goal related error messages
@@ -383,7 +396,7 @@ def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[Erro
383396
if tactic.text.strip().startswith("have"):
384397
# Check if there is any goal related error after this tactic
385398
for error in goal_related:
386-
if error.position.line == tactic.end_line:
399+
if error.position.line == tactic.end_line or error.position.line - 1 == tactic.end_line:
387400
nested_have_count += 1
388401
return nested_have_count
389402

@@ -461,7 +474,7 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors:
461474
if have_error_message is None:
462475
self._nested_have_counts = self._get_nested_haves_count(tactics, errors)
463476
self._nested_calc_counts = self._get_nested_calc_count(tactics, errors)
464-
self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic)
477+
self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic, errors)
465478
else:
466479
self._backtrack_tactic_line(idx)
467480
self.lean_error_messages = [have_error_message]
@@ -504,13 +517,19 @@ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors:
504517
def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = False):
505518
assert self.tactic_parser is not None, "Tactic parser is not initialized"
506519
assert self._content_till_last_theorem_stmt is not None, "Content till last theorem statement should not be None"
507-
if len(stmt) > SimpleLean4SyncExecutor.max_threshold_for_tactic_length:
520+
if len(stmt) > self.max_threshold_for_tactic_length:
508521
self.lean_error_messages = [
509522
"The tactic length exceeds the maximum threshold of"
510-
f" {SimpleLean4SyncExecutor.max_threshold_for_tactic_length} characters."
523+
f" {self.max_threshold_for_tactic_length} characters."
511524
" Please break down the tactic into smaller steps. And execute them one by one."
512525
]
513526
return
527+
if '✝' in stmt:
528+
self.lean_error_messages = [
529+
"The tactic tries to use hypothesis ending with '✝', which are hidden."
530+
" Please use the `rename_i` tactic to rename such hypotheses, before using them."
531+
]
532+
return
514533
if ("sorry" in stmt or "admit" in stmt) and self._proof_running:
515534
# We don't need to run the sorry statements. This should be treated as a failed proof step
516535
self.lean_error_messages = ["The tactic 'sorry/admit' was found in the statement, this is not allowed"]
@@ -535,7 +554,7 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool =
535554
proof_should_run = False
536555
if theorem_started:
537556
# Load the theorem context at once
538-
self.tactic_parser.parse(
557+
full_parse_tacitcs, errors = self.tactic_parser.parse(
539558
self._content_till_last_theorem_stmt,
540559
fail_on_error=True,
541560
parse_type=RequestType.CHKPT_TACTICS
@@ -549,7 +568,6 @@ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool =
549568
while not code_was_executed:
550569
# Run the statement in tactic mode
551570
code = self._get_lean_code_with_tactics(idx, stmt)
552-
self.logger.info(f"Running tactic on lean server at line {self.line_num}:\n{code}")
553571
tactics, error_info = self.tactic_parser.parse(
554572
code,
555573
fail_on_error=False,

src/itp_interface/tools/tactic_parser.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,26 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
747747
print_tactics(tactics)
748748
if errors:
749749
print(f"Error: {errors}")
750+
750751
p_path = "/home/amthakur/Projects/copra/data/test/miniF2F-lean4"
752+
with TacticParser(project_path=p_path) as parser:
753+
# Example 1a: Simple proof with multiple tactics
754+
lean_code = """import MiniF2F.Minif2fImport
755+
open BigOperators Real Nat Topology
756+
757+
theorem amc12_2000_p1
758+
(i m o : ℕ)
759+
(h₀ : i ≠ m ∧ m ≠ o ∧ o ≠ i)
760+
(h₁ : i*m*o = 2001) :
761+
i+m+o ≤ 671 :=by
762+
have hprimes : i ∈ {3, 23, 29} ∧ m ∈ {3, 23, 29} ∧ o ∈ {3, 23, 29} := by
763+
"""
764+
print("Parsing example 1a...")
765+
tactics, errors = parser.parse(lean_code, fail_on_error=False)
766+
print_tactics(tactics)
767+
if errors:
768+
print(f"Error: {errors}")
769+
751770
with TacticParser(project_path=p_path) as parser:
752771
# Example 1a: Simple proof with multiple tactics
753772
lean_code = """
@@ -762,6 +781,27 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
762781
z / x = 7 / 25 :=
763782
by
764783
have h1': x = 5 * y / 2 := by ring
784+
"""
785+
print("Parsing example 1a...")
786+
tactics, errors = parser.parse(lean_code, fail_on_error=False)
787+
print_tactics(tactics)
788+
if errors:
789+
print(f"Error: {errors}")
790+
791+
with TacticParser(project_path=p_path) as parser:
792+
# Example 1a: Simple proof with multiple tactics
793+
lean_code = """import MiniF2F.Minif2fImport
794+
open BigOperators Real Nat Topology
795+
796+
theorem mathd_numbertheory_495
797+
(a b : ℕ)
798+
(h₀ : 0 < a ∧ 0 < b)
799+
(h₁ : a % 10 = 2)
800+
(h₂ : b % 10 = 4)
801+
(h₃ : Nat.gcd a b = 6) :
802+
108 ≤ Nat.lcm a b :=
803+
by
804+
apply?
765805
"""
766806
print("Parsing example 1a...")
767807
tactics, errors = parser.parse(lean_code, fail_on_error=False)
@@ -793,7 +833,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
793833
lean_code2 = "example (r: Nat) (p q : Prop) (hp : p) (hq : q) : p ∧ q := by\n apply And.intro\n exact hp\n exact hq"
794834

795835
print("\nParsing example 2...")
796-
tactics2, errors = parser.parse(lean_code2)
836+
tactics2, errors = parser.parse(lean_code2, fail_on_error=False)
797837
print_tactics(tactics2)
798838
if errors:
799839
print(f"Error: {errors}")

src/test/simple_env_test.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22
import os
33
from itp_interface.tools.tactic_parser import build_lean4_project, build_tactic_parser_if_needed
4-
from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
54

65
def pretty_print(s1, s2, proof_step, done):
76
print(f"Current Goal:")
@@ -287,6 +286,7 @@ def test_simple_lean_calc(self):
287286
"_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]"
288287
]
289288
with env:
289+
env.set_max_proof_step_length(10000)
290290
proof_was_finished = False
291291
for proof_step in proof_steps:
292292
state, _, next_state, _, done, info = env.step(ProofAction(
@@ -347,6 +347,7 @@ def test_simple_lean_calc_with_validation(self):
347347
"_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]"
348348
]
349349
with env:
350+
env.set_max_proof_step_length(10000)
350351
proof_was_finished = False
351352
for proof_step in proof_steps:
352353
state, _, next_state, _, done, info = env.step(ProofAction(
@@ -415,6 +416,7 @@ def test_simple_lean_enforce_done_test(self):
415416
"done"
416417
]
417418
with env:
419+
env.set_max_proof_step_length(10000)
418420
proof_finished = False
419421
for proof_step in proof_steps:
420422
state, _, next_state, _, done, info = env.step(ProofAction(
@@ -523,18 +525,19 @@ def test_simple_lean4_have_test(self):
523525
'rw [Nat.gcd_rec]',
524526
'rw [Nat.gcd_rec]',
525527
'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by',
526-
' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
527-
' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
528-
' have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith',
529-
' rw [Nat.mod_eq_of_lt]',
530-
' rw [Nat.mod_eq_of_lt]',
531-
' exact h₂',
532-
' rw [Nat.mod_eq_of_lt]',
533-
' exact h₂',
534-
' exact h₂',
528+
' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
529+
' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
530+
' have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith',
531+
' rw [Nat.mod_eq_of_lt]',
532+
' rw [Nat.mod_eq_of_lt]',
533+
' exact h₂',
534+
' rw [Nat.mod_eq_of_lt]',
535+
' exact h₂',
536+
' exact h₂',
535537
'rw [eq₂]'
536538
]
537539
with env:
540+
env.set_max_proof_step_length(10000)
538541
for proof_step in proof_steps:
539542
state, m_action, next_state, _, done, info = env.step(ProofAction(
540543
ProofAction.ActionType.RUN_TACTIC,
@@ -675,18 +678,17 @@ def test_simple_lean4_multiline_multigoal(self):
675678
assert proof_was_finished, "Proof was not finished"
676679

677680
def main():
678-
SimpleLean4SyncExecutor.max_threshold_for_tactic_length = 10000
679-
unittest.main()
681+
# unittest.main()
680682
# Run only the Lean 4 tests
681-
# t = Lean4Test()
683+
t = Lean4Test()
682684
# t.test_simple_lean4_multiline_multigoal()
683685
# t.test_simple_lean4()
684686
# t.test_lean4_backtracking()
685687
# t.test_simple_lean4_done_test()
686688
# t.test_simple_lean_calc()
687689
# t.test_simple_lean_calc_with_validation()
688690
# t.test_simple_lean4_with_error()
689-
# t.test_simple_lean4_have_test()
691+
t.test_simple_lean4_have_test()
690692
# t.test_simple_lean_enforce_done_test()
691693

692694

0 commit comments

Comments
 (0)