Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

calls to variables within a type no longer get added to call graph #520

Merged
201 changes: 164 additions & 37 deletions ford/sourceform.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,10 @@ class FortranContainer(FortranBase):
)
ARITH_GOTO_RE = re.compile(r"go\s*to\s*\([0-9,\s]+\)", re.IGNORECASE)
CALL_RE = re.compile(
r"(?:^|[^a-zA-Z0-9_% ]\s*(?:\w+%)?)(\w+)(?=\s*\(\s*(?:.*?)\s*\))", re.IGNORECASE
r"(?:^|[^a-zA-Z0-9_ ]\s*)(\w+)(?=\s*\(\s*(?:.*?)\s*\))", re.IGNORECASE
)
SUBCALL_RE = re.compile(
r"^(?:if\s*\(.*\)\s*)?call\s+(?:\w+%)?(\w+)\s*(?:\(\s*(.*?)\s*\))?$",
r"^(?:if\s*\(.*\)\s*)?call\s+(?:\w+%)*(\w+)\s*(?:\(\s*(.*?)\s*\))?$",
re.IGNORECASE,
)
FORMAT_RE = re.compile(r"^[0-9]+\s+format\s+\(.*\)", re.IGNORECASE)
Expand Down Expand Up @@ -917,6 +917,14 @@ def __init__(
else:
self.print_error(line, "Unexpected USE statement")
else:

def is_unique_call(call):
"""Check if a call is unique in this container, and not an intrinsic call."""
return (
call.name not in INTRINSICS
and not any(c.name == call.name for c in self.calls)
)

if self.CALL_RE.search(line):
if hasattr(self, "calls"):
# Arithmetic GOTOs looks little like function references:
Expand All @@ -925,23 +933,21 @@ def __init__(
# expression doesn't catch that so we first rule such a
# GOTO out.
if not self.ARITH_GOTO_RE.search(line):
callvals = self.CALL_RE.findall(line)
for val in callvals:
if (
val.lower() not in self.calls
and val.lower() not in INTRINSICS
):
self.calls.append(val.lower())
matches = self.CALL_RE.finditer(line)
for match in matches:
call = get_call_chain(line,match)
if is_unique_call(call):
self.calls.append(call)
else:
pass
# Not raising an error here as too much possibility that something
# has been misidentified as a function call
if match := self.SUBCALL_RE.match(line):
# Need this to catch any subroutines called without argument lists
if hasattr(self, "calls"):
callval = match.group(1).lower()
if callval not in self.calls and callval not in INTRINSICS:
self.calls.append(callval)
call = get_call_chain(line, match)
if is_unique_call(call):
self.calls.append(call)
else:
self.print_error(line, "Unexpected procedure call")

Expand Down Expand Up @@ -986,11 +992,13 @@ def correlate(self, project):
self.all_procs.update(self.parent_submodule.all_procs)
self.all_absinterfaces.update(self.parent_submodule.all_absinterfaces)
self.all_types.update(self.parent_submodule.all_types)
self.all_vars.update(self.parent_submodule.pub_vars)
elif type(getattr(self, "ancestor_module", "")) not in [str, type(None)]:
self.ancestor_module.descendants.append(self)
self.all_procs.update(self.ancestor_module.all_procs)
self.all_absinterfaces.update(self.ancestor_module.all_absinterfaces)
self.all_types.update(self.ancestor_module.all_types)
self.all_vars.update(self.ancestor_module.pub_vars)

# Module procedures will be missing (some/all?) metadata, so
# now we copy it from the interface
Expand Down Expand Up @@ -1049,38 +1057,59 @@ def filter_public(collection: dict) -> dict:
typelist[dtype] = set([])
typeorder = toposort.toposort_flatten(typelist)

# Correlate types
for dtype in typeorder:
if dtype in self.types:
dtype.correlate(project)

# Match up called procedures
if hasattr(self, "calls"):
tmplst = []
for call in self.calls:
call_low = call.lower()
call: CallChain
call.name = call.name.lower()

if call.name == 'c':
pass

# get the context of the call
context = self._find_call_context(call)

if context is None:
pass

# failed to find context, give up and add call's string name to the list
if context is None:
tmplst.append(call.name)
continue

argname = False
for a in getattr(self, "args", []):
# Consider allowing procedures passed as arguments to be included in callgraphs
argname |= call_low == a.name.lower()
if hasattr(self, "retvar"):
argname |= call_low == self.retvar.name.lower()
# arguments and returns are only possible labels if the call is made within self's context
if context == self:
for a in getattr(self, "args", []):
# Consider allowing procedures passed as arguments to be included in callgraphs
argname |= call.name == a.name.lower()
if hasattr(self, "retvar"):
argname |= call.name == self.retvar.name.lower()

# get all the variables in the call's context
all_vars = {}
if hasattr(context, "all_vars"):
all_vars = context.all_vars
if hasattr(context, "variables"):
all_vars = {**all_vars, **{v.name.lower(): v for v in context.variables}}

# if call isn't to a variable (i.e. an array), and isn't a type, add it to the list
if (
not argname
and call_low not in self.all_vars
and (call_low not in self.all_types or call_low in self.all_procs)
and call.name not in getattr(context, "all_types", {})
and call.name not in all_vars
):
# if can't find the call in context, add it as a string
call = context.all_procs.get(call.name, call.name)
tmplst.append(call)
self.calls = tmplst

procedures = (
{proc.name.lower(): proc for proc in self.parent.routines}
if self.parobj == "sourcefile"
else {}
)
procedures.update({proc.name.lower(): proc for proc in project.procedures})
procedures.update(self.all_procs)

for i, call in enumerate(self.calls):
try:
self.calls[i] = procedures[call.lower()]
except KeyError:
pass
self.calls = tmplst

if self.obj == "submodule":
self.ancestry = []
Expand All @@ -1091,9 +1120,6 @@ def filter_public(collection: dict) -> dict:
self.ancestry.insert(0, item.ancestor_module)

# Recurse
for dtype in typeorder:
if dtype in self.types:
dtype.correlate(project)
for entity in self.iterator(
"functions",
"subroutines",
Expand Down Expand Up @@ -1228,6 +1254,89 @@ def prune(self):
obj.visible = True
obj.prune()

def _find_call_context(self, call):
"""
Traverse the call chain of the call to discover the context the call is made on.
This is done by looking at the first label in the call chain and matching it to
a variable or function in the current scope. Then, traverse to the context of the
variable or function return and repeat until the call chain is exhausted.

If the traversal fails to find a label in a context,
the function gives up and returns None
"""

def strip_type(s):
"""
strip the encasing 'type()' or 'class()' from a string if it exists,
and return the inner string (lowercased)
"""
r = re.match(r"^(type|class)\((.*?)(?:\(.*\))?\)$", s, re.IGNORECASE)
return r.group(2).lower() if r else s.lower()

# context is self if call is not a chain
if len(call.chain) == 0:
return self

call.chain[0] = call.chain[0].lower()

call_type = None
# try call type is a variable
vars = getattr(self, "all_vars", {})
if hasattr(self, "args"):
vars = {**vars, **{a.name.lower(): a for a in self.args}}
if hasattr(self, "retvar"):
vars = {**vars, **{self.retvar.name.lower(): self.retvar}}
if (hasattr(self, 'all_types')
and call.chain[0] in vars
):
call_type_str = strip_type(vars[call.chain[0]].full_type)
call_type = self.all_types.get(call_type_str, None)

# if None, not a variable, try call type is an extended type
if (call_type is None
and isinstance(self, FortranType)
and hasattr(self, 'extends')):
extend = self
while getattr(extend, "extends", None) is not None:
if extend.extends.name.lower() == c:
call_type = extend.extends
break
extend = extend.extends

# if still None, give up
if call_type is None:
return None

# traverse the call chain
for c in [c.lower() for c in call.chain[1:]]:
new_call_type = None
# try call type is a variable
if hasattr(call_type, 'variables'):
new_call_type_str = None
for v in call_type.variables:
if v.name.lower() == c:
new_call_type_str = strip_type(v.full_type)
break
new_call_type = call_type.all_types.get(new_call_type_str, None)

# not a variable, try call type is an extended type
if (new_call_type is None
and isinstance(call_type, FortranType)):
extend = call_type
while getattr(extend, "extends", None) is not None:
if extend.extends.name.lower() == c:
new_call_type = extend.extends
break
extend = extend.extends

# not a subtype, give up
if new_call_type is None:
return None

call_type = new_call_type

return call_type


class FortranSourceFile(FortranContainer):
"""
Expand Down Expand Up @@ -2794,6 +2903,24 @@ def parse_type(
kind = kind.group(1) if kind else args
return ParsedType(vartype, rest, kind=kind)

@dataclass
class CallChain:
name: str
chain: List[str]

def get_call_chain(line, match) -> CallChain:
sub_line = line[:match.start(1)].lower()
if len(sub_line) == 0 or sub_line[-1] != "%":
# not a chain call
return CallChain(match.group(1).lower(), [])
level = sub_line.count("(") - sub_line.count(")")
sub_line = ford.utils.strip_paren(sub_line, retlevel=level, index = len(sub_line) - 1)
# remove 'call ' from the start if present
sub_line = sub_line[5:] if sub_line.startswith("call ") else sub_line
# get end of sub_line after last non-alphanumeric character
sub_line = re.split(r'([^\w\s_%]|(?<=[^\s%])\s+(?=[^\s%]))+', sub_line)[-1].replace(' ','')
call_chain = sub_line.split("%")[:-1]
return CallChain(match.group(1).lower(), call_chain)

def set_base_url(url):
FortranBase.base_url = url
Expand Down
39 changes: 38 additions & 1 deletion ford/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from urllib.parse import urljoin
import pathlib
from typing import Union
from io import StringIO
import itertools


Expand Down Expand Up @@ -88,7 +89,7 @@ def substitute(match):

def get_parens(line: str, retlevel: int = 0, retblevel: int = 0) -> str:
"""
By default akes a string starting with an open parenthesis and returns the portion
By default takes a string starting with an open parenthesis and returns the portion
of the string going to the corresponding close parenthesis. If retlevel != 0 then
will return when that level (for parentheses) is reached. Same for retblevel.
"""
Expand Down Expand Up @@ -118,6 +119,42 @@ def get_parens(line: str, retlevel: int = 0, retblevel: int = 0) -> str:
return parenstr
raise RuntimeError("Couldn't parse parentheses: {}".format(line))

def strip_paren(line: str, retlevel: int = 0, retblevel: int = 0, index = -1) -> str:
"""
Takes a string with parentheses and returns only the portion of the string
that is in the same level of nested parentheses as specified by retlevel.
If index is specified, then characters in the same level but not in the same
scope as the char at index are also stripped.
"""
if len(line) == 0:
return line
retstr = StringIO()
level = 0
blevel = 0
for i,char in enumerate(line):
if char == "(":
level += 1
elif char == ")":
level -= 1
elif char == "[":
blevel += 1
elif char == "]":
blevel -= 1
elif (
level == retlevel
and blevel == retblevel
):
retstr.write(char)

if index >= 0 and char == ')' and level < retlevel:
# scope of index is yet to start, reset retstr
if i < index:
retstr = StringIO()
# scope of index is over, return retstr
else:
return retstr.getvalue()

return retstr.getvalue()

def paren_split(sep, string):
"""
Expand Down
Loading