Skip to content
This repository has been archived by the owner on Nov 23, 2024. It is now read-only.

Commit

Permalink
feat: improved infer_purity to only detect reads and writes inside …
Browse files Browse the repository at this point in the history
…a class for class methods | fixed testdata
  • Loading branch information
lukarade committed Mar 12, 2024
1 parent 8db122e commit 37d07e0
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,31 @@
_BUILTINS = dir(builtins)


def is_function_of_class(function: astroid.FunctionDef, klass: ClassScope) -> bool:
"""Check if a function is a method of a class.
Parameters
----------
function : astroid.FunctionDef
The function to check.
klass : ClassScope
The class to check.
Returns
-------
bool
True if the function is a method of the class, False otherwise.
"""
parent = function
while not isinstance(parent, astroid.Module | None):
if isinstance(parent, astroid.ClassDef) and parent == klass.symbol.node:
return True
elif isinstance(parent, astroid.ClassDef):
return False
parent = parent.parent
return False


def _find_call_references(
call_reference: Reference,
function: FunctionScope,
Expand Down Expand Up @@ -391,6 +416,25 @@ def _find_target_references(
and function.parent != klass
):
continue
# Do not add functions that are not of the current class (or superclass).
if function.symbol.name not in klass.class_variables or not is_function_of_class(function.symbol.node, klass):
# Collect all functions of superclasses for the current klass instance.
super_functions = []
for sup in klass.super_classes:
for class_var_list in sup.class_variables.values():
for var in class_var_list:
if isinstance(var.node, astroid.FunctionDef):
super_functions.append(var.node.name)

# Make an exception for global functions and functions of superclasses.
# Also check if the function was overwritten in the current class.
if (isinstance(function.symbol, GlobalVariable)
or function.symbol.name in super_functions and function.symbol.name not in klass.class_variables
):
pass
else:
continue

result_target_reference.referenced_symbols.extend(
klass.class_variables[target_reference.node.member],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ class ClassScope(Scope):
init_function : FunctionScope | None
The init function of the class if it exists else None.
super_classes : list[ClassScope]
The list of super classes of the class if any.
The list of superclasses of the class if any.
"""

class_variables: dict[str, list[Symbol]] = field(default_factory=dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def fun():
return var1
""", # language= None
{"fun.line2": Pure()},
),
), # TODO: [LATER] For this case it would be good to check if the instance that is accessed is local or not. Since this case is actually pure.
( # language=Python "VariableWrite to InstanceVariable - but actually a LocalVariable"
"""
class A:
Expand All @@ -186,8 +186,9 @@ def fun():
a = A()
a.instance_attr1 = 20 # Pure: VariableWrite to InstanceVariable - but actually a LocalVariable
""", # language= None
{"__init__.line3": Pure(), "fun.line6": Pure()},
),
{"__init__.line3": Pure(),
"fun.line6": SimpleImpure({"NonLocalVariableWrite.InstanceVariable.A.instance_attr1"})},
), # TODO: [LATER] For this case it would be good to check if the instance that is accessed is local or not. Since this case is actually pure.
( # language=Python "VariableRead from InstanceVariable - but actually a LocalVariable"
"""
class A:
Expand All @@ -199,8 +200,9 @@ def fun():
res = a.instance_attr1 # Pure: VariableRead from InstanceVariable - but actually a LocalVariable
return res
""", # language= None
{"__init__.line3": Pure(), "fun.line6": Pure()},
),
{"__init__.line3": Pure(),
"fun.line6": SimpleImpure({"NonLocalVariableRead.InstanceVariable.A.instance_attr1"})},
), # TODO: [LATER] For this case it would be good to check if the instance that is accessed is local or not. Since this case is actually pure.
( # language=Python "VariableRead and VariableWrite in chained class attribute and instance attribute"
"""
class A:
Expand All @@ -220,8 +222,10 @@ def g():
""", # language=none
{
"__init__.line3": Pure(),
"f.line9": Pure(),
"g.line13": Pure(),
"f.line9": SimpleImpure({"NonLocalVariableRead.ClassVariable.B.upper_class",
"NonLocalVariableRead.InstanceVariable.A.name"}),
"g.line13": SimpleImpure({"NonLocalVariableWrite.ClassVariable.B.upper_class",
"NonLocalVariableWrite.InstanceVariable.A.name"}),
},
),
( # language=Python "Pure Class initialization"
Expand Down Expand Up @@ -265,8 +269,9 @@ def fun2(self):
def fun3(self):
self.test = 10 # Impure: VariableWrite to InstanceVariable
""", # language= None
{
"__init__.line3": Pure(), # For init we need to filter out all reasons which are related to instance variables of the class (from the init function itself or propagated from called functions)
{ # TODO: [LATER] For init we need to filter out all reasons which are related to instance variables of the class (from the init function itself or propagated from called functions)
"__init__.line3": SimpleImpure({"NonLocalVariableRead.InstanceVariable.A.name",
"NonLocalVariableWrite.InstanceVariable.A.test"}),
"fun1.line10": Pure(),
"fun2.line13": SimpleImpure({"NonLocalVariableRead.InstanceVariable.A.name"}),
"fun3.line16": SimpleImpure({"NonLocalVariableWrite.InstanceVariable.A.test"}),
Expand Down Expand Up @@ -407,7 +412,7 @@ def fun1():
"Multiple Calls of same Pure function (Caching)",
], # TODO: class inits in cycles
)
@pytest.mark.xfail(reason="Some cases disabled for merging")
# @pytest.mark.xfail(reason="Some cases disabled for merging")
def test_infer_purity_pure(code: str, expected: list[ImpurityReason]) -> None:
purity_results = infer_purity(code)
transformed_purity_results = {
Expand Down Expand Up @@ -508,8 +513,9 @@ def fun2():
"fun1.line12": SimpleImpure({"NonLocalVariableWrite.ClassVariable.C.state"}),
"fun2.line16": SimpleImpure({"NonLocalVariableWrite.ClassVariable.C.state"}),
},
# The analysis checks the class of the classmethod and only returns it (C in this case)
),
( # language=Python "Class methode call of superclass" # TODO: propagate methods from super class to sub class
( # language=Python "Class methode call of superclass"
"""
class A:
state: str = "A"
Expand All @@ -528,10 +534,13 @@ def fun1():
def fun2():
C().set_state(1)
""", # language= None
{
"set_state.line6": SimpleImpure({"NonLocalVariableWrite.ClassVariable.A.state"}),
"fun1.line12": SimpleImpure({"NonLocalVariableWrite.ClassVariable.C.state"}),
"fun2.line16": SimpleImpure({"NonLocalVariableWrite.ClassVariable.C.state"}),
{ # Since the analysis only checks the name of the attributes, both class attributes are referenced here.
"set_state.line6": SimpleImpure({"NonLocalVariableWrite.ClassVariable.A.state",
"NonLocalVariableWrite.ClassVariable.C.state"}), # this mistake is acceptable due to the restrictions we made
"fun1.line12": SimpleImpure({"NonLocalVariableWrite.ClassVariable.A.state",
"NonLocalVariableWrite.ClassVariable.C.state"}),
"fun2.line16": SimpleImpure({"NonLocalVariableWrite.ClassVariable.A.state",
"NonLocalVariableWrite.ClassVariable.C.state"}),
},
),
( # language=Python "Class methode call of superclass (overwritten method)"
Expand All @@ -551,21 +560,25 @@ def set_state(cls, state):
cls.state = state
print("test") # Impure: FileWrite
def fun1():
def fun1(): # Pure
a = A()
a.set_state(1)
def fun2():
def fun2(): # Impure
C().set_state(1)
""", # language= None
{
"set_state.line6": SimpleImpure({"NonLocalVariableWrite.ClassVariable.A.state"}),
"set_state.line13": SimpleImpure(
{"NonLocalVariableWrite.ClassVariable.C.state", "FileWrite.StringLiteral.stdout"},
),
"fun1.line17": SimpleImpure({"NonLocalVariableWrite.ClassVariable.A.state"}),
"fun1.line17": SimpleImpure({"NonLocalVariableWrite.ClassVariable.A.state",
"NonLocalVariableWrite.ClassVariable.C.state", # this mistake is acceptable due to the restrictions we made
"FileWrite.StringLiteral.stdout"}), # this mistake is acceptable due to the restrictions we made
"fun2.line21": SimpleImpure(
{"NonLocalVariableWrite.ClassVariable.C.state", "FileWrite.StringLiteral.stdout"},
{"NonLocalVariableWrite.ClassVariable.A.state", # this mistake is acceptable due to the restrictions we made
"NonLocalVariableWrite.ClassVariable.C.state",
"FileWrite.StringLiteral.stdout"},
),
},
),
Expand Down Expand Up @@ -597,9 +610,11 @@ def fun3():
"__init__.line3": Pure(),
"__init__.line7": Pure(),
"b_fun.line10": SimpleImpure({"FileWrite.StringLiteral.stdout"}),
"fun1.line13": Pure(),
"fun2.line17": SimpleImpure({"FileWrite.StringLiteral.stdout"}),
"fun3.line21": SimpleImpure({"FileWrite.StringLiteral.stdout"}),
"fun1.line13": SimpleImpure({"NonLocalVariableRead.InstanceVariable.A.a_inst"}), # this mistake is acceptable due to the restrictions we made
"fun2.line17": SimpleImpure({"FileWrite.StringLiteral.stdout",
"NonLocalVariableRead.InstanceVariable.A.a_inst"}),
"fun3.line21": SimpleImpure({"FileWrite.StringLiteral.stdout",
"NonLocalVariableRead.InstanceVariable.A.a_inst"}),
},
),
( # language=Python "VariableWrite to ClassVariable"
Expand Down Expand Up @@ -669,7 +684,7 @@ def fun(c):
{
"__init__.line3": Pure(),
"fun.line6": SimpleImpure({"NonLocalVariableRead.InstanceVariable.B.instance_attr1"}),
}, # TODO: LARS is this corrct?
},
),
( # language=Python "VariableRead and VariableWrite in chained class attribute and instance attribute"
"""
Expand Down Expand Up @@ -697,10 +712,16 @@ def g():
""", # language=none
{
"__init__.line3": Pure(),
"a_fun.line6": SimpleImpure({"FileWrite.StringLiteral.stdout"}),
"b_fun.line12": SimpleImpure({"FileRead.StringLiteral.stdin"}),
"f.line16": SimpleImpure({"FileWrite.StringLiteral.stdout"}),
"g.line20": SimpleImpure({"FileRead.StringLiteral.stdin"}),
"a_fun.line6": SimpleImpure({"FileWrite.StringLiteral.stdout",
"NonLocalVariableRead.InstanceVariable.A.name"}), # this mistake is acceptable due to the restrictions we made
"b_fun.line12": SimpleImpure({"FileRead.StringLiteral.stdin",
"NonLocalVariableRead.ClassVariable.B.upper_class"}), # this mistake is acceptable due to the restrictions we made
"f.line16": SimpleImpure({"FileWrite.StringLiteral.stdout",
"NonLocalVariableRead.ClassVariable.B.upper_class", # this mistake is acceptable due to the restrictions we made
"NonLocalVariableRead.InstanceVariable.A.name"}), # this mistake is acceptable due to the restrictions we made
"g.line20": SimpleImpure({"FileRead.StringLiteral.stdin",
"NonLocalVariableRead.ClassVariable.B.upper_class", # this mistake is acceptable due to the restrictions we made
"NonLocalVariableRead.InstanceVariable.A.name"}), # this mistake is acceptable due to the restrictions we made
},
),
( # language=Python "Function call of functions with same name and different purity"
Expand Down Expand Up @@ -1178,7 +1199,7 @@ def try_except(num1):
"Try Except",
],
)
@pytest.mark.xfail(reason="Some cases disabled for merging")
# @pytest.mark.xfail(reason="Some cases disabled for merging")
def test_infer_purity_impure(code: str, expected: dict[str, SimpleImpure]) -> None:
purity_results = infer_purity(code)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def set_state(node, state):
ReferenceTestNode(
"node.state.line10",
"FunctionDef.set_state",
["ClassVariable.A.state.line3", "ClassVariable.C.state.line6"],
["ClassVariable.C.state.line6"],
),
ReferenceTestNode("node.line10", "FunctionDef.set_state", ["Parameter.node.line9"]),
],
Expand All @@ -1029,8 +1029,7 @@ def set_state(cls, state):
ReferenceTestNode(
"cls.state.line10",
"FunctionDef.set_state",
["ClassVariable.A.state.line3", "ClassVariable.C.state.line6"],
# TODO: [LATER] A.state should be removed!
["ClassVariable.C.state.line6"],
),
ReferenceTestNode("cls.line10", "FunctionDef.set_state", ["Parameter.cls.line9"]),
],
Expand Down

0 comments on commit 37d07e0

Please sign in to comment.