Skip to content

Commit

Permalink
Add test implementation of computation graph cache
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Oct 15, 2024
1 parent 1a8c4f9 commit 74003ce
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ test-pyright: clean
python ./coconut/tests/dest/extras.py

# same as test-univ but includes verbose output for better debugging
# regex for getting non-timing lines: ^(?!'|\s*(Time|Packrat|Loaded|Saving|Adaptive|Errorless|Grammar|Failed|Incremental|Pruned|Compiled)\s)[^\n]*\n*
# regex for getting non-informational lines to delete: ^(?!'|\s*(Time|Packrat|Loaded|Saving|Adaptive|Errorless|Grammar|Failed|Incremental|Pruned|Compiled|Computation)\s)[^\n]*\n*
.PHONY: test-verbose
test-verbose: export COCONUT_USE_COLOR=TRUE
test-verbose: clean
Expand Down
12 changes: 10 additions & 2 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
match_in,
transform,
parse,
cached_parse,
get_target_info_smart,
split_leading_comments,
compile_regex,
Expand Down Expand Up @@ -405,6 +406,7 @@ class Compiler(Grammar, pickleable_obj):
"""The Coconut compiler."""
lock = Lock()
current_compiler = None
computation_graph_caches = defaultdict(dict)

preprocs = [
lambda self: self.prepare,
Expand Down Expand Up @@ -1372,14 +1374,20 @@ def parse_line_by_line(self, init_parser, line_parser, original):
while cur_loc < len(original):
self.remaining_original = original[cur_loc:]
ComputationNode.add_to_loc = cur_loc
results = parse(init_parser if init else line_parser, self.remaining_original, inner=False)
parser = init_parser if init else line_parser
results = cached_parse(
self.computation_graph_caches[("line_by_line", parser)],
parser,
self.remaining_original,
inner=False,
)
if len(results) == 1:
got_loc, = results
else:
got, got_loc = results
out_parts.append(got)
got_loc = int(got_loc)
internal_assert(got_loc >= cur_loc and (init or got_loc > cur_loc), "invalid line by line parse", (cur_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0]))
internal_assert(got_loc >= cur_loc and (init or got_loc > cur_loc), "invalid line by line parse", (cur_loc, got_loc, results), extra=lambda: "in: " + repr(self.remaining_original.split("\n", 1)[0]))
cur_loc = got_loc
init = False
return "".join(out_parts)
Expand Down
2 changes: 0 additions & 2 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2781,8 +2781,6 @@ class Grammar(object):
rparen,
) + end_marker,
tco_return_handle,
# this is the root in what it's used for, so might as well evaluate greedily
greedy=True,
))

rest_of_lambda = Forward()
Expand Down
67 changes: 63 additions & 4 deletions coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def evaluate_tokens(tokens, **kwargs):

if not USE_COMPUTATION_GRAPH:
return tokens
final_evaluate_tokens.enabled = True # special variable used by cached_parse

if isinstance(tokens, ParseResults):

