-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_tests_and_install.py
executable file
·157 lines (130 loc) · 5.31 KB
/
run_tests_and_install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python
"""
Custom runner script to run TorchDevice tests, build, and install.
Usage examples:
python run_tests_and_install.py # Run all tests, build, and install
python run_tests_and_install.py --test-only # Run tests only
python run_tests_and_install.py --update-expected # Update expected output files
python run_tests_and_install.py tests/some_test.py # Run specific test file(s)
"""
import os
import sys
import argparse
import subprocess
import time
import logging
from pathlib import Path
# Configure basic logging for the runner
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Determine the project root directory (where run_tests_and_install.py is located).
PROJECT_ROOT = Path(__file__).parent.absolute()
# Add the tests directory (which contains your common package) to sys.path.
tests_dir = PROJECT_ROOT / "tests"
if str(tests_dir) not in sys.path:
sys.path.insert(0, str(tests_dir))
def discover_test_files(test_root: Path) -> list:
"""Recursively discover test files in test_root that match 'test_*.py'."""
return list(test_root.rglob("test_*.py"))
def run_test_file(test_file: Path, update_expected: bool) -> bool:
"""
Run a test file via subprocess.
Args:
test_file: Path to the test file.
update_expected: Whether to pass the update flag.
Returns:
True if the test passed, False otherwise.
"""
env = os.environ.copy()
# Add the tests directory to PYTHONPATH so that "common" can be imported.
tests_dir = PROJECT_ROOT / "tests"
current_pythonpath = env.get("PYTHONPATH", "")
env["PYTHONPATH"] = f"{str(tests_dir)}{os.pathsep}{current_pythonpath}"
# Build the command.
cmd = [sys.executable, str(test_file)]
if update_expected:
# Append the flag so that the test file's sys.argv contains it.
cmd.append("--update-expected")
logger.info(f"Running test file: {test_file}")
result = subprocess.run(cmd, env=env)
if result.returncode != 0:
logger.error(f"Test file failed: {test_file}")
return False
else:
logger.info(f"Test file passed: {test_file}")
return True
def build_package() -> bool:
"""Build the TorchDevice package."""
logger.info("Building TorchDevice package...")
os.chdir(PROJECT_ROOT)
build_cmd = [sys.executable, 'setup.py', 'build']
process = subprocess.run(build_cmd, capture_output=True, text=True)
if process.returncode == 0:
logger.info("✅ Package built successfully!")
return True
else:
logger.error(f"❌ Package build failed: {process.stderr}")
return False
def install_package() -> bool:
"""Install the TorchDevice package in development mode."""
logger.info("Installing TorchDevice package...")
os.chdir(PROJECT_ROOT)
install_cmd = [sys.executable, '-m', 'pip', 'install', '-e', '.']
process = subprocess.run(install_cmd, capture_output=True, text=True)
if process.returncode == 0:
logger.info("✅ Package installed successfully!")
return True
else:
logger.error(f"❌ Package installation failed: {process.stderr}")
return False
def main():
parser = argparse.ArgumentParser(description="Run TorchDevice tests, build, and install")
parser.add_argument('--test-only', action='store_true', help="Run tests only")
parser.add_argument('--update-expected', action='store_true', help="Update expected output files")
parser.add_argument('test_paths', nargs='*', help="Test files or directories to run")
args = parser.parse_args()
# Remove custom args from sys.argv so unittest in test files doesn't complain.
sys.argv = [sys.argv[0]]
# Add the project root to sys.path so that tests can import modules correctly.
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
start_time = time.time()
# Determine test files: if none are provided, use the whole tests directory.
if not args.test_paths:
test_root = PROJECT_ROOT / "tests"
test_files = discover_test_files(test_root)
else:
test_files = []
for path in args.test_paths:
p = PROJECT_ROOT / path
if p.is_dir():
test_files.extend(discover_test_files(p))
else:
test_files.append(p)
if not test_files:
logger.warning("No test files found.")
return 0
# Run tests.
all_passed = True
for test_file in test_files:
if not run_test_file(test_file, args.update_expected):
all_passed = False
# If tests passed and we're not in test-only mode, build and install.
if all_passed and not args.test_only:
if build_package():
if not install_package():
logger.error("Package installation failed.")
return 1
else:
logger.error("Package build failed.")
return 1
elif not all_passed:
logger.error("Some tests failed. Skipping build and install.")
return 1
else:
logger.info("Test-only mode: Skipping build and install.")
elapsed_time = time.time() - start_time
logger.info(f"Process completed in {elapsed_time:.2f} seconds")
return 0
if __name__ == "__main__":
sys.exit(main())