diff --git a/nest/core/decorators/utils.py b/nest/core/decorators/utils.py index 69fda4b..58a2209 100644 --- a/nest/core/decorators/utils.py +++ b/nest/core/decorators/utils.py @@ -1,12 +1,30 @@ import ast import inspect -from typing import Callable, List +from typing import Callable, List, Dict, Type, Iterable import click from nest.common.constants import INJECTABLE_TOKEN +def _is_valid_instance_variable(target): + """ + Checks if the target is a valid instance variable. + + Args: + target: The AST target node to check. + dependencies: A list of attribute names to exclude. + + Returns: + bool: True if the target is a valid instance variable, False otherwise. + """ + return ( + isinstance(target, ast.Attribute) + and isinstance(target.value, ast.Name) + and target.value.id == "self" + ) + + def get_instance_variables(cls): """ Retrieves instance variables assigned in the __init__ method of a class, @@ -23,67 +41,68 @@ def get_instance_variables(cls): tree = ast.parse(source) # Getting the parameter names to exclude dependencies - dependencies = set( - param.name - for param in inspect.signature(cls.__init__).parameters.values() - if param.annotation != param.empty - and getattr(param.annotation, "__injectable__", False) - ) - - instance_vars = {} - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - for target in node.targets: - if ( - isinstance(target, ast.Attribute) - and isinstance(target.value, ast.Name) - and target.value.id == "self" - ): - # Exclude dependencies - if target.attr not in dependencies: - # Here you can either store the source code of the value or - # evaluate it in the class' context, depending on your needs - instance_vars[target.attr] = ast.get_source_segment( - source, node.value - ) - return instance_vars - except Exception as e: + dependencies: Dict[str, ...] = parse_dependencies(cls, check_parent=True) + assign_nodes = filter(lambda node: isinstance(node, ast.Assign), ast.walk(tree)) + return { + # Here you can either store the source code of the value or + # evaluate it in the class' context, depending on your needs + target.attr: ast.get_source_segment(source, node.value) + for node in assign_nodes + for target in node.targets + if _is_valid_instance_variable(target) + and target.attr not in dependencies # Exclude dependencies + } + except Exception: return {} -def get_non_dependencies_params(cls): +def get_non_dependencies_params(cls: Type): source = inspect.getsource(cls.__init__).strip() tree = ast.parse(source) - non_dependencies = {} - for node in ast.walk(tree): - if isinstance(node, ast.Attribute): - non_dependencies[node.attr] = node.value.id - return non_dependencies + return { + node.attr: node.value.id + for node in ast.walk(tree) + if isinstance(node, ast.Attribute) + } + +def _check_injectable_not_inherited(param: inspect.Parameter) -> bool: + return ( + param.annotation != param.empty + and hasattr(param.annotation, "__dict__") + and INJECTABLE_TOKEN in param.annotation.__dict__ + ) -def parse_dependencies(cls): + +def _check_injectable_inherited(param: inspect.Parameter) -> bool: + return param.annotation != param.empty and getattr( + param.annotation, INJECTABLE_TOKEN, False + ) + + +def parse_dependencies(cls: Type, check_inherited: bool = False) -> Dict[str, Type]: + """ + Returns: + mapping of injectable parameters name to there annotation + """ signature = inspect.signature(cls.__init__) - dependecies = {} - for param in signature.parameters.values(): - try: - if ( - param.annotation != param.empty - and hasattr(param.annotation, "__dict__") - and INJECTABLE_TOKEN in param.annotation.__dict__ - ): - dependecies[param.name] = param.annotation - except Exception as e: - raise e - return dependecies + filter_by = ( + _check_injectable_inherited + if check_inherited + else _check_injectable_not_inherited + ) + params: Iterable[inspect.Parameter] = filter(filter_by, signature.parameters.values()) + return {param.name: param.annotation for param in params} def parse_params(func: Callable) -> List[click.Option]: + """ + Returns: + all parameters with annotation + """ signature = inspect.signature(func) - params = [] - for param in signature.parameters.values(): - try: - if param.annotation != param.empty: - params.append(param.annotation) - except Exception as e: - raise e - return params + return [ + param.annotation + for param in signature.parameters.values() + if param.annotation != param.empty + ]