From 391786ef38ed76c287ee33dd5044f72b197f32c6 Mon Sep 17 00:00:00 2001 From: Joseph Astier Date: Mon, 13 Nov 2023 15:21:58 -0700 Subject: [PATCH] MATLAB switch statement support (#634) ## MATLAB switch statement support This PR adds Tree-sitter coverage of MATLAB switch statements. CI tests are also included. ``` matlab switch x case 1 n = 1; j = 1; case {3, three, 'three'} n = 3; case {{1, 2, 3,}, {7, 8, 9}} n = 6; otherwise n = 0; end ``` ## Relevant features - Switch statements are translated into CAST conditional logic - Multiple argument cases are reduced to single list inclusion test - Case arguments may be of differing datatypes (string, number, identifier) ## Related issues - Resolves issue #561 --------- Co-authored-by: Joseph Astier --- skema/program_analysis/CAST/matlab/cast_out | 60 +++-- .../CAST/matlab/matlab_to_cast.py | 227 ++++++++++-------- .../CAST/matlab/node_helper.py | 16 +- .../CAST/matlab/tests/test_assignment.py | 46 +++- .../matlab/tests/test_binary_operation.py | 23 -- ...nditional_logic.py => test_conditional.py} | 41 +++- .../CAST/matlab/tests/test_operators.py | 38 +++ .../CAST/matlab/tests/test_switch.py | 70 ++++++ .../CAST/matlab/tests/utils.py | 21 +- skema/program_analysis/CAST/matlab/tree_out | 5 +- .../CAST/matlab/variable_context.py | 8 - 11 files changed, 375 insertions(+), 180 deletions(-) delete mode 100644 skema/program_analysis/CAST/matlab/tests/test_binary_operation.py rename skema/program_analysis/CAST/matlab/tests/{test_conditional_logic.py => test_conditional.py} (57%) create mode 100644 skema/program_analysis/CAST/matlab/tests/test_operators.py create mode 100644 skema/program_analysis/CAST/matlab/tests/test_switch.py diff --git a/skema/program_analysis/CAST/matlab/cast_out b/skema/program_analysis/CAST/matlab/cast_out index 91cb0adcc04..2d332dac26b 100755 --- a/skema/program_analysis/CAST/matlab/cast_out +++ b/skema/program_analysis/CAST/matlab/cast_out @@ -2,25 +2,55 @@ import json import sys from skema.program_analysis.CAST.matlab.matlab_to_cast import MatlabToCast +from typing import List + +# Show a CAST object as pretty printed JSON with keys filtered for clarity. + +# Keys to remove from the output +#KEY_FILTER = ["source_refs", "default_value", "interpreter"] +KEY_FILTER = ["source_refs"] + +def remove_keys(json_obj, target_keys: List): + """ remove all instances of the target keys from the json object""" + + def remove_key(json_obj, target_key): + """ remove all instances of the target key from the json object""" + if isinstance(json_obj, dict): + for target_key in target_keys: + if target_key in json_obj.keys(): + json_obj.pop(target_key) + for key in json_obj.keys(): + remove_key(json_obj[key], target_keys) + elif isinstance(json_obj, list): + for item in json_obj: + remove_key(item, target_keys) + return(json_obj) + + print(f"Removed keys: {target_keys}") + return remove_key(json_obj, target_keys) + + +def show_cast(filename): + """ Run a file of any type through the Tree-sitter MATLAB parser""" + parser = MatlabToCast(filename) + print("\nINPUT FILE:") + print(parser.filename) + print("\nSOURCE:") + print(parser.source) + print('\nCAST:') + cast_list = parser.out_cast + for cast in cast_list: + json_obj = cast.to_json_object() + # declutter JSON by filtering keys + json_obj = remove_keys(json_obj, KEY_FILTER) + # pretty print JSON to string + output = json.dumps(json_obj, sort_keys=True, indent=2) + print(output) -""" Run a file of any type through the Tree-sitter MATLAB parser""" if __name__ == "__main__": if len(sys.argv) > 1: for i in range(1, len(sys.argv)): - parser = MatlabToCast(sys.argv[i]) - print("\n\nINPUT:") - print(parser.filename) - print("\nSOURCE:") - print(parser.source) - print('\nCAST:') - cast_list = parser.out_cast - for cast_index in range(0, len(cast_list)): - jd = json.dumps( - cast_list[cast_index].to_json_object(), - sort_keys=True, - indent=2, - ) - print(jd) + show_cast(sys.argv[i]) else: print("Please enter one filename to parse") diff --git a/skema/program_analysis/CAST/matlab/matlab_to_cast.py b/skema/program_analysis/CAST/matlab/matlab_to_cast.py index d29bc8bfeb0..01305fff5f5 100644 --- a/skema/program_analysis/CAST/matlab/matlab_to_cast.py +++ b/skema/program_analysis/CAST/matlab/matlab_to_cast.py @@ -30,6 +30,7 @@ from skema.program_analysis.CAST.matlab.variable_context import VariableContext from skema.program_analysis.CAST.matlab.node_helper import ( + get_all, get_children_by_types, get_control_children, get_first_child_by_type, @@ -38,6 +39,7 @@ get_non_control_children, remove_comments, NodeHelper, + valid ) from skema.program_analysis.tree_sitter_parsers.build_parsers import INSTALLED_LANGUAGES_FILEPATH @@ -45,8 +47,10 @@ MATLAB_VERSION='matlab_version_here' class MatlabToCast(object): + + literal_types = ["number","string", "boolean", "array_literal"] + def __init__(self, source_path = "", source = ""): - """docstring""" # if a source file path is provided, read source from file if not source_path == "": @@ -78,6 +82,7 @@ def generate_cast(self) -> List[CAST]: """Interface for generating CAST.""" # remove comments from tree before processing + modules = self.run(remove_comments(self.tree.root_node)) return [CAST([module], "matlab") for module in modules] @@ -96,7 +101,6 @@ def run_old(self, root) -> List[Module]: "function_definition", "subroutine", "assignment", - "switch_statement" ] outer_body_nodes = get_children_by_types(root, body_node_names) @@ -118,7 +122,9 @@ def run_old(self, root) -> List[Module]: def visit(self, node): """Switch execution based on node type""" - # print(f"\nvisit node type = {node.type}") + # print(f"\nvisit {node.type}") + if node == None: + return None if node.type in ["program", "module", "source_file"] : return self.visit_module(node) @@ -144,7 +150,7 @@ def visit(self, node): "math_expression", "relational_expression" ]: return self.visit_math_expression(node) - elif node.type in ["number", "array", "string", "boolean"]: + elif node.type in self.literal_types: return self.visit_literal(node) elif node.type == "keyword_statement": return self.visit_keyword_statement(node) @@ -152,12 +158,10 @@ def visit(self, node): return self.visit_extent_specifier(node) elif node.type == "do_loop_statement": return self.visit_do_loop_statement(node) + elif node.type == "switch_statement": + return self.visit_switch_statement(node) elif node.type == "if_statement": return self.visit_if_statement(node) - elif node.type == "elseif_clause": - return self.visit_elseif_clause(node) - elif node.type == "else_clause": - return self.visit_else_clause(node) elif node.type == "derived_type_definition": return self.visit_derived_type(node) elif node.type == "derived_type_member_expression": @@ -167,7 +171,6 @@ def visit(self, node): def visit_module(self, node: Node) -> Module: """Visitor for program and module statement. Returns a Module object""" - # print('visit_module') self.variable_context.push_context() program_body = [] @@ -188,13 +191,10 @@ def visit_module(self, node: Node) -> Module: def visit_internal_procedures(self, node: Node) -> List[FunctionDef]: """Visitor for internal procedures. Returns list of FunctionDef""" - # print('visit_internal_procedures') internal_procedures = get_children_by_types(node, ["function_definition", "subroutine"]) return [self.visit(procedure) for procedure in internal_procedures] def visit_name(self, node): - """Docstring""" - # print('visit_name') # Node structure # (name) @@ -208,8 +208,6 @@ def visit_name(self, node): ) def visit_function_def(self, node): - """Docstring""" - # print('visit_function_def') # TODO: Refactor function def code to use new helper functions # Node structure # (subroutine) @@ -316,8 +314,6 @@ def visit_function_def(self, node): ) def visit_function_call(self, node): - """Docstring""" - # print('visit_function_call') # Pull relevent nodes if node.type == "subroutine_call": function_node = node.children[1] @@ -359,8 +355,6 @@ def visit_function_call(self, node): ) def visit_keyword_statement(self, node): - """Docstring""" - # print('visit_keyword_statement') # Currently, the only keyword_identifier produced by tree-sitter is Return # However, there may be other instances @@ -396,8 +390,6 @@ def visit_keyword_statement(self, node): ) def visit_use_statement(self, node): - """Docstring""" - # print('visit_use_statement') # (use) # (use) # (module_name) @@ -457,7 +449,6 @@ def visit_do_loop_statement(self, node) -> Loop: (body) ... """ - # print('visit_do_loop_statement') # First check for # TODO: Add do until Loop support while_statement_node = get_first_child_by_type(node, "while_statement") @@ -578,73 +569,128 @@ def visit_do_loop_statement(self, node) -> Loop: source_refs=[self.node_helper.get_source_ref(node)], ) - def visit_if_statement(self, node): - """ return a ModelIf if, elseif, and else clauses""" - - # print('visit_if_statement') - - # if_statement Tree-sitter syntax tree: - # if - # comparison_operator - # body block with 1-n elements - # elseif_clause (0-n of these) - # else_clause (0-1 of these) - # end - - # the initial ModelIf node is built just like the else-if clause - mi = self.visit_elseif_clause(node) - - # get 0-n elseif_clauses - elseif_clauses = get_children_by_types(node, ("elseif_clause")) - for child in elseif_clauses: - elseif_node = self.visit(child) - if elseif_node: - if not mi.orelse: - mi.orelse = list() - mi.orelse.append(elseif_node) - - # get 0-1 else_clauses - else_clauses = get_children_by_types(node, ("else_clause")) - for child in else_clauses: - else_node = self.visit(child) - if else_node.body: - for body_node in else_node.body: - if not mi.orelse: - mi.orelse = list() - mi.orelse.append(body_node) - - return mi + + def visit_switch_statement(self, node): + """ return a conditional statement based on the switch statement """ + # node types used for case comparison + case_node_types = self.literal_types + ["identifier"] + + def get_node_value(ast_node): + """ return the CAST node value or var name """ + if isinstance(ast_node, Var): + return ast_node.val.name + return ast_node.value + + def get_operator(op, operands, source_refs): + """ return an Operator representing the case test """ + return Operator( + source_language = "matlab", + interpreter = None, + version = MATLAB_VERSION, + op = op, + operands = operands, + source_refs = source_refs + ) + + def get_case_expression(case_node, identifier): + """ return an Operator representing the case test """ + cell_node = get_first_child_by_type(case_node, "cell") + source_refs = self.node_helper.get_source_ref(case_node) + # multiple case arguments + if (cell_node): + nodes = get_all(cell_node, case_node_types) + ast_nodes = valid([self.visit(node) for node in nodes]) + operand = LiteralValue( + value_type="List", + value = [get_node_value(node) for node in ast_nodes], + source_code_data_type=["matlab", MATLAB_VERSION, "unknown"], + source_refs=[self.node_helper.get_source_ref(cell_node)] + ) + return get_operator("in", [identifier, operand], source_refs) + # single case argument + nodes = get_children_by_types(case_node, case_node_types) + operand = valid([self.visit(node) for node in nodes])[0] + return get_operator("==", [identifier, operand], source_refs) + + def get_case_body(case_node): + """ return the instruction block for the case """ + block = get_first_child_by_type(case_node, "block") + if block: + return valid([self.visit(child) for child in block.children]) + return None + + def get_model_if(case_node, identifier): + """ return conditional logic representing the case """ + return ModelIf( + expr = get_case_expression(case_node, identifier), + body = get_case_body(case_node), + source_refs=[self.node_helper.get_source_ref(case_node)] + ) + + # switch statement identifier + identifier = self.visit(get_first_child_by_type(node, "identifier")) + + # n case clauses as 'if then' nodes + case_nodes = get_children_by_types(node, ["case_clause"]) + model_ifs = [get_model_if(node, identifier) for node in case_nodes] + for i, model_if in enumerate(model_ifs[1:]): + model_ifs[i].orelse = [model_if] + + # otherwise clause as 'else' node after last 'if then' node + otherwise_clause = get_first_child_by_type(node, "otherwise_clause") + if otherwise_clause: + block = get_first_child_by_type(otherwise_clause, "block") + if block: + last = model_ifs[len(model_ifs)-1] + last.orelse = valid([self.visit(child) for child in block.children]) + + return model_ifs[0] - def visit_elseif_clause(self, node): - """ return a ModelIf with comparison and body nodes. """ - # get ModelIf with body nodes - mi = self.visit_else_clause(node) - # addd comparison operator - comp: Operator = get_first_child_by_type(node, "comparison_operator") - mi.expr = self.visit(comp) + def visit_if_statement(self, node): + """ return a node describing if, elseif, else conditional logic""" + + def conditional(conditional_node): + """ return a ModelIf struct for the conditional logic node. """ + + # comparison_operator + expr = self.visit(get_first_child_by_type( + conditional_node, + "comparison_operator" + )) + + ret = ModelIf( + expr = expr, + source_refs=[self.node_helper.get_source_ref(conditional_node)] + ) - return mi - + # instruction_block + block = get_first_child_by_type(conditional_node, "block") + if block: + ret.body = valid([self.visit(child) for child in block.children]) + + return ret - def visit_else_clause(self, node): - """ Return a ModelIf with body nodes only. """ - # get the top level body nodes - mi = ModelIf() - block = get_first_child_by_type(node, "block") - for child in block.children: - body_node = self.visit(child) - if body_node: - if not mi.body: - mi.body = list() - mi.body.append(body_node) + # the if statement is returned as a ModelIf AstNode + model_ifs = [conditional(node)] - return mi + # add 0-n elseif clauses + elseif_clauses = get_children_by_types(node, ["elseif_clause"]) + model_ifs += [conditional(child) for child in elseif_clauses] + for i, model_if in enumerate(model_ifs[1:]): + model_ifs[i].orelse = [model_if] + # add 0-1 else clause + else_clause = get_first_child_by_type(node, "else_clause") + if else_clause: + block = get_first_child_by_type(else_clause, "block") + if block: + last = model_ifs[len(model_ifs)-1] + last.orelse = valid([self.visit(child) for child in block.children]) + return model_ifs[0] + def visit_assignment(self, node): - """Docstring""" - # print('visit_assignment') left, _, right = node.children return Assignment( @@ -655,7 +701,6 @@ def visit_assignment(self, node): def visit_literal(self, node) -> LiteralValue: """Visitor for literals. Returns a LiteralValue""" - # print('visit_literal') literal_type = node.type literal_value = self.node_helper.get_identifier(node) literal_source_ref = self.node_helper.get_source_ref(node) @@ -711,10 +756,7 @@ def visit_literal(self, node) -> LiteralValue: source_refs=[literal_source_ref], ) - def visit_identifier(self, node): - """Docstring""" - # print('visit_identifier') # By default, this is unknown, but can be updated by other visitors identifier = self.node_helper.get_identifier(node) if self.variable_context.is_variable(identifier): @@ -737,8 +779,6 @@ def visit_identifier(self, node): ) def visit_math_expression(self, node): - """Docstring""" - # print('visit_math_expression') op = self.node_helper.get_identifier( get_control_children(node)[0] ) # The operator will be the first control character @@ -864,8 +904,6 @@ def visit_variable_declaration(self, node) -> List: return vars def visit_extent_specifier(self, node): - """Docstring""" - # print('visit_extent_specifier') # Node structure # (extent_specifier) # (identifier) @@ -906,8 +944,6 @@ def visit_derived_type(self, node: Node) -> RecordDef: (BODY_NODES) ... """ - # print('visit_derived_type') - record_name = self.node_helper.get_identifier( get_first_child_by_type(node, "type_name", recurse=True) @@ -989,7 +1025,6 @@ def visit_derived_type_member_expression(self, node) -> Attribute: (argument_list) (type_member) """ - # print('visit_derived_type_member_expression') # If we are accessing an attribute of a scalar type, we can simply pull the name node from the variable context. # However, if this is a dimensional type, we must convert it to a call to _get. @@ -1016,8 +1051,6 @@ def visit_derived_type_member_expression(self, node) -> Attribute: # NOTE: This function starts with _ because it will never be dispatched to directly. There is not a get node in the tree-sitter parse tree. # From context, we will determine when we are accessing an element of a List, and call this function, def _visit_get(self, node): - """Docstring""" - # print('_visit_get') # Node structure # (call_expression) # (identifier) @@ -1056,8 +1089,6 @@ def _visit_get(self, node): ) def _visit_set(self, node): - """Docstring""" - # print('_visit_set') # Node structure # (assignment) # (call_expression) @@ -1084,7 +1115,6 @@ def _visit_while(self, node) -> Loop: (...) ... (body) ... """ - # print('_visit_while') while_statement_node = get_first_child_by_type(node, "while_statement") # The first body node will be the node after the while_statement @@ -1113,7 +1143,6 @@ def _visit_while(self, node) -> Loop: def _visit_implied_do_loop(self, node) -> Call: """Custom visitor for implied_do_loop array literal. This form gets converted to a call to range""" # TODO: This loop_control is the same as the do loop. Can we turn this into one visitor? - # print('_visit_implied_do_loop') loop_control_node = get_first_child_by_type( node, "loop_control_expression", recurse=True ) @@ -1142,8 +1171,6 @@ def _visit_implied_do_loop(self, node) -> Call: ) def _visit_passthrough(self, node): - """Docstring""" - # print('_visit_passthrough') if len(node.children) == 0: return None @@ -1153,8 +1180,6 @@ def _visit_passthrough(self, node): return child_cast def get_gromet_function_node(self, func_name: str) -> Name: - """Docstring""" - # print('get_gromet_function_node') # Idealy, we would be able to create a dummy node and just call the name visitor. # However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here. if self.variable_context.is_variable(func_name): diff --git a/skema/program_analysis/CAST/matlab/node_helper.py b/skema/program_analysis/CAST/matlab/node_helper.py index a64af619599..bda494b5f2b 100644 --- a/skema/program_analysis/CAST/matlab/node_helper.py +++ b/skema/program_analysis/CAST/matlab/node_helper.py @@ -27,7 +27,6 @@ class NodeHelper(): def __init__(self, source: str, source_file_name: str): - """Docstring""" self.source = source self.source_file_name = source_file_name @@ -89,6 +88,19 @@ def get_first_child_index(node, type: str): if child.type == type: return i +def get_all(node, types): + """ return all nodes with type in types from the entire node tree """ + def search(node, types, ret): + if node.type in types: + ret += [node] + for child in node.children: + search(child, types, ret) + return ret + return search(node, types, []) + +def valid(nodes): + """ return the node list without any None elements """ + return [node for node in nodes if node] def remove_comments(node: Node): """Remove comment nodes from tree-sitter parse tree""" @@ -114,12 +126,10 @@ def get_last_child_index(node, type: str): def get_control_children(node: Node): - """Docstring""" return get_children_by_types(node, CONTROL_CHARACTERS) def get_non_control_children(node: Node): - """Docstring""" children = [] for child in node.children: if child.type not in CONTROL_CHARACTERS: diff --git a/skema/program_analysis/CAST/matlab/tests/test_assignment.py b/skema/program_analysis/CAST/matlab/tests/test_assignment.py index 898432720c4..2b38dd5702e 100644 --- a/skema/program_analysis/CAST/matlab/tests/test_assignment.py +++ b/skema/program_analysis/CAST/matlab/tests/test_assignment.py @@ -1,17 +1,45 @@ from skema.program_analysis.CAST.matlab.tests.utils import ( assert_assignment, - first_cast_node + cast_nodes ) -# Test the CAST returned by processing the simplest MATLAB assignment +# Test CAST from assignment -def test_assignment(): - """ Test CAST from MATLAB 'assignment' statement.""" +def test_literal(): + """ Test assignment of literal types (number, string, boolean).""" - source = 'x = 5' + source = """ + x = 5 + y = 1.8 + x = 'single' + y = "double" + yes = true + no = false + """ - # The root of the CAST should be Assignment - assignment = first_cast_node(source) + nodes = cast_nodes(source) + assert len(nodes) == 6 - # The module body should contain a single assignment node - assert_assignment(assignment, left = "x", right = "5") + # number + assert_assignment(nodes[0], left = "x", right = "5") + assert_assignment(nodes[1], left = "y", right = "1.8") + # string + assert_assignment(nodes[2], left = "x", right = "'single'") + assert_assignment(nodes[3], left = "y", right = "\"double\"") + # boolean + assert_assignment(nodes[4], left = 'yes', right = 'true') + assert_assignment(nodes[5], left = 'no', right = 'false') + + +def test_identifier(): + """ Test assignment of identifiers.""" + + source = """ + x = y + """ + + nodes = cast_nodes(source) + assert len(nodes) == 1 + + # identifier + assert_assignment(nodes[0], left = 'x', right = 'y') diff --git a/skema/program_analysis/CAST/matlab/tests/test_binary_operation.py b/skema/program_analysis/CAST/matlab/tests/test_binary_operation.py deleted file mode 100644 index b0452445992..00000000000 --- a/skema/program_analysis/CAST/matlab/tests/test_binary_operation.py +++ /dev/null @@ -1,23 +0,0 @@ -from skema.program_analysis.CAST.matlab.tests.utils import ( - assert_var, - assert_expression, - first_cast_node -) -from skema.program_analysis.CAST2FN.model.cast import Assignment - -# Test the CAST returned by processing the simplest MATLAB binary operation - -def test_binary_operation(): - """ Test CAST from MATLAB binary operation statement.""" - - source = 'z = x + y' - - # The root of the CAST should be Assignment - assignment = first_cast_node(source) - assert isinstance(assignment, Assignment) - - # Left operand of this assignment node is the variable - assert_var(assignment.left, name = "z") - - # right operand of this assignment node is a binary expression - assert_expression(assignment.right, op = "+", left = "x", right = "y") diff --git a/skema/program_analysis/CAST/matlab/tests/test_conditional_logic.py b/skema/program_analysis/CAST/matlab/tests/test_conditional.py similarity index 57% rename from skema/program_analysis/CAST/matlab/tests/test_conditional_logic.py rename to skema/program_analysis/CAST/matlab/tests/test_conditional.py index 59f4d18f815..08a61f7265b 100644 --- a/skema/program_analysis/CAST/matlab/tests/test_conditional_logic.py +++ b/skema/program_analysis/CAST/matlab/tests/test_conditional.py @@ -1,7 +1,7 @@ from skema.program_analysis.CAST.matlab.tests.utils import ( assert_assignment, assert_expression, - first_cast_node + cast_nodes ) from skema.program_analysis.CAST2FN.model.cast import ModelIf @@ -9,15 +9,16 @@ def test_if(): """ Test CAST from MATLAB 'if' conditional logic.""" source = """ - if x > 5 + if x == 5 y = 6 end """ - mi = first_cast_node(source) + mi = cast_nodes(source)[0] + # if assert isinstance(mi, ModelIf) - assert_expression(mi.expr, op = ">", left = "x", right = "5") + assert_expression(mi.expr, op = "==", left = "x", right = "5") assert_assignment(mi.body[0], left="y", right = "6") def test_if_else(): @@ -26,18 +27,43 @@ def test_if_else(): source = """ if x > 5 y = 6 + three = 3 else y = x + foo = 'bar' end """ - mi = first_cast_node(source) + mi = cast_nodes(source)[0] # if assert isinstance(mi, ModelIf) assert_expression(mi.expr, op = ">", left = "x", right = "5") assert_assignment(mi.body[0], left="y", right = "6") + assert_assignment(mi.body[1], left="three", right = "3") # else assert_assignment(mi.orelse[0], left="y", right = "x") + assert_assignment(mi.orelse[1], left="foo", right = "'bar'") + +def test_if_elseif(): + """ Test CAST from MATLAB 'if elseif else' conditional logic.""" + + source = """ + if x >= 5 + y = 6 + elseif x <= 0 + y = x + end + """ + + mi = cast_nodes(source)[0] + # if + assert isinstance(mi, ModelIf) + assert_expression(mi.expr, op = ">=", left = "x", right = "5") + assert_assignment(mi.body[0], left="y", right = "6") + # elseif + assert isinstance(mi.orelse[0], ModelIf) + assert_expression(mi.orelse[0].expr, op = "<=", left = "x", right = "0") + assert_assignment(mi.orelse[0].body[0], left="y", right = "x") def test_if_elseif_else(): """ Test CAST from MATLAB 'if elseif else' conditional logic.""" @@ -52,7 +78,7 @@ def test_if_elseif_else(): end """ - mi = first_cast_node(source) + mi = cast_nodes(source)[0] # if assert isinstance(mi, ModelIf) assert_expression(mi.expr, op = ">", left = "x", right = "5") @@ -62,5 +88,4 @@ def test_if_elseif_else(): assert_expression(mi.orelse[0].expr, op = ">", left = "x", right = "0") assert_assignment(mi.orelse[0].body[0], left="y", right = "x") # else - assert_assignment(mi.orelse[1], left="y", right = "0") - + assert_assignment(mi.orelse[0].orelse[0], left="y", right = "0") diff --git a/skema/program_analysis/CAST/matlab/tests/test_operators.py b/skema/program_analysis/CAST/matlab/tests/test_operators.py new file mode 100644 index 00000000000..bd74e82d63b --- /dev/null +++ b/skema/program_analysis/CAST/matlab/tests/test_operators.py @@ -0,0 +1,38 @@ +from skema.program_analysis.CAST.matlab.tests.utils import ( + assert_var, + assert_expression, + cast_nodes +) +from skema.program_analysis.CAST2FN.model.cast import Assignment + +def test_binary_operator(): + """ Test CAST from MATLAB binary operation statement.""" + + source = 'z = x + y' + + # cast nodes should be one assignment + nodes = cast_nodes(source) + assert len(nodes) == 1 + assert isinstance(nodes[0], Assignment) + + # Left assignment operand is the variable + assert_var(nodes[0].left, name = "z") + + # right assignment operand is a binary expression + assert_expression(nodes[0].right, op = "+", left = "x", right = "y") + +def do_not_test_unary_operator(): + """ Test CAST from MATLAB binary operation statement.""" + + source = 'z = -6' + + # cast nodes should be one assignment + nodes = cast_nodes(source) + assert len(nodes) == 1 + assert isinstance(nodes[0], Assignment) + + # Left assignment operand is the variable + assert_var(nodes[0].left, name = "z") + + # right assignment operand is a binary expression + assert_expression(nodes[0].right, op = "+", left = "x", right = "y") diff --git a/skema/program_analysis/CAST/matlab/tests/test_switch.py b/skema/program_analysis/CAST/matlab/tests/test_switch.py new file mode 100644 index 00000000000..bc4af8ea109 --- /dev/null +++ b/skema/program_analysis/CAST/matlab/tests/test_switch.py @@ -0,0 +1,70 @@ +from skema.program_analysis.CAST.matlab.tests.utils import ( + assert_assignment, + assert_expression, + cast_nodes +) +from skema.program_analysis.CAST2FN.model.cast import ModelIf + +def test_switch_single_values(): + """ Test CAST from MATLAB switch statement.""" + + source = """ + switch s + case 'one' + n = 1; + case 'two' + n = 2; + x = y; + otherwise + n = 0; + end + """ + + # case clause 'one' + mi0 = cast_nodes(source)[0] + assert isinstance(mi0, ModelIf) + assert_assignment(mi0.body[0], left="n", right = "1") + assert_expression(mi0.expr, op="==", left = "s", right = "'one'") + + # case clause 'two' + mi1 = mi0.orelse[0] + assert isinstance(mi1, ModelIf) + assert_assignment(mi1.body[0], left="n", right = "2") + assert_assignment(mi1.body[1], left="x", right = "y") + + # otherwise clause + assert_assignment(mi1.orelse[0], left="n", right = "0") + +def test_switch_multiple_values(): + """ Test CAST from MATLAB switch statement.""" + + source = """ + switch s + case {'one', 'two', 'three'} + n = 1; + case 2 + n = 2; + otherwise + n = 0; + end + """ + + # case clause {'one', 'two', 'three'} + mi0 = cast_nodes(source)[0] + assert isinstance(mi0, ModelIf) + assert_assignment(mi0.body[0], left="n", right = "1") + assert_expression( + mi0.expr, + op="in", + left = 's', + right = ["'one'", "'two'", "'three'"] + ) + + # case clause 2 + mi1 = mi0.orelse[0] + assert isinstance(mi1, ModelIf) + assert_assignment(mi1.body[0], left="n", right = "2") + assert_expression(mi1.expr, op="==", left = 's', right = "2") + + # otherwise clause + assert_assignment(mi1.orelse[0], left="n", right = "0") diff --git a/skema/program_analysis/CAST/matlab/tests/utils.py b/skema/program_analysis/CAST/matlab/tests/utils.py index 73a7cf49300..b145497b19c 100644 --- a/skema/program_analysis/CAST/matlab/tests/utils.py +++ b/skema/program_analysis/CAST/matlab/tests/utils.py @@ -8,15 +8,20 @@ Var ) +# Tests now check for null source_refs nodes + def assert_var(var, name = ""): """ Test the Var for correct type and name. """ assert isinstance(var, Var) + assert not var.source_refs == None assert isinstance(var.val, Name) + assert not var.val.source_refs == None assert var.val.name == name def assert_literal_value(literal_value, value = ""): """ Test the LiteralValue for correct type and value. """ assert isinstance(literal_value, LiteralValue) + assert not literal_value.source_refs == None assert literal_value.value == value def assert_operand(operand, value = ""): @@ -29,30 +34,28 @@ def assert_operand(operand, value = ""): assert(False) def assert_assignment(assignment, left = "", right = ""): - """ Test an Assignment correct type and operands. """ + """ Test an Assignment for correct type and operands. """ assert isinstance(assignment, Assignment) + assert not assignment.source_refs == None assert_operand(assignment.left, left) assert_operand(assignment.right, right) def assert_expression(expression, op = "", left = "", right = ""): """ Test an Operator for correct type, operation, and operands. """ assert isinstance(expression, Operator) + assert not expression.source_refs == None assert expression.op == op assert_operand(expression.operands[0], left) assert_operand(expression.operands[1], right) -def first_cast_node(source): - """ Return the first node from the first Module of MatlabToCast output """ - +def cast_nodes(source): + """ Return the CAST nodes from the first Module of MatlabToCast output """ # there should only be one CAST object in the cast output list cast = MatlabToCast(source = source).out_cast assert len(cast) == 1 - # there should be one module in the CAST object assert len(cast[0].nodes) == 1 module = cast[0].nodes[0] assert isinstance(module, Module) - - # currently we support one node per module. This may change - assert len(module.body) == 1 - return module.body[0] + # return the module body node list + return module.body diff --git a/skema/program_analysis/CAST/matlab/tree_out b/skema/program_analysis/CAST/matlab/tree_out index 4e52a78f6a7..c0352e3077a 100755 --- a/skema/program_analysis/CAST/matlab/tree_out +++ b/skema/program_analysis/CAST/matlab/tree_out @@ -11,10 +11,7 @@ from skema.program_analysis.tree_sitter_parsers.build_parsers import ( def print_tree(node: Node, indent = ''): """Display the node branch in pretty format""" for child in node.children: - if child.type == "\n": - # print(f"{indent} ") - pass - else: + if not child.type == "\n": print(f"{indent} {child.type}") print_tree(child, indent + ' ') diff --git a/skema/program_analysis/CAST/matlab/variable_context.py b/skema/program_analysis/CAST/matlab/variable_context.py index 6ef6088964a..b2ab8af51b9 100644 --- a/skema/program_analysis/CAST/matlab/variable_context.py +++ b/skema/program_analysis/CAST/matlab/variable_context.py @@ -6,7 +6,6 @@ class VariableContext(object): def __init__(self): - """Docstring""" self.context = [{}] # Stack of context dictionaries self.context_return_values = [set()] # Stack of context return values self.all_symbols = {} @@ -78,11 +77,9 @@ def is_variable(self, symbol: str) -> bool: return symbol in self.all_symbols def get_node(self, symbol: str) -> Dict: - """Docstring""" return self.all_symbols[symbol]["node"] def get_type(self, symbol: str) -> str: - """Docstring""" return self.all_symbols[symbol]["type"] def update_type(self, symbol: str, type: str): @@ -92,22 +89,18 @@ def update_type(self, symbol: str, type: str): self.all_symbols[full_symbol_name]["type"] = type def add_return_value(self, symbol): - """Docstring""" self.context_return_values[-1].add(symbol) def remove_return_value(self, symbol): - """Docstring""" self.context_return_values[-1].discard(symbol) def generate_iterator(self): - """Docstring""" symbol = f"generated_iter_{self.iterator_id}" self.iterator_id += 1 return self.add_variable(symbol, "iterator", None) def generate_stop_condition(self): - """Docstring""" symbol = f"sc_{self.stop_condition_id}" self.stop_condition_id += 1 @@ -126,5 +119,4 @@ def set_internal(self): self.internal = True def unset_internal(self): - """Docstring""" self.internal = False