Skip to content

Commit

Permalink
Fix tests with node-type key error problems.
Browse files Browse the repository at this point in the history
  • Loading branch information
aravij committed Jul 13, 2020
1 parent d78d424 commit fac167b
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion aibolit/ast_framework/_auxiliary_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class ASTNodeReference(NamedTuple):
tree.WhileStatement: ASTNodeType.WHILE_STATEMENT,
}

common_attributes: Set[str] = {'type', 'line'}
common_attributes: Set[str] = {'node_type', 'line'}

attributes_by_node_type: Dict[ASTNodeType, Set[str]] = {
ASTNodeType.ANNOTATION_DECLARATION: {
Expand Down
8 changes: 4 additions & 4 deletions aibolit/ast_framework/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def get_binary_operation_name(self, node: int) -> str:

def get_line_number_from_children(self, node: int) -> int:
for child in self.tree.succ[node]:
cur_line = self.get_attr(child, 'line', -1)
if cur_line >= 0:
cur_line = self.get_attr(child, 'line')
if cur_line is not None:
return cur_line
return 0

Expand Down Expand Up @@ -241,7 +241,7 @@ def _add_javalang_standard_node(tree: DiGraph, javalang_node: Node) -> Tuple[int
@staticmethod
def _add_javalang_collection_node(tree: DiGraph, collection_node: Set[Any]) -> int:
node_index = len(tree) + 1
tree.add_node(node_index, type=ASTNodeType.COLLECTION)
tree.add_node(node_index, node_type=ASTNodeType.COLLECTION)
# we expect only strings in collection
# we add them here as children
for item in collection_node:
Expand All @@ -256,7 +256,7 @@ def _add_javalang_collection_node(tree: DiGraph, collection_node: Set[Any]) -> i
@staticmethod
def _add_javalang_string_node(tree: DiGraph, string_node: str) -> int:
node_index = len(tree) + 1
tree.add_node(node_index, type=ASTNodeType.STRING, string=string_node)
tree.add_node(node_index, node_type=ASTNodeType.STRING, string=string_node)
return node_index

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions aibolit/ast_framework/ast_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, graph: DiGraph, node_index: int):
self._node_index = node_index

def __dir__(self) -> List[str]:
node_type = self._graph.nodes[self._node_index]['type']
node_type = self._graph.nodes[self._node_index]['node_type']
return ['children'] + list(common_attributes) + list(attributes_by_node_type[node_type])

@property
Expand All @@ -43,7 +43,7 @@ def children(self) -> Iterator['ASTNode']:

def __getattr__(self, attribute_name: str):
if attribute_name not in common_attributes:
node_type = self._graph.nodes[self._node_index]['type']
node_type = self._graph.nodes[self._node_index]['node_type']
if(attribute_name not in attributes_by_node_type[node_type]):
raise AttributeError(f'{node_type} node does not have "{attribute_name}" attribute.')

Expand All @@ -57,7 +57,7 @@ def __getattr__(self, attribute_name: str):

def __str__(self) -> str:
text_representation = f'node index: {self._node_index}'
node_type = self.__getattr__('type')
node_type = self.__getattr__('node_type')
for attribute_name in sorted(common_attributes | attributes_by_node_type[node_type]):
text_representation += f'\n{attribute_name}: {self.__getattr__(attribute_name)}'

Expand Down
8 changes: 4 additions & 4 deletions aibolit/patterns/var_middle/var_middle.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def value(self, filename):
@staticmethod
def _on_entering_node(node: int, ast: DiGraph, scope_status: ScopeStatus,
lines_with_error: List[LineNumber]):
node_type = ast.nodes[node]['type']
node_type = ast.nodes[node]['node_type']

# if the variable is declared mark it and check the scope
if node_type in VarMiddle._var_declaration_node_types:
Expand All @@ -60,7 +60,7 @@ def _on_entering_node(node: int, ast: DiGraph, scope_status: ScopeStatus,

# mark scope for super constructor calling
elif node_type == ASTNodeType.STATEMENT_EXPRESSION:
children_types = {ast.nodes[child]['type'] for child in ast.succ[node]}
children_types = {ast.nodes[child]['node_type'] for child in ast.succ[node]}
if ASTNodeType.SUPER_CONSTRUCTOR_INVOCATION in children_types:
scope_status.add_flag(ScopeStatusFlags.INSIDE_CALLING_SUPER_CLASS_CONSTRUCTOR_SUBTREE)

Expand All @@ -80,15 +80,15 @@ def _on_entering_node(node: int, ast: DiGraph, scope_status: ScopeStatus,

@staticmethod
def _on_leaving_node(node: int, ast: DiGraph, scope_status: ScopeStatus):
node_type = ast.nodes[node]['type']
node_type = ast.nodes[node]['node_type']

# on the end of variable declaration remove according flag
if node_type in VarMiddle._var_declaration_node_types:
scope_status.remove_flag(ScopeStatusFlags.INSIDE_VARIABLE_DECLARATION_SUBTREE)

# on the end of super constructor call remove according flag
elif node_type == ASTNodeType.STATEMENT_EXPRESSION:
children_types = {ast.nodes[child]['type'] for child in ast.succ[node]}
children_types = {ast.nodes[child]['node_type'] for child in ast.succ[node]}
if ASTNodeType.SUPER_CONSTRUCTOR_INVOCATION in children_types:
scope_status.remove_flag(ScopeStatusFlags.INSIDE_CALLING_SUPER_CLASS_CONSTRUCTOR_SUBTREE)

Expand Down

0 comments on commit fac167b

Please sign in to comment.