-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix java export for huge models #306
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
from m2cgen import ast | ||
from m2cgen.interpreters.interpreter import BaseToCodeInterpreter | ||
from m2cgen.interpreters.utils import chunks | ||
|
||
|
||
class BinExpressionDepthTrackingMixin(BaseToCodeInterpreter): | ||
|
@@ -89,7 +90,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(), **kwargs): | |
*extra_func_args) | ||
|
||
|
||
Subroutine = namedtuple('Subroutine', ['name', 'expr']) | ||
Subroutine = namedtuple('Subroutine', ['name', 'idx', 'expr']) | ||
|
||
|
||
class SubroutinesMixin(BaseToCodeInterpreter): | ||
|
@@ -99,9 +100,11 @@ class SubroutinesMixin(BaseToCodeInterpreter): | |
|
||
Subclasses only need to implement `create_code_generator` method. | ||
|
||
Their code generators should implement 3 methods: | ||
Their code generators should implement 5 methods: | ||
- function_definition; | ||
- function_invocation; | ||
- module_definition; | ||
- module_function_invocation; | ||
- add_return_statement. | ||
|
||
Interpreter should prepare at least one subroutine using method | ||
|
@@ -112,6 +115,7 @@ class SubroutinesMixin(BaseToCodeInterpreter): | |
# disabled by default | ||
ast_size_check_frequency = sys.maxsize | ||
ast_size_per_subroutine_threshold = sys.maxsize | ||
subroutine_per_group_threshold = sys.maxsize | ||
|
||
def __init__(self, *args, **kwargs): | ||
self._subroutine_idx = 0 | ||
|
@@ -124,15 +128,33 @@ def process_subroutine_queue(self, top_code_generator): | |
subroutine queue. | ||
""" | ||
self._subroutine_idx = 0 | ||
subroutines = [] | ||
|
||
while len(self.subroutine_expr_queue): | ||
while self.subroutine_expr_queue: | ||
self._reset_reused_expr_cache() | ||
subroutine = self.subroutine_expr_queue.pop(0) | ||
subroutine_code = self._process_subroutine(subroutine) | ||
subroutines.append((subroutine, subroutine_code)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If my assumption in the previous comment is correct I don't think we actually need |
||
|
||
subroutines.sort(key=lambda subroutine: subroutine[0].idx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this sort given that subroutines are added in a specific order already (the |
||
|
||
groups = chunks(subroutines, self.subroutine_per_group_threshold) | ||
for _, subroutine_code in next(groups): | ||
top_code_generator.add_code_lines(subroutine_code) | ||
|
||
def enqueue_subroutine(self, name, expr): | ||
self.subroutine_expr_queue.append(Subroutine(name, expr)) | ||
for index, subroutine_group in enumerate(groups): | ||
cg = self.create_code_generator() | ||
|
||
with cg.module_definition( | ||
module_name=self._format_group_name(index + 1)): | ||
for _, subroutine_code in subroutine_group: | ||
cg.add_code_lines(subroutine_code) | ||
|
||
top_code_generator.add_code_lines( | ||
cg.finalize_and_get_generated_code()) | ||
|
||
def enqueue_subroutine(self, name, idx, expr): | ||
self.subroutine_expr_queue.append(Subroutine(name, idx, expr)) | ||
|
||
def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs): | ||
if isinstance(expr, ast.BinExpr) and not expr.to_reuse: | ||
|
@@ -145,7 +167,16 @@ def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs): | |
ast_size = ast.count_exprs(expr) | ||
if ast_size > self.ast_size_per_subroutine_threshold: | ||
function_name = self._get_subroutine_name() | ||
self.enqueue_subroutine(function_name, expr) | ||
|
||
self.enqueue_subroutine(function_name, self._subroutine_idx, expr) | ||
|
||
group_idx = self._subroutine_idx // self.subroutine_per_group_threshold | ||
if group_idx != 0: | ||
return self._cg.module_function_invocation( | ||
self._format_group_name(group_idx), | ||
function_name, | ||
self._feature_array_name), kwargs | ||
|
||
return self._cg.function_invocation(function_name, self._feature_array_name), kwargs | ||
|
||
kwargs['ast_size_check_counter'] = ast_size_check_counter | ||
|
@@ -191,6 +222,10 @@ def _get_subroutine_name(self): | |
self._subroutine_idx += 1 | ||
return subroutine_name | ||
|
||
@staticmethod | ||
def _format_group_name(group_idx): | ||
return f"SubroutineGroup{group_idx}" | ||
|
||
# Methods to be implemented by subclasses. | ||
|
||
def create_code_generator(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,3 +28,9 @@ def array_index_access(self, array_name, index): | |
|
||
def vector_init(self, values): | ||
return f"c({', '.join(values)})" | ||
|
||
def module_definition(self, module_name): | ||
raise NotImplementedError("Modules in r is not supported") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since "Modules" is plural I suggest to you use |
||
|
||
def module_function_invocation(self, module_name, function_name, *args): | ||
raise NotImplementedError("Modules in r is not supported") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like
subroutine_per_module_threshold
seems more accurate