From 79cd50dad7d8140ad6c3de6457f1fa94bd59c015 Mon Sep 17 00:00:00 2001 From: gareth-cross Date: Fri, 1 Jul 2022 14:05:06 -0700 Subject: [PATCH] Add format_matrix_subscript_accessor and update_template_data to CodegenConfig --- symforce/codegen/codegen.py | 3 ++- symforce/codegen/codegen_config.py | 22 ++++++++++++++++++++++ symforce/codegen/codegen_util.py | 3 ++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/symforce/codegen/codegen.py b/symforce/codegen/codegen.py index a7905ff3c..0d6155d2b 100644 --- a/symforce/codegen/codegen.py +++ b/symforce/codegen/codegen.py @@ -398,8 +398,9 @@ def generate_function( self.namespace = namespace template_data = dict(self.common_data(), spec=self) - template_dir = self.config.template_dir() + self.config.update_template_data(data=template_data) + template_dir = self.config.template_dir() backend_name = self.config.backend_name() if skip_directory_nesting: out_function_dir = output_dir diff --git a/symforce/codegen/codegen_config.py b/symforce/codegen/codegen_config.py index dda67284a..f608b9284 100644 --- a/symforce/codegen/codegen_config.py +++ b/symforce/codegen/codegen_config.py @@ -77,3 +77,25 @@ def format_data_accessor(prefix: str, index: int) -> str: Format data for accessing a data array in code. """ return f"{prefix}.data[{index}]" + + # TODO: Move this into code printer. + @staticmethod + def format_matrix_subscript_accessor(key: str, i: int, j: int) -> str: + """ + Format accessing a matrix element (i, j) in code. + Args: + key (str): Name of the variable being accessed. + i (int): Row + j (int): Column + """ + return f"{key}({i}, {j})" + + def update_template_data(self, data: T.Dict[str, T.Any]) -> None: + """ + Derived classes may override this to customize the "template data" dict. This dict + is passed to jinja at code-generation time. + + Args: + data: Dict passed by Codegen. The function should modify this in-place. + """ + pass diff --git a/symforce/codegen/codegen_util.py b/symforce/codegen/codegen_util.py index b8a6fd360..7597fdf78 100644 --- a/symforce/codegen/codegen_util.py +++ b/symforce/codegen/codegen_util.py @@ -351,7 +351,8 @@ def get_formatted_list( formatted_symbols = [] for j in range(value.shape[1]): for i in range(value.shape[0]): - formatted_symbols.append(sf.Symbol(f"{key}({i}, {j})")) + formatted_subscript = config.format_matrix_subscript_accessor(key=key, i=i, j=j) + formatted_symbols.append(sf.Symbol(formatted_subscript)) flattened_value = ops.StorageOps.to_storage(value)