Skip to content

Commit

Permalink
fix(utils): improve module name inspection methods
Browse files Browse the repository at this point in the history
  • Loading branch information
entelecheia committed Aug 18, 2023
1 parent 5702d87 commit a638659
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions src/hyfi/utils/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import sys
from pathlib import Path
from typing import Any
from typing import Any, List, Optional

from hyfi.utils.iolibs import IOLIBs
from hyfi.utils.logging import LOGGING
Expand Down Expand Up @@ -204,21 +204,36 @@ def viewsource(obj: str) -> None:
print(PKGs.getsource(obj))

@staticmethod
def get_caller_module_name(caller_stack_depth: int = 2) -> str:
def get_module_name_stack() -> List[str]:
"""Get the name of the module that called this function."""
try:
_stack = inspect.stack()
if len(_stack) < caller_stack_depth + 1:
logger.info(
"Returning top level module name (depth %d)", len(_stack) - 1
)
return inspect.getmodule(_stack[-1][0]).__name__ # type: ignore
return inspect.getmodule(_stack[caller_stack_depth][0]).__name__ # type: ignore
return [
getattr(inspect.getmodule(_stack[i][0]), "__name__", "")
for i in range(1, len(_stack))
]
except Exception as e:
logger.error(
f"Error getting caller module name at depth {caller_stack_depth}: {e}"
)
return ""
logger.error(f"Error getting module name stack: {e}")
return []

@staticmethod
def get_caller_module_name(caller_stack_depth: int = 2) -> str:
"""Get the name of the module that called this function."""
_stack = PKGs.get_module_name_stack()
if len(_stack) < caller_stack_depth + 1:
logger.info("Returning top level module name (depth %d)", len(_stack) - 1)
return _stack[-1]
return _stack[caller_stack_depth]

@staticmethod
def get_next_level_caller_package_name() -> Optional[str]:
"""Get the name of the package that called this function."""
_stack = PKGs.get_module_name_stack()
package_name = _stack[0].split(".")[0]
for name in _stack:
name = name.split(".")[0]
if name != package_name:
return name

@staticmethod
def is_importable(module_name: str) -> bool:
Expand Down

0 comments on commit a638659

Please sign in to comment.