From 2eae12114f1bdd8da65176cd57c0c4a0728919ac Mon Sep 17 00:00:00 2001 From: Joseph Wood Date: Wed, 17 May 2023 15:22:04 -0400 Subject: [PATCH 1/6] calls to variables within a type no longer get added to call graph --- ford/sourceform.py | 144 ++++++++++++++++++++++++++++++++-------- test/test_sourceform.py | 2 +- 2 files changed, 119 insertions(+), 27 deletions(-) diff --git a/ford/sourceform.py b/ford/sourceform.py index 76d106e5..525e3817 100644 --- a/ford/sourceform.py +++ b/ford/sourceform.py @@ -584,7 +584,7 @@ class FortranContainer(FortranBase): ) ARITH_GOTO_RE = re.compile(r"go\s*to\s*\([0-9,\s]+\)", re.IGNORECASE) CALL_RE = re.compile( - r"(?:^|[^a-zA-Z0-9_% ]\s*(?:\w+%)?)(\w+)(?=\s*\(\s*(?:.*?)\s*\))", re.IGNORECASE + r"(?:^|[^a-zA-Z0-9_ ]\s*)(\w+)(?=\s*\(\s*(?:.*?)\s*\))", re.IGNORECASE ) SUBCALL_RE = re.compile( r"^(?:if\s*\(.*\)\s*)?call\s+(?:\w+%)?(\w+)\s*(?:\(\s*(.*?)\s*\))?$", @@ -923,6 +923,40 @@ def __init__( else: self.print_error(line, "Unexpected USE statement") else: + def get_call_chain(match) -> dict: + bracket = 0 + spacer = False + call_chain = [] + tmp = '' + for c in line[match.start(1)-1::-1]: + if bracket < 0: + break + if c == ')': + bracket += 1 + elif c == '(': + bracket -= 1 + elif bracket == 0: + if c.isalnum() or c == '_': + if spacer: + if len(tmp) == 0: + spacer = False + else: + break + tmp = c + tmp + elif c == '%': + if(len(tmp) != 0): + call_chain.append(tmp.lower()) + tmp = '' + elif c == ' ': + spacer = True + continue + else: + break + if(len(tmp) != 0): + call_chain.append(tmp.lower()) + call_chain.reverse() + return {"name": match.group(1).lower(), "call_chain": call_chain} + if self.CALL_RE.search(line): if hasattr(self, "calls"): # Arithmetic GOTOs looks little like function references: @@ -931,13 +965,16 @@ def __init__( # expression doesn't catch that so we first rule such a # GOTO out. if not self.ARITH_GOTO_RE.search(line): - callvals = self.CALL_RE.findall(line) + callvals = [] + matches = self.CALL_RE.finditer(line) + for match in matches: + callvals.append(get_call_chain(match)) for val in callvals: if ( - val.lower() not in self.calls - and val.lower() not in INTRINSICS + not any(d["name"] == val["name"] for d in self.calls) + and val["name"] not in INTRINSICS ): - self.calls.append(val.lower()) + self.calls.append(val) else: pass # Not raising an error here as too much possibility that something @@ -945,8 +982,11 @@ def __init__( if match := self.SUBCALL_RE.match(line): # Need this to catch any subroutines called without argument lists if hasattr(self, "calls"): - callval = match.group(1).lower() - if callval not in self.calls and callval not in INTRINSICS: + callval = get_call_chain(match) + if ( + not any(d["name"] == callval["name"] for d in self.calls) + and callval["name"] not in INTRINSICS + ): self.calls.append(callval) else: self.print_error(line, "Unexpected procedure call") @@ -992,11 +1032,13 @@ def correlate(self, project): self.all_procs.update(self.parent_submodule.all_procs) self.all_absinterfaces.update(self.parent_submodule.all_absinterfaces) self.all_types.update(self.parent_submodule.all_types) + self.all_vars.update(self.parent_submodule.pub_vars) elif type(getattr(self, "ancestor_module", "")) not in [str, type(None)]: self.ancestor_module.descendants.append(self) self.all_procs.update(self.ancestor_module.all_procs) self.all_absinterfaces.update(self.ancestor_module.all_absinterfaces) self.all_types.update(self.ancestor_module.all_types) + self.all_vars.update(self.ancestor_module.pub_vars) # Module procedures will be missing (some/all?) metadata, so # now we copy it from the interface @@ -1055,39 +1097,92 @@ def filter_public(collection: dict) -> dict: typelist[dtype] = set([]) typeorder = toposort.toposort_flatten(typelist) + # Correlate types + for dtype in typeorder: + if dtype in self.types: + dtype.correlate(project) + # Match up called procedures if hasattr(self, "calls"): tmplst = [] for call in self.calls: - call_low = call.lower() + call_low = call["name"].lower() argname = False for a in getattr(self, "args", []): # Consider allowing procedures passed as arguments to be included in callgraphs argname |= call_low == a.name.lower() if hasattr(self, "retvar"): argname |= call_low == self.retvar.name.lower() + + # travers the call chain of the call to descover the context the call is made on + context = self + call_chain = call["call_chain"] + if len(call_chain) > 0: + class NoLabel(Exception): + pass + try: + call_type = None + # strip off the "type()" or "class()" if it's there + strip_type = lambda s: re.match(r"^(type|class)\((.*?)(?:\(.*\))?\)$", s, re.IGNORECASE).group(2) if re.match(r"^(type|class)\((.*?)(?:\(.*\))?\)$", s, re.IGNORECASE) else s + try: # try call type is a variable + vars = self.all_vars + if hasattr(self, "args"): + vars = {**vars, **{a.name.lower(): a for a in self.args}} + if hasattr(self, "retvar"): + vars = {**vars, **{self.retvar.name.lower(): self.retvar}} + call_type = vars[call_chain[0].lower()].full_type + call_type = strip_type(call_type) + call_type = self.all_types[call_type.lower()] + except (KeyError, AttributeError): # not a variable, try call type is a procedure + try: + call_type = self.all_procs[call_chain[0].lower()].retvar.full_type + call_type = strip_type(call_type) + call_type = self.all_procs[call_chain[0].lower()].all_types[call_type.lower()] + except (KeyError, AttributeError): # not a procedure, give up + raise NoLabel + for c in [c.lower() for c in call_chain[1:]]: # traverse the call chain + try: # try call type is a variable + new_call_type = None + for var in call_type.variables: + if var.name.lower() == c: + new_call_type = var.full_type + break + if new_call_type is None: + raise KeyError(f"Variable {c} not found in type {call_type.name}") + new_call_type = strip_type(new_call_type) + call_type = call_type.all_types[new_call_type.lower()] + except (KeyError, AttributeError): # not a variable, try call type is a procedure + try: + new_call_type = None + for proc in call_type.boundprocs: + if proc.name.lower() == c: + new_call_type = proc.retvar.full_type + break + if new_call_type is None: + raise KeyError(f"Procedure {c} not found in type {call_type.name}") + except (KeyError, AttributeError): # not a procedure, give up + raise NoLabel + context = call_type + except NoLabel: + pass + all_vars = {} + if hasattr(context, "all_vars"): + all_vars = context.all_vars + if hasattr(context, "variables"): + all_vars = {**all_vars, **{v.name.lower(): v for v in context.variables}} + if ( not argname - and call_low not in self.all_vars + and call_low not in all_vars and (call_low not in self.all_types or call_low in self.all_procs) ): + try: + call = context.all_procs[call["name"].lower()] + except KeyError: + call = call["name"] tmplst.append(call) self.calls = tmplst - procedures = ( - {proc.name.lower(): proc for proc in self.parent.routines} - if self.parobj == "sourcefile" - else {} - ) - procedures.update({proc.name.lower(): proc for proc in project.procedures}) - procedures.update(self.all_procs) - - for i, call in enumerate(self.calls): - try: - self.calls[i] = procedures[call.lower()] - except KeyError: - pass - if self.obj == "submodule": self.ancestry = [] item = self @@ -1097,9 +1192,6 @@ def filter_public(collection: dict) -> dict: self.ancestry.insert(0, item.ancestor_module) # Recurse - for dtype in typeorder: - if dtype in self.types: - dtype.correlate(project) for entity in self.iterator( "functions", "subroutines", diff --git a/test/test_sourceform.py b/test/test_sourceform.py index 18681eb2..057efbc8 100644 --- a/test/test_sourceform.py +++ b/test/test_sourceform.py @@ -231,7 +231,7 @@ def test_function_and_subroutine_call_on_same_line(parse_fortran_file): program = fortran_file.programs[0] assert len(program.calls) == 2 expected_calls = {"bar", "foo"} - assert set(program.calls) == expected_calls + assert set([call["name"] for call in program.calls]) == expected_calls def test_component_access(parse_fortran_file): From b65f6a7270a965bc7c30584d0f872adb9dcad556 Mon Sep 17 00:00:00 2001 From: Joseph Wood Date: Tue, 23 May 2023 14:55:25 -0400 Subject: [PATCH 2/6] rewrite of get_call_chain and _find_call_context, now much cleaner. --- ford/sourceform.py | 276 ++++++++++++++++++++++++---------------- ford/utils.py | 39 +++++- test/test_sourceform.py | 2 +- 3 files changed, 205 insertions(+), 112 deletions(-) diff --git a/ford/sourceform.py b/ford/sourceform.py index 525e3817..484dd170 100644 --- a/ford/sourceform.py +++ b/ford/sourceform.py @@ -923,39 +923,13 @@ def __init__( else: self.print_error(line, "Unexpected USE statement") else: - def get_call_chain(match) -> dict: - bracket = 0 - spacer = False - call_chain = [] - tmp = '' - for c in line[match.start(1)-1::-1]: - if bracket < 0: - break - if c == ')': - bracket += 1 - elif c == '(': - bracket -= 1 - elif bracket == 0: - if c.isalnum() or c == '_': - if spacer: - if len(tmp) == 0: - spacer = False - else: - break - tmp = c + tmp - elif c == '%': - if(len(tmp) != 0): - call_chain.append(tmp.lower()) - tmp = '' - elif c == ' ': - spacer = True - continue - else: - break - if(len(tmp) != 0): - call_chain.append(tmp.lower()) - call_chain.reverse() - return {"name": match.group(1).lower(), "call_chain": call_chain} + + def is_unique_call(call): + """Check if a call is unique in this container, and not an intrinsic call.""" + return ( + call.name not in INTRINSICS + and not any(c.name == call.name for c in self.calls) + ) if self.CALL_RE.search(line): if hasattr(self, "calls"): @@ -965,16 +939,11 @@ def get_call_chain(match) -> dict: # expression doesn't catch that so we first rule such a # GOTO out. if not self.ARITH_GOTO_RE.search(line): - callvals = [] matches = self.CALL_RE.finditer(line) for match in matches: - callvals.append(get_call_chain(match)) - for val in callvals: - if ( - not any(d["name"] == val["name"] for d in self.calls) - and val["name"] not in INTRINSICS - ): - self.calls.append(val) + call = get_call_chain(line,match) + if is_unique_call(call): + self.calls.append(call) else: pass # Not raising an error here as too much possibility that something @@ -982,12 +951,9 @@ def get_call_chain(match) -> dict: if match := self.SUBCALL_RE.match(line): # Need this to catch any subroutines called without argument lists if hasattr(self, "calls"): - callval = get_call_chain(match) - if ( - not any(d["name"] == callval["name"] for d in self.calls) - and callval["name"] not in INTRINSICS - ): - self.calls.append(callval) + call = get_call_chain(line, match) + if is_unique_call(call): + self.calls.append(call) else: self.print_error(line, "Unexpected procedure call") @@ -1106,81 +1072,49 @@ def filter_public(collection: dict) -> dict: if hasattr(self, "calls"): tmplst = [] for call in self.calls: - call_low = call["name"].lower() + call: CallChain + call.name = call.name.lower() + + if call.name == 'c': + pass + + # get the context of the call + context = self._find_call_context(call) + + if context is None: + pass + + # failed to find context, give up and add call's string name to the list + if context is None: + tmplst.append(call.name) + continue + argname = False - for a in getattr(self, "args", []): - # Consider allowing procedures passed as arguments to be included in callgraphs - argname |= call_low == a.name.lower() - if hasattr(self, "retvar"): - argname |= call_low == self.retvar.name.lower() + # arguments and returns are only possible labels if the call is made within self's context + if context == self: + for a in getattr(self, "args", []): + # Consider allowing procedures passed as arguments to be included in callgraphs + argname |= call.name == a.name.lower() + if hasattr(self, "retvar"): + argname |= call.name == self.retvar.name.lower() - # travers the call chain of the call to descover the context the call is made on - context = self - call_chain = call["call_chain"] - if len(call_chain) > 0: - class NoLabel(Exception): - pass - try: - call_type = None - # strip off the "type()" or "class()" if it's there - strip_type = lambda s: re.match(r"^(type|class)\((.*?)(?:\(.*\))?\)$", s, re.IGNORECASE).group(2) if re.match(r"^(type|class)\((.*?)(?:\(.*\))?\)$", s, re.IGNORECASE) else s - try: # try call type is a variable - vars = self.all_vars - if hasattr(self, "args"): - vars = {**vars, **{a.name.lower(): a for a in self.args}} - if hasattr(self, "retvar"): - vars = {**vars, **{self.retvar.name.lower(): self.retvar}} - call_type = vars[call_chain[0].lower()].full_type - call_type = strip_type(call_type) - call_type = self.all_types[call_type.lower()] - except (KeyError, AttributeError): # not a variable, try call type is a procedure - try: - call_type = self.all_procs[call_chain[0].lower()].retvar.full_type - call_type = strip_type(call_type) - call_type = self.all_procs[call_chain[0].lower()].all_types[call_type.lower()] - except (KeyError, AttributeError): # not a procedure, give up - raise NoLabel - for c in [c.lower() for c in call_chain[1:]]: # traverse the call chain - try: # try call type is a variable - new_call_type = None - for var in call_type.variables: - if var.name.lower() == c: - new_call_type = var.full_type - break - if new_call_type is None: - raise KeyError(f"Variable {c} not found in type {call_type.name}") - new_call_type = strip_type(new_call_type) - call_type = call_type.all_types[new_call_type.lower()] - except (KeyError, AttributeError): # not a variable, try call type is a procedure - try: - new_call_type = None - for proc in call_type.boundprocs: - if proc.name.lower() == c: - new_call_type = proc.retvar.full_type - break - if new_call_type is None: - raise KeyError(f"Procedure {c} not found in type {call_type.name}") - except (KeyError, AttributeError): # not a procedure, give up - raise NoLabel - context = call_type - except NoLabel: - pass + # get all the variables in the call's context all_vars = {} if hasattr(context, "all_vars"): all_vars = context.all_vars if hasattr(context, "variables"): all_vars = {**all_vars, **{v.name.lower(): v for v in context.variables}} + # if call isn't to a variable (i.e. an array), and isn't a type, add it to the list if ( not argname - and call_low not in all_vars - and (call_low not in self.all_types or call_low in self.all_procs) + and call.name not in getattr(context, "all_types", {}) + and call.name not in all_vars ): - try: - call = context.all_procs[call["name"].lower()] - except KeyError: - call = call["name"] + # if can't find the call in context, add it as a string + call = context.all_procs.get(call.name, call.name) tmplst.append(call) + self.calls = tmplst if self.obj == "submodule": @@ -1326,6 +1260,110 @@ def prune(self): obj.visible = True obj.prune() + def _find_call_context(self, call): + """ + Traverse the call chain of the call to discover the context the call is made on. + This is done by looking at the first label in the call chain and matching it to + a variable or function in the current scope. Then, traverse to the context of the + variable or function return and repeat until the call chain is exhausted. + + If the traversal fails to find a label in a context, + the function gives up and returns None + """ + + def strip_type(s): + """ + strip the encasing 'type()' or 'class()' from a string if it exists, + and return the inner string (lowercased) + """ + r = re.match(r"^(type|class)\((.*?)(?:\(.*\))?\)$", s, re.IGNORECASE) + return r.group(2).lower() if r else s.lower() + + # context is self if call is not a chain + if len(call.chain) == 0: + return self + + call.chain[0] = call.chain[0].lower() + + call_type = None + # try call type is a variable + vars = getattr(self, "all_vars", {}) + if hasattr(self, "args"): + vars = {**vars, **{a.name.lower(): a for a in self.args}} + if hasattr(self, "retvar"): + vars = {**vars, **{self.retvar.name.lower(): self.retvar}} + if (hasattr(self, 'all_types') + and call.chain[0] in vars + ): + call_type_str = strip_type(vars[call.chain[0]].full_type) + call_type = self.all_types.get(call_type_str, None) + + # if None, not a variable, try call type is a procedure + if (call_type is None + and hasattr(self, 'all_procs') + and hasattr(self, 'all_types') + and call.chain[0] in self.all_procs + and hasattr(self.all_procs[call.chain[0]], 'retvar') + ): + call_type_str = strip_type(self.all_procs[call.chain[0]].retvar.full_type) + call_type = self.all_types.get(call_type_str, None) + + # if None, not a procedure, try call type is an extended type + if (call_type is None + and isinstance(self, FortranType) + and hasattr(self, 'extends')): + extend = self + while getattr(extend, "extends", None) is not None: + if extend.extends.name.lower() == c: + call_type = extend.extends + break + extend = extend.extends + + # if still None, give up + if call_type is None: + return None + + # traverse the call chain + for c in [c.lower() for c in call.chain[1:]]: + new_call_type = None + # try call type is a variable + if hasattr(call_type, 'variables'): + new_call_type_str = None + for v in call_type.variables: + if v.name.lower() == c: + new_call_type_str = strip_type(v.full_type) + break + new_call_type = call_type.all_types.get(new_call_type_str, None) + + # not a variable, try call type is a procedure + if (new_call_type is None + and hasattr(call_type, 'boundprocs') + ): + new_call_type_str = None + for b in call_type.boundprocs: + if b.name.lower() == c: + new_call_type_str = strip_type(v.retvar.full_type) + break + new_call_type = call_type.all_types.get(new_call_type_str, None) + + # not a procedure, try call type is an extended type + if (new_call_type is None + and isinstance(call_type, FortranType)): + extend = call_type + while getattr(extend, "extends", None) is not None: + if extend.extends.name.lower() == c: + new_call_type = extend.extends + break + extend = extend.extends + + # not a subtype, give up + if new_call_type is None: + return None + + call_type = new_call_type + + return call_type + class FortranSourceFile(FortranContainer): """ @@ -2892,6 +2930,24 @@ def parse_type( kind = kind.group(1) if kind else args return ParsedType(vartype, rest, kind=kind) +@dataclass +class CallChain: + name: str + chain: List[str] + +def get_call_chain(line, match) -> CallChain: + sub_line = line[:match.start(1)].lower() + if len(sub_line) == 0 or sub_line[-1] != "%": + # not a chain call + return CallChain(match.group(1).lower(), []) + level = sub_line.count("(") - sub_line.count(")") + sub_line = ford.utils.strip_paren(sub_line, retlevel=level, index = len(sub_line) - 1) + # remove 'call ' from the start if present + sub_line = sub_line[5:] if sub_line.startswith("call ") else sub_line + # get end of sub_line after last non-alphanumeric character + sub_line = re.split(r'([^\w\s_%]|(?<=[^\s%])\s+(?=[^\s%]))+', sub_line)[-1].replace(' ','') + call_chain = sub_line.split("%")[:-1] + return CallChain(match.group(1).lower(), call_chain) def set_base_url(url): FortranBase.base_url = url diff --git a/ford/utils.py b/ford/utils.py index b7ae328f..152c3b77 100644 --- a/ford/utils.py +++ b/ford/utils.py @@ -31,6 +31,7 @@ from urllib.parse import urljoin import pathlib from typing import Union +from io import StringIO NOTE_TYPE = { @@ -87,7 +88,7 @@ def substitute(match): def get_parens(line: str, retlevel: int = 0, retblevel: int = 0) -> str: """ - By default akes a string starting with an open parenthesis and returns the portion + By default takes a string starting with an open parenthesis and returns the portion of the string going to the corresponding close parenthesis. If retlevel != 0 then will return when that level (for parentheses) is reached. Same for retblevel. """ @@ -117,6 +118,42 @@ def get_parens(line: str, retlevel: int = 0, retblevel: int = 0) -> str: return parenstr raise RuntimeError("Couldn't parse parentheses: {}".format(line)) +def strip_paren(line: str, retlevel: int = 0, retblevel: int = 0, index = -1) -> str: + """ + By default takes a string starting with an open parenthesis and returns the portion + of the string going to the corresponding close parenthesis. If retlevel != 0 then + will return when that level (for parentheses) is reached. Same for retblevel. + If index >= 0, then only the characters inside the same scope as index are returned. + """ + if len(line) == 0: + return line + retstr = StringIO() + level = 0 + blevel = 0 + for i,char in enumerate(line): + if char == "(": + level += 1 + elif char == ")": + level -= 1 + elif char == "[": + blevel += 1 + elif char == "]": + blevel -= 1 + elif ( + level == retlevel + and blevel == retblevel + ): + retstr.write(char) + + if index >= 0 and char == ')' and level < retlevel: + # scope of index is yet to start, reset retstr + if i < index: + retstr = StringIO() + # scope of index is over, return retstr + else: + return retstr.getvalue() + + return retstr.getvalue() def paren_split(sep, string): """ diff --git a/test/test_sourceform.py b/test/test_sourceform.py index 057efbc8..4c2a4f1c 100644 --- a/test/test_sourceform.py +++ b/test/test_sourceform.py @@ -231,7 +231,7 @@ def test_function_and_subroutine_call_on_same_line(parse_fortran_file): program = fortran_file.programs[0] assert len(program.calls) == 2 expected_calls = {"bar", "foo"} - assert set([call["name"] for call in program.calls]) == expected_calls + assert set([call.name for call in program.calls]) == expected_calls def test_component_access(parse_fortran_file): From 651ae0a9b5f727a7d0cf9997ed14ce0386364b52 Mon Sep 17 00:00:00 2001 From: Joseph Wood Date: Tue, 23 May 2023 15:04:53 -0400 Subject: [PATCH 3/6] fix sub calls with more than one type in call chain --- ford/sourceform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ford/sourceform.py b/ford/sourceform.py index 484dd170..157d8a67 100644 --- a/ford/sourceform.py +++ b/ford/sourceform.py @@ -587,7 +587,7 @@ class FortranContainer(FortranBase): r"(?:^|[^a-zA-Z0-9_ ]\s*)(\w+)(?=\s*\(\s*(?:.*?)\s*\))", re.IGNORECASE ) SUBCALL_RE = re.compile( - r"^(?:if\s*\(.*\)\s*)?call\s+(?:\w+%)?(\w+)\s*(?:\(\s*(.*?)\s*\))?$", + r"^(?:if\s*\(.*\)\s*)?call\s+(?:\w+%)*(\w+)\s*(?:\(\s*(.*?)\s*\))?$", re.IGNORECASE, ) FORMAT_RE = re.compile(r"^[0-9]+\s+format\s+\(.*\)", re.IGNORECASE) From b0abb1d22df45a52bf26d29f9684bb1a41fce42a Mon Sep 17 00:00:00 2001 From: Joseph Wood Date: Wed, 24 May 2023 13:15:41 -0400 Subject: [PATCH 4/6] remove procedure detection in _find_call_context due to functions in the left side of call chains being illegal in fortran. --- ford/sourceform.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/ford/sourceform.py b/ford/sourceform.py index 157d8a67..ac28e9e1 100644 --- a/ford/sourceform.py +++ b/ford/sourceform.py @@ -1297,18 +1297,8 @@ def strip_type(s): ): call_type_str = strip_type(vars[call.chain[0]].full_type) call_type = self.all_types.get(call_type_str, None) - - # if None, not a variable, try call type is a procedure - if (call_type is None - and hasattr(self, 'all_procs') - and hasattr(self, 'all_types') - and call.chain[0] in self.all_procs - and hasattr(self.all_procs[call.chain[0]], 'retvar') - ): - call_type_str = strip_type(self.all_procs[call.chain[0]].retvar.full_type) - call_type = self.all_types.get(call_type_str, None) - # if None, not a procedure, try call type is an extended type + # if None, not a variable, try call type is an extended type if (call_type is None and isinstance(self, FortranType) and hasattr(self, 'extends')): @@ -1334,19 +1324,8 @@ def strip_type(s): new_call_type_str = strip_type(v.full_type) break new_call_type = call_type.all_types.get(new_call_type_str, None) - - # not a variable, try call type is a procedure - if (new_call_type is None - and hasattr(call_type, 'boundprocs') - ): - new_call_type_str = None - for b in call_type.boundprocs: - if b.name.lower() == c: - new_call_type_str = strip_type(v.retvar.full_type) - break - new_call_type = call_type.all_types.get(new_call_type_str, None) - # not a procedure, try call type is an extended type + # not a variable, try call type is an extended type if (new_call_type is None and isinstance(call_type, FortranType)): extend = call_type From 3f0de23201a12061f8ffbb7cd9c8005b0f79eb92 Mon Sep 17 00:00:00 2001 From: Joseph Wood Date: Wed, 24 May 2023 13:17:31 -0400 Subject: [PATCH 5/6] add tests for the strip_parens util function, and tests for call chain context detection --- test/test_sourceform.py | 96 +++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 12 ++++++ 2 files changed, 108 insertions(+) diff --git a/test/test_sourceform.py b/test/test_sourceform.py index 4c2a4f1c..193cbcb8 100644 --- a/test/test_sourceform.py +++ b/test/test_sourceform.py @@ -1,6 +1,7 @@ from ford.sourceform import ( FortranSourceFile, FortranModule, + FortranBase, parse_type, ParsedType, line_to_variables, @@ -233,6 +234,101 @@ def test_function_and_subroutine_call_on_same_line(parse_fortran_file): expected_calls = {"bar", "foo"} assert set([call.name for call in program.calls]) == expected_calls +@pytest.mark.parametrize( + ["call_segment", "expected"], + [ + (""" + TYPE(t_bar) :: v_bar + TYPE(t_baz) :: var + var = v_bar%p_foo() + """ + , ["p_foo"]), + (""" + TYPE(t_bar) :: v_bar + INTEGER :: var + var = v_bar%v_foo(0) + """ + , []), + (""" + TYPE(t_bar) :: v_bar + INTEGER, DIMENSION(:), ALLOCATABLE :: var + var = v_bar%v_baz%p_baz() + """ + , ["p_baz"]), + (""" + TYPE(t_bar) :: v_bar + TYPE(t_baz) :: var + var = v_bar%t_foo%p_foo() + """ + , ["p_foo"]), + (""" + TYPE(t_bar) :: v_bar + INTEGER :: var + var = v_bar%v_baz%v_faz(0) + """ + , []), + ]) +def test_type_chain_function_and_subroutine_calls(parse_fortran_file,call_segment, expected): + data = """\ + MODULE m_foo + IMPLICIT NONE + TYPE :: t_baz + INTEGER, DIMENSION(:), ALLOCATABLE :: v_faz + CONTAINS + PROCEDURE :: p_baz + END TYPE t_baz + + TYPE :: t_foo + INTEGER, DIMENSION(:), ALLOCATABLE :: v_foo + CONTAINS + PROCEDURE :: p_foo + END TYPE t_foo + + TYPE, EXTENDS(t_foo) :: t_bar + TYPE(t_baz) :: v_baz + END TYPE t_bar + + CONTAINS + + FUNCTION p_baz(self) RESULT(ret_val) + CLASS(t_baz), INTENT(IN) :: self + INTEGER, DIMENSION(:), ALLOCATABLE :: ret_val + END FUNCTION p_baz + + FUNCTION p_foo(self) RESULT(ret_val) + CLASS(t_foo), INTENT(IN) :: self + TYPE(t_baz) :: ret_val + END FUNCTION p_foo + + FUNCTION p_bar() RESULT(ret_val) + TYPE(t_baz) :: ret_val + END FUNCTION p_bar + + SUBROUTINE main + ! Call segment + {} + ! Call segment + END SUBROUTINE main + + END MODULE m_foo + + """.format(call_segment) + + fortran_file = parse_fortran_file(data) + fp = FakeProject() + fortran_file.modules[0].correlate(fp) + + calls = fortran_file.modules[0].subroutines[0].calls + + assert len(calls) == len(expected) + + calls_sorted = sorted(calls, key=lambda x: x.name) + expected_sorted = sorted(expected) + for call, expected_name in zip(calls_sorted, expected_sorted): + assert isinstance(call, FortranBase) + + assert call.name == expected_name + def test_component_access(parse_fortran_file): data = """\ diff --git a/test/test_utils.py b/test/test_utils.py index 997d142d..45cea550 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -40,3 +40,15 @@ def test_str_to_bool_false(string): def test_str_to_bool_already_bool(): assert ford.utils.str_to_bool(True) assert not ford.utils.str_to_bool(False) + +@pytest.mark.parametrize(("string", "level", "index", "expected"), + [("abcdefghi", 0, -1, "abcdefghi"), + ("abc(def)ghi", 1, -1, "def"), + ("abc(def)ghi", 0, -1, "abcghi"), + ("(abc)def(ghi)", 1, 1, "abc"), + ("(abc)def(ghi)", 1, 9, "ghi"), + ("(abc)def(ghi)", 1, -1, "abcghi"), + ("(a(b)c)def(gh(i))", 1, 2, "ac"), + ]) +def test_strip_paren(string, level, index, expected): + assert ford.utils.strip_paren(string, retlevel = level, index = index) == expected From 39ef6a851dd2d3bed6a937e7bab43fc80855fa38 Mon Sep 17 00:00:00 2001 From: Joseph Wood Date: Wed, 24 May 2023 17:03:28 -0400 Subject: [PATCH 6/6] fix strip_paren doc comment --- ford/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ford/utils.py b/ford/utils.py index 152c3b77..fb33c24b 100644 --- a/ford/utils.py +++ b/ford/utils.py @@ -120,10 +120,10 @@ def get_parens(line: str, retlevel: int = 0, retblevel: int = 0) -> str: def strip_paren(line: str, retlevel: int = 0, retblevel: int = 0, index = -1) -> str: """ - By default takes a string starting with an open parenthesis and returns the portion - of the string going to the corresponding close parenthesis. If retlevel != 0 then - will return when that level (for parentheses) is reached. Same for retblevel. - If index >= 0, then only the characters inside the same scope as index are returned. + Takes a string with parentheses and returns only the portion of the string + that is in the same level of nested parentheses as specified by retlevel. + If index is specified, then characters in the same level but not in the same + scope as the char at index are also stripped. """ if len(line) == 0: return line