Expand Down Expand Up @@ -374,6 +375,8 @@ def evaluate(self):
# note that this should never cache, since if a greedy Wrap that doesn't add to the packrat context
# hits the cache, it'll get the same ComputationNode object, but since it's greedy that object needs
# to actually be reevaluated
if logger.tracing and not final_evaluate_tokens.enabled:
logger.log_tag("cached_parse invalidated by", self)
evaluated_toks = evaluate_tokens(self.tokens)
if logger.tracing: # avoid the overhead of the call if not tracing
logger.log_trace(self.name, self.original, self.loc, evaluated_toks, self.tokens)
Expand Down Expand Up @@ -523,12 +526,17 @@ def attach(item, action, ignore_no_tokens=None, ignore_one_token=None, ignore_ar

def final_evaluate_tokens(tokens):
"""Same as evaluate_tokens but should only be used once a parse is assured."""
if not final_evaluate_tokens.enabled: # handled by cached_parse
return tokens
result = evaluate_tokens(tokens, is_final=True)
# clear packrat cache after evaluating tokens so error creation gets to see the cache
clear_packrat_cache()
return result


final_evaluate_tokens.enabled = True


def final(item):
"""Collapse the computation graph upon parsing the given item."""
# evaluate_tokens expects a computation graph, so we just call add_action directly
Expand Down Expand Up @@ -674,17 +682,66 @@ def parse(grammar, text, inner=None, eval_parse_tree=True):
return result


def try_parse(grammar, text, inner=None, eval_parse_tree=True):
def cached_parse(computation_graph_cache, grammar, text, inner=None, eval_parse_tree=True):
"""Version of parse that caches the result when it's a pure ComputationNode."""
if not CPYPARSING: # caching is only supported on cPyparsing
return parse(grammar, text, inner)

for (prefix, is_at_end), tokens in computation_graph_cache.items():
# the assumption here is that if the prior parse didn't make it to the end,
# then we can freely change the text after the end of where it made it,
# but if it did make it to the end, then we can't add more text after that
if (
is_at_end and text == prefix
or not is_at_end and text.startswith(prefix)
):
if DEVELOP:
logger.record_stat("cached_parse", True)
logger.log_tag("cached_parse hit", (prefix, text[len(prefix):], tokens))
break
else: # no break
# disable token evaluation by final() to allow us to get a ComputationNode;
# this makes long parses very slow, however, so once a greedy parse action
# is hit such that evaluate_tokens gets called, evaluate_tokens will set
# final_evaluate_tokens.enabled back to True, which speeds up the rest of the
# parse and tells us that something greedy happened so we can't cache
final_evaluate_tokens.enabled = False
try:
with parsing_context(inner):
loc, tokens = prep_grammar(grammar, for_scan=False).parseString(text, returnLoc=True)
if not final_evaluate_tokens.enabled:
prefix = text[:loc + 1]
is_at_end = loc >= len(text)
computation_graph_cache[(prefix, is_at_end)] = tokens
finally:
if DEVELOP:
logger.record_stat("cached_parse", False)
logger.log_tag(
"cached_parse miss " + ("-> stored" if not final_evaluate_tokens.enabled else "(not stored)"),
text,
multiline=True,
)
final_evaluate_tokens.enabled = True

if eval_parse_tree:
tokens = unpack(tokens)
return tokens


def try_parse(grammar, text, inner=None, eval_parse_tree=True, computation_graph_cache=None):
"""Attempt to parse text using grammar else None."""
try:
return parse(grammar, text, inner, eval_parse_tree)
if computation_graph_cache is None:
return parse(grammar, text, inner, eval_parse_tree)
else:
return cached_parse(computation_graph_cache, grammar, text, inner, eval_parse_tree)
except ParseBaseException:
return None


def does_parse(grammar, text, inner=None):
def does_parse(grammar, text, inner=None, **kwargs):
"""Determine if text can be parsed using grammar."""
return try_parse(grammar, text, inner, eval_parse_tree=False)
return try_parse(grammar, text, inner, eval_parse_tree=False, **kwargs)


def all_matches(grammar, text, inner=None, eval_parse_tree=True):
Expand Down Expand Up @@ -1370,6 +1427,8 @@ def parseImpl(self, original, loc, *args, **kwargs):
with self.wrapped_context():
parse_loc, tokens = super(Wrap, self).parseImpl(original, loc, *args, **kwargs)
if self.greedy:
if logger.tracing and not final_evaluate_tokens.enabled:
logger.log_tag("cached_parse invalidated by", self)
tokens = evaluate_tokens(tokens)
if reparse and parse_loc is None:
raise CoconutInternalException("illegal double reparse in", self)
Expand Down
2 changes: 1 addition & 1 deletion coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ def get_path_env_var(env_var, default):

# min versions are inclusive
unpinned_min_versions = {
"cPyparsing": (2, 4, 7, 2, 4, 0),
"cPyparsing": (2, 4, 7, 2, 4, 1),
("pre-commit", "py3"): (3,),
("psutil", "py>=27"): (6,),
"jupyter": (1, 1),
Expand Down
16 changes: 11 additions & 5 deletions coconut/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,14 +561,18 @@ def trace(self, item):
return item

def record_stat(self, stat_name, stat_bool):
"""Record the given boolean statistic for the given stat_name."""
"""Record the given boolean statistic for the given stat_name.
All stats recorded here must have some printing logic added to gather_parsing_stats or log_compiler_stats.
Printed stats should also be added to the regex in the Makefile for getting non-informational lines."""
self.recorded_stats[stat_name][stat_bool] += 1

@contextmanager
def gather_parsing_stats(self):
"""Times parsing if --verbose."""
if self.verbose:
self.recorded_stats.pop("adaptive", None)
self.recorded_stats.pop("cached_parse", None)
start_time = get_clock_time()
try:
yield
Expand All @@ -584,6 +588,9 @@ def gather_parsing_stats(self):
if "adaptive" in self.recorded_stats:
failures, successes = self.recorded_stats["adaptive"]
self.printlog("\tAdaptive parsing stats:", successes, "successes;", failures, "failures")
if "cached_parse" in self.recorded_stats:
misses, hits = self.recorded_stats["cached_parse"]
self.printlog("\tComputation graph cache stats:", hits, "hits;", misses, "misses")
if maybe_make_safe is not None:
hits, misses = maybe_make_safe.stats
self.printlog("\tErrorless parsing stats:", hits, "errorless;", misses, "with errors")
Expand All @@ -595,10 +602,9 @@ def log_compiler_stats(self, comp):
if self.verbose:
self.log("Grammar init time: " + str(comp.grammar_init_time) + " secs / Total init time: " + str(get_clock_time() - first_import_time) + " secs")
for stat_name, (no_copy, yes_copy) in self.recorded_stats.items():
if not stat_name.startswith("maybe_copy_"):
continue
name = assert_remove_prefix(stat_name, "maybe_copy_")
self.printlog("\tGrammar copying stats (" + name + "):", no_copy, "not copied;", yes_copy, "copied")
if stat_name.startswith("maybe_copy_"):
name = assert_remove_prefix(stat_name, "maybe_copy_")
self.printlog("\tGrammar copying stats (" + name + "):", no_copy, "not copied;", yes_copy, "copied")

total_block_time = defaultdict(int)

Expand Down

0 comments on commit 74003ce

Please sign in to comment.