diff --git a/tools/run_all_tests.py b/tools/run_all_tests.py index fdf503d9da..51d775adbc 100644 --- a/tools/run_all_tests.py +++ b/tools/run_all_tests.py @@ -70,6 +70,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--discover_only", action="store_true", help="Only discover and print tests, don't run them.") parser.add_argument("--quiet", action="store_true", help="Don't print to console, only log to file.") parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="Timeout for each test in seconds.") + parser.add_argument("--extension", type=str, default=None, help="Run tests only for the given extension.") # parse arguments args = parser.parse_args() return args @@ -83,6 +84,7 @@ def test_all( per_test_timeouts: dict[str, float] = {}, discover_only: bool = False, quiet: bool = False, + extension: str | None = None, ) -> bool: """Run all tests under the given directory. @@ -96,7 +98,8 @@ def test_all( discover_only: If True, only discover and print the tests without running them. Defaults to False. quiet: If False, print the output of the tests to the terminal console (in addition to the log file). Defaults to False. - + extension: Run tests only for the given extension. Defaults to None, which means all extensions' + tests will be run. Returns: True if all un-skipped tests pass or `discover_only` is True. Otherwise, False. @@ -126,6 +129,23 @@ def test_all( break else: raise ValueError(f"Test to skip '{test_to_skip}' not found in tests.") + + # Filter tests by extension + if extension is not None: + all_tests_in_selected_extension = [] + + for test_path in all_test_paths: + # Extract extension name from test path + extension_name = test_path[test_path.find("extensions") :].split("/")[1] + + # Skip tests that are not in the selected extension + if extension_name != extension: + continue + + all_tests_in_selected_extension.append(test_path) + + all_test_paths = all_tests_in_selected_extension + # Remove tests to skip from the list of tests to run if len(tests_to_skip) != 0: for test_path in all_test_paths: @@ -303,6 +323,7 @@ def test_all( per_test_timeouts=PER_TEST_TIMEOUTS, discover_only=args.discover_only, quiet=args.quiet, + extension=args.extension, ) # update exit status based on all tests passing or not if not test_success: