Skip to content

Commit

Permalink
Support for relative imports to reuse step impls
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Vishwasrao <kunal.vishwasrao@gmail.com>
  • Loading branch information
kunalvishwasrao authored Jul 2, 2024
1 parent 105f304 commit 32e9662
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
23 changes: 22 additions & 1 deletion getgauge/impl_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import traceback
from os import path
from contextlib import contextmanager

from getgauge import logger
from getgauge.registry import registry
Expand All @@ -16,6 +17,7 @@
env_dir = os.path.join(project_root, 'env', 'default')
requirements_file = os.path.join(project_root, 'requirements.txt')
sys.path.append(project_root)
temporary_sys_path = []
PLUGIN_JSON = 'python.json'
VERSION = 'version'
PYTHON_PROPERTIES = 'python.properties'
Expand All @@ -30,6 +32,12 @@ def load_impls(step_impl_dirs=impl_dirs):
logger.error('Make sure `STEP_IMPL_DIR` env var is set to a valid directory path.')
return
base_dir = project_root if impl_dir.startswith(project_root) else os.path.dirname(impl_dir)
# Handle multi-level relative imports
for _ in range(impl_dir.count('..')):
base_dir = os.path.dirname(base_dir).replace("/", os.path.sep).replace("\\", os.path.sep)
# Add temporary sys path for relative imports that is not already added
if '..' in impl_dir and base_dir not in temporary_sys_path:
temporary_sys_path.append(base_dir)
_import_impl(base_dir, impl_dir)


Expand Down Expand Up @@ -57,12 +65,25 @@ def _import_impl(base_dir, step_impl_dir):
elif path.isdir(file_path):
_import_impl(base_dir, file_path)

@contextmanager
def use_temporary_sys_path():
original_sys_path = sys.path[:]
sys.path.extend(temporary_sys_path)
try:
yield
finally:
sys.path = original_sys_path

def _import_file(base_dir, file_path):
rel_path = os.path.normpath(file_path.replace(base_dir + os.path.sep, ''))
try:
module_name = os.path.splitext(rel_path.replace(os.path.sep, '.'))[0]
m = importlib.import_module(module_name)
# Use temporary sys path for relative imports
if '..' in file_path:
with use_temporary_sys_path():
m = importlib.import_module(module_name)
else:
m = importlib.import_module(module_name)
# Get all classes in the imported module
classes = inspect.getmembers(m, lambda member: inspect.isclass(member) and member.__module__ == module_name)
if len(classes) > 0:
Expand Down
2 changes: 1 addition & 1 deletion python.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"id": "python",
"version": "0.4.3",
"version": "0.4.4",
"description": "Python support for gauge",
"run": {
"windows": [
Expand Down

0 comments on commit 32e9662

Please sign in to comment